Compare commits

..

12 commits

2 changed files with 221 additions and 205 deletions

View file

@ -5,20 +5,18 @@ edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[features]
request = []
request = ["surf/h1-client-rustls"]
convert_from_notion = []
[dependencies]
async-trait = "0.1.68"
async-std = "1.12.0"
# async-trait = "0.1.68"
base64 = "0.21.0"
chrono = "0.4.31"
futures-core = "0.3.28"
lazy_static = "1.4.0"
log = "0.4.20"
regex = "1.7.1"
reqwest = { version = "0.11.14", features = ["json"] }
serde = { version = "^1.0", features = ["derive"], default-features = false }
serde_json = { version = "^1.0", features = ["raw_value"], default-features = false }
[dev-dependencies]
tokio = { version = "1.28.1", features = ["macros"] }
surf = { version = "2.3.2", default-features = false }

View file

@ -1,15 +1,13 @@
use std::collections::HashMap;
use std::sync::Arc;
use chrono::{DateTime, NaiveTime, Utc};
use lazy_static::lazy_static;
use regex::Regex;
#[cfg(feature = "request")]
use reqwest::header::{HeaderMap, HeaderValue};
use serde::de::Error as SerdeError;
use serde::{Deserialize, Deserializer, Serialize};
use serde_json::json;
use serde_json::Value;
use surf::http::StatusCode;
use futures_core::future::BoxFuture;
@ -22,18 +20,16 @@ lazy_static! {
const NOTION_VERSION: &str = "2022-06-28";
pub type Result<T> = std::result::Result<T, Error>;
pub type Callback = dyn Fn(
&mut reqwest::RequestBuilder,
) -> BoxFuture<'_, std::result::Result<reqwest::Response, reqwest::Error>>
pub type Callback = dyn Fn(surf::RequestBuilder) -> BoxFuture<'static, std::result::Result<surf::Response, surf::Error>>
+ 'static
+ Send
+ Sync;
#[derive(Debug)]
pub enum Error {
Http(reqwest::Error, Option<Value>),
Http(StatusCode, surf::Response),
Surf(surf::Error),
Deserialization(serde_json::Error, Option<Value>),
Header(reqwest::header::InvalidHeaderValue),
ChronoParse(chrono::ParseError),
UnexpectedType,
}
@ -44,15 +40,9 @@ impl std::fmt::Display for Error {
}
}
impl From<reqwest::Error> for Error {
fn from(error: reqwest::Error) -> Self {
Error::Http(error, None)
}
}
impl From<reqwest::header::InvalidHeaderValue> for Error {
fn from(error: reqwest::header::InvalidHeaderValue) -> Self {
Error::Header(error)
impl From<surf::Error> for Error {
fn from(error: surf::Error) -> Self {
Error::Surf(error)
}
}
@ -68,37 +58,31 @@ impl From<chrono::ParseError> for Error {
}
}
async fn try_to_parse_response<T: std::fmt::Debug + for<'de> serde::Deserialize<'de>>(
response: reqwest::Response,
) -> Result<T> {
let text = response.text().await?;
// async fn try_to_parse_response<T: std::fmt::Debug + for<'de> serde::Deserialize<'de>>(
// mut response: surf::Response,
// ) -> Result<T> {
// let text = response.body_string().await?;
match serde_json::from_str::<T>(&text) {
Ok(value) => Ok(value),
Err(error) => match serde_json::from_str::<Value>(&text) {
Ok(body) => Err(Error::Deserialization(error, Some(body))),
_ => Err(Error::Deserialization(error, Some(Value::String(text)))),
},
}
}
// match serde_json::from_str::<T>(&text) {
// Ok(value) => Ok(value),
// Err(error) => match serde_json::from_str::<Value>(&text) {
// Ok(body) => Err(Error::Deserialization(error, Some(body))),
// _ => Err(Error::Deserialization(error, Some(Value::String(text)))),
// },
// }
// }
#[cfg(feature = "request")]
fn get_http_client(notion_api_key: &str) -> reqwest::Client {
let mut headers = HeaderMap::new();
headers.insert(
"Authorization",
HeaderValue::from_str(&format!("Bearer {notion_api_key}"))
.expect("bearer token to be parsed into a header"),
);
headers.insert(
"Notion-Version",
HeaderValue::from_str(NOTION_VERSION).expect("notion version to be parsed into a header"),
);
headers.insert("Content-Type", HeaderValue::from_static("application/json"));
reqwest::ClientBuilder::new()
.default_headers(headers)
.build()
fn get_http_client(notion_api_key: &str) -> surf::Client {
log::trace!("Readying HTTP Client");
surf::Config::new()
.add_header("Authorization", format!("Bearer {notion_api_key}"))
.expect("to add Authorization header")
.add_header("Notion-Version", NOTION_VERSION)
.expect("to add Notion-Version header")
.add_header("Content-Type", "application/json")
.expect("to add Content-Type header")
.try_into()
.expect("to build a valid client out of notion_api_key")
}
@ -119,7 +103,7 @@ pub struct SearchOptions<'a> {
#[derive(Default)]
pub struct ClientBuilder {
api_key: Option<String>,
custom_request: Option<Arc<Callback>>,
custom_request: Option<Box<Callback>>,
}
impl ClientBuilder {
@ -132,13 +116,13 @@ impl ClientBuilder {
pub fn custom_request<F>(mut self, callback: F) -> Self
where
for<'c> F: Fn(
&'c mut reqwest::RequestBuilder,
) -> BoxFuture<'c, std::result::Result<reqwest::Response, reqwest::Error>>
surf::RequestBuilder,
) -> BoxFuture<'static, std::result::Result<surf::Response, surf::Error>>
+ 'static
+ Send
+ Sync,
{
self.custom_request = Some(Arc::new(callback));
self.custom_request = Some(Box::new(callback));
self
}
@ -147,76 +131,71 @@ impl ClientBuilder {
pub fn build(self) -> Client {
let notion_api_key = self.api_key.expect("api_key to be set");
let request_handler = self.custom_request.unwrap_or(Arc::new(
|request_builder: &mut reqwest::RequestBuilder| {
Box::pin(async move {
let request = request_builder
.try_clone()
.expect("non-stream body request clone to succeed");
request.send().await
})
},
));
let http_client = Arc::from(get_http_client(&notion_api_key));
Client {
http_client: http_client.clone(),
request_handler: request_handler.clone(),
pages: Pages {
http_client: http_client.clone(),
request_handler: request_handler.clone(),
},
blocks: Blocks {
http_client: http_client.clone(),
request_handler: request_handler.clone(),
},
databases: Databases {
http_client: http_client.clone(),
request_handler: request_handler.clone(),
},
users: Users {
http_client: http_client.clone(),
request_handler: request_handler.clone(),
},
http_client: get_http_client(&notion_api_key),
request_handler: self.custom_request.unwrap_or(Box::new(
|request_builder: surf::RequestBuilder| Box::pin(request_builder),
)),
}
}
}
pub struct Client {
http_client: Arc<reqwest::Client>,
request_handler: Arc<Callback>,
pub pages: Pages,
pub blocks: Blocks,
pub databases: Databases,
pub users: Users,
http_client: surf::Client,
request_handler: Box<Callback>,
}
impl<'a> Client {
#[allow(clippy::new_ret_no_self)]
pub fn new() -> ClientBuilder {
ClientBuilder::default()
}
pub async fn search<'b, T: std::fmt::Debug + for<'de> serde::Deserialize<'de>>(
self,
&self,
options: SearchOptions<'b>,
) -> Result<QueryResponse<T>> {
let mut request = self
let request = self
.http_client
.post("https://api.notion.com/v1/search")
.json(&options);
.body_json(&options)
.expect("to parse JSON for doing `search`");
let response = (self.request_handler)(&mut request).await?;
let mut response = (self.request_handler)(request)
.await
.expect("to request through a request handler");
match response.error_for_status_ref() {
Ok(_) => Ok(response.json().await?),
Err(error) => {
let body = response.json::<Value>().await?;
Err(Error::Http(error, Some(body)))
}
match response.status() {
StatusCode::Ok => Ok(response.body_json().await?),
status => Err(Error::Http(status, response)),
}
}
pub fn pages(&'a self) -> Pages<'a> {
Pages {
http_client: &self.http_client,
request_handler: &self.request_handler,
}
}
pub fn blocks(&'a self) -> Blocks<'a> {
Blocks {
http_client: &self.http_client,
request_handler: &self.request_handler,
}
}
pub fn databases(&'a self) -> Databases<'a> {
Databases {
http_client: &self.http_client,
request_handler: &self.request_handler,
}
}
pub fn users(&'a self) -> Users<'a> {
Users {
http_client: &self.http_client,
request_handler: &self.request_handler,
}
}
}
@ -225,88 +204,78 @@ pub struct PageOptions<'a> {
pub page_id: &'a str,
}
#[derive(Clone)]
pub struct Pages {
http_client: Arc<reqwest::Client>,
request_handler: Arc<Callback>,
pub struct Pages<'a> {
http_client: &'a surf::Client,
request_handler: &'a Box<Callback>,
}
impl Pages {
pub async fn retrieve<'a>(self, options: PageOptions<'a>) -> Result<Page> {
impl Pages<'_> {
pub async fn retrieve(&self, options: PageOptions<'_>) -> Result<Page> {
let url = format!(
"https://api.notion.com/v1/pages/{page_id}",
page_id = options.page_id
);
let mut request = self.http_client.get(url);
let request = self.http_client.get(url);
let mut response = (self.request_handler)(request).await?;
let response = (self.request_handler)(&mut request).await?;
match response.error_for_status_ref() {
Ok(_) => Ok(response.json().await?),
Err(error) => {
let body = response.json::<Value>().await?;
Err(Error::Http(error, Some(body)))
}
match response.status() {
StatusCode::Ok => Ok(response.body_json().await?),
status => Err(Error::Http(status, response)),
}
}
}
#[derive(Clone)]
pub struct Blocks {
http_client: Arc<reqwest::Client>,
request_handler: Arc<Callback>,
pub struct Blocks<'a> {
http_client: &'a surf::Client,
request_handler: &'a Box<Callback>,
}
impl Blocks {
impl Blocks<'_> {
pub fn children(&self) -> BlockChildren {
BlockChildren {
http_client: self.http_client.clone(),
request_handler: self.request_handler.clone(),
http_client: self.http_client,
request_handler: self.request_handler,
}
}
}
pub struct BlockChildren {
http_client: Arc<reqwest::Client>,
request_handler: Arc<Callback>,
pub struct BlockChildren<'a> {
http_client: &'a surf::Client,
request_handler: &'a Box<Callback>,
}
pub struct BlockChildrenListOptions<'a> {
pub block_id: &'a str,
}
impl BlockChildren {
pub async fn list<'a>(
self,
options: BlockChildrenListOptions<'a>,
impl BlockChildren<'_> {
pub async fn list(
&self,
options: BlockChildrenListOptions<'_>,
) -> Result<QueryResponse<Block>> {
let url = format!(
"https://api.notion.com/v1/blocks/{block_id}/children",
block_id = options.block_id
);
let mut request = self.http_client.get(&url);
let request = self.http_client.get(&url);
let response = (self.request_handler)(&mut request).await?;
let mut response = (self.request_handler)(request).await?;
match response.error_for_status_ref() {
Ok(_) => Ok(response.json().await?),
Err(error) => {
let body = response.json::<Value>().await?;
Err(Error::Http(error, Some(body)))
}
match response.status() {
StatusCode::Ok => Ok(response.body_json().await?),
status => Err(Error::Http(status, response)),
}
}
}
#[derive(Clone)]
pub struct Databases {
http_client: Arc<reqwest::Client>,
request_handler: Arc<Callback>,
pub struct Databases<'a> {
http_client: &'a surf::Client,
request_handler: &'a Box<Callback>,
}
impl Databases {
impl Databases<'_> {
pub async fn query<'a>(
&self,
options: DatabaseQueryOptions<'a>,
@ -318,11 +287,7 @@ impl Databases {
let mut request = self.http_client.post(url);
let json = if let Some(filter) = options.filter {
Some(json!({ "filter": filter }))
} else {
None
};
let json = options.filter.map(|filter| json!({ "filter": filter }));
let json = if let Some(sorts) = options.sorts {
if let Some(mut json) = json {
@ -352,18 +317,19 @@ impl Databases {
json
};
if let Some(json) = json {
request = request.json(&json);
if let Some(ref json) = json {
request = request
.body_json(json)
.expect("to parse JSON for start_cursor");
}
let response = (self.request_handler)(&mut request).await?;
log::trace!("Querying database with request: {request:?} and body: {json:?}");
match response.error_for_status_ref() {
Ok(_) => try_to_parse_response(response).await,
Err(error) => {
let body = try_to_parse_response::<Value>(response).await?;
Err(Error::Http(error, Some(body)))
}
let mut response = (self.request_handler)(request).await?;
match response.status() {
StatusCode::Ok => Ok(response.body_json().await?),
status => Err(Error::Http(status, response)),
}
}
}
@ -372,9 +338,9 @@ impl Databases {
mod tests {
use super::*;
#[tokio::test]
#[async_std::test]
async fn check_database_query() {
let databases = Client::new()
let _ = Client::new()
.api_key("secret_FuhJkAoOVZlk8YUT9ZOeYqWBRRZN6OMISJwhb4dTnud")
.build()
.search::<Database>(SearchOptions {
@ -390,23 +356,19 @@ mod tests {
start_cursor: None,
})
.await;
println!("{databases:#?}");
}
#[tokio::test]
#[async_std::test]
async fn test_blocks() {
let blocks = Client::new()
let _ = Client::new()
.api_key("secret_FuhJkAoOVZlk8YUT9ZOeYqWBRRZN6OMISJwhb4dTnud")
.build()
.blocks
.blocks()
.children()
.list(BlockChildrenListOptions {
block_id: "0d253ab0f751443aafb9bcec14012897",
})
.await;
println!("{blocks:#?}")
}
}
@ -419,26 +381,22 @@ pub struct DatabaseQueryOptions<'a> {
pub start_cursor: Option<String>,
}
#[derive(Clone)]
pub struct Users {
http_client: Arc<reqwest::Client>,
request_handler: Arc<Callback>,
pub struct Users<'a> {
http_client: &'a surf::Client,
request_handler: &'a Box<Callback>,
}
impl Users {
impl Users<'_> {
pub async fn get(&self) -> Result<QueryResponse<User>> {
let url = "https://api.notion.com/v1/users".to_owned();
let mut request = self.http_client.get(&url);
let request = self.http_client.get(&url);
let response = (self.request_handler)(&mut request).await?;
let mut response = (self.request_handler)(request).await?;
match response.error_for_status_ref() {
Ok(_) => Ok(response.json().await?),
Err(error) => {
let body = response.json::<Value>().await?;
Err(Error::Http(error, Some(body)))
}
match response.status() {
StatusCode::Ok => Ok(response.body_json().await?),
status => Err(Error::Http(status, response)),
}
}
}
@ -727,6 +685,9 @@ pub enum CodeLanguage {
YAML,
#[serde(rename = "java/c/c++/c#")]
JavaCCppCSharp,
#[serde(other)]
Unsupported,
}
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
@ -956,8 +917,8 @@ where
serde_json::from_value::<DatabaseProperty>(value.to_owned()).unwrap_or_else(|error| {
log::warn!(
"Could not parse value because of error, defaulting to DatabaseProperty::Unsupported:\n= ERROR:\n{error:#?}\n= JSON:\n{:#?}\n---",
serde_json::to_string_pretty(&value).unwrap()
"Could not parse value because of error, defaulting to DatabaseProperty::Unsupported:\n= ERROR:\n{error:?}\n= JSON:\n{:?}\n---",
serde_json::to_string_pretty(&value).expect("to pretty print the database property error")
);
DatabaseProperty::Unsupported(value.to_owned())
}),
@ -983,7 +944,7 @@ pub struct Relation {
// TODO: Paginate all possible responses
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
pub struct QueryResponse<T> {
pub has_more: bool,
pub has_more: Option<bool>,
pub next_cursor: Option<String>,
pub results: Vec<T>,
}
@ -1180,7 +1141,7 @@ where
// Notion forgets to set the formula type, so we're doing it's homework
"formula" => {
if let Value::Object(object) = value {
if let None = object.get("type") {
if object.get("type").is_none() {
object.insert("type".to_owned(), json!("string"));
}
}
@ -1223,7 +1184,7 @@ where
serde_json::from_value::<Property>(value.to_owned()).unwrap_or_else(|error| {
log::warn!(
"Could not parse value because of error, defaulting to Property::Unsupported:\n= ERROR:\n{error:#?}\n= JSON:\n{}\n---",
serde_json::to_string_pretty(&value).unwrap()
serde_json::to_string_pretty(&value).expect("to pretty print Property errors")
);
Property::Unsupported(value.to_owned())
}),
@ -1236,10 +1197,21 @@ where
#[serde(tag = "type")]
#[serde(rename_all = "snake_case")]
pub enum Formula {
Boolean { boolean: Option<bool> },
Date { date: Option<Date> },
Number { number: Option<f32> },
String { string: Option<String> },
Boolean {
boolean: Option<bool>,
},
Date {
date: Option<Date>,
},
Number {
number: Option<f32>,
},
String {
string: Option<String>,
},
#[serde(other)]
Unsupported,
}
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
@ -1293,6 +1265,9 @@ pub enum RichText {
href: Option<String>,
annotations: Annotations,
},
#[serde(other)]
Unsupported,
}
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
@ -1305,11 +1280,24 @@ pub struct Text {
#[serde(tag = "type")]
#[serde(rename_all = "lowercase")]
pub enum Mention {
Database { database: PartialDatabase },
Date { date: Date },
LinkPreview { link_preview: LinkPreview },
Page { page: PartialPage },
User { user: PartialUser },
Database {
database: PartialDatabase,
},
Date {
date: Date,
},
LinkPreview {
link_preview: LinkPreview,
},
Page {
page: PartialPage,
},
User {
user: PartialUser,
},
#[serde(other)]
Unsupported,
}
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
@ -1435,33 +1423,60 @@ pub enum Color {
PurpleBackground,
PinkBackground,
RedBackground,
#[serde(other)]
Unsupported,
}
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
#[serde(tag = "type")]
#[serde(rename_all = "snake_case")]
pub enum Parent {
PageId { page_id: String },
DatabaseId { database_id: String },
BlockId { block_id: String },
PageId {
page_id: String,
},
DatabaseId {
database_id: String,
},
BlockId {
block_id: String,
},
Workspace,
#[serde(other)]
Unsupported,
}
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
#[serde(tag = "type")]
#[serde(rename_all = "lowercase")]
pub enum File {
File { file: NotionFile },
External { external: ExternalFile },
File {
file: NotionFile,
},
External {
external: ExternalFile,
},
#[serde(other)]
Unsupported,
}
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
#[serde(tag = "type")]
#[serde(rename_all = "lowercase")]
pub enum Icon {
Emoji { emoji: String },
File { file: NotionFile },
External { external: ExternalFile },
Emoji {
emoji: String,
},
File {
file: NotionFile,
},
External {
external: ExternalFile,
},
#[serde(other)]
Unsupported,
}
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
@ -1482,4 +1497,7 @@ pub enum DatabaseFormulaType {
Date,
Number,
String,
#[serde(other)]
Unsupported,
}