Skip to content

Commit

Permalink
move all the reqwest related code to a different file
Browse files Browse the repository at this point in the history
  • Loading branch information
irevoire committed Apr 14, 2024
1 parent b58c067 commit 8867b9e
Show file tree
Hide file tree
Showing 8 changed files with 131 additions and 114 deletions.
3 changes: 2 additions & 1 deletion examples/cli-app-with-reqwest/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,5 @@ serde_json = "1.0"
lazy_static = "1.4.0"
reqwest = "0.11.16"
async-trait = "0.1.51"
tokio = { version = "1.27.0", features = ["full"] }
tokio = { version = "1.27.0", features = ["full"] }
yaup = "0.2.0"
16 changes: 12 additions & 4 deletions examples/cli-app-with-reqwest/src/main.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
use async_trait::async_trait;
use lazy_static::lazy_static;
use meilisearch_sdk::errors::Error;
use meilisearch_sdk::request::{
add_query_parameters, parse_response, qualified_version, HttpClient, Method,
};
use meilisearch_sdk::request::{parse_response, qualified_version, HttpClient, Method};
use meilisearch_sdk::{client::*, settings::Settings};
use serde::de::DeserializeOwned;
use serde::{Deserialize, Serialize};
Expand All @@ -19,7 +17,7 @@ lazy_static! {
#[derive(Debug, Clone, Serialize)]
pub struct ReqwestClient;

#[async_trait(?Send)]
#[async_trait]
impl HttpClient for ReqwestClient {
async fn request<Query, Body, Output>(
&self,
Expand Down Expand Up @@ -249,3 +247,13 @@ impl fmt::Display for ClothesDisplay {
)
}
}

fn add_query_parameters<Query: Serialize>(url: &str, query: &Query) -> Result<String, Error> {
let query = yaup::to_string(query)?;

if query.is_empty() {
Ok(url.to_string())
} else {
Ok(format!("{url}?{query}"))
}
}
1 change: 1 addition & 0 deletions examples/cli-app/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,4 @@ futures = "0.3"
serde = { version="1.0", features = ["derive"] }
serde_json = "1.0"
lazy_static = "1.4.0"
yaup = "0.2.0"
4 changes: 2 additions & 2 deletions meilisearch-index-setting-macro/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ fn get_index_config_implementation(

quote! {
#[::meilisearch_sdk::macro_helper::async_trait(?Send)]
impl ::meilisearch_sdk::documents::IndexConfig<::meilisearch_sdk::request::ReqwestClient> for #struct_ident {
impl ::meilisearch_sdk::documents::IndexConfig<::meilisearch_sdk::reqwest::ReqwestClient> for #struct_ident {
const INDEX_STR: &'static str = #index_name;

fn generate_settings() -> ::meilisearch_sdk::settings::Settings {
Expand All @@ -140,7 +140,7 @@ fn get_index_config_implementation(
#distinct_attr_token
}

async fn generate_index(client: &::meilisearch_sdk::client::Client<::meilisearch_sdk::request::ReqwestClient>) -> std::result::Result<::meilisearch_sdk::indexes::Index<::meilisearch_sdk::request::ReqwestClient>, ::meilisearch_sdk::tasks::Task> {
async fn generate_index(client: &::meilisearch_sdk::client::Client<::meilisearch_sdk::reqwest::ReqwestClient>) -> std::result::Result<::meilisearch_sdk::indexes::Index<::meilisearch_sdk::reqwest::ReqwestClient>, ::meilisearch_sdk::tasks::Task> {
return client.create_index(#index_name, #primary_key_token)
.await.unwrap()
.wait_for_completion(&client, ::std::option::Option::None, ::std::option::Option::None)
Expand Down
4 changes: 2 additions & 2 deletions src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ impl Client {
/// ```
pub fn new(host: impl Into<String>, api_key: Option<impl Into<String>>) -> Client {
let api_key = api_key.map(|key| key.into());
let http_client = ReqwestClient::new(api_key.as_deref());
let http_client = crate::reqwest::ReqwestClient::new(api_key.as_deref());

Client {
host: host.into(),
Expand Down Expand Up @@ -1142,7 +1142,7 @@ mod tests {

use meilisearch_test_macro::meilisearch_test;

use crate::{client::*, key::Action};
use crate::{client::*, key::Action, reqwest::qualified_version};

#[derive(Debug, Serialize, Deserialize, PartialEq)]
struct Document {
Expand Down
5 changes: 4 additions & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,10 @@ mod tenant_tokens;
mod utils;

#[cfg(feature = "reqwest")]
pub type DefaultHttpClient = request::ReqwestClient;
pub mod reqwest;

#[cfg(feature = "reqwest")]
pub type DefaultHttpClient = reqwest::ReqwestClient;

#[cfg(not(feature = "reqwest"))]
pub type DefaultHttpClient = std::convert::Infallible;
Expand Down
106 changes: 2 additions & 104 deletions src/request.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,20 +63,9 @@ impl<Q, B> Method<Q, B> {
Method::Patch { body, query: _ } => Some(body),
}
}

#[cfg(feature = "reqwest")]
pub fn verb(&self) -> reqwest::Method {
match self {
Method::Get { .. } => reqwest::Method::GET,
Method::Delete { .. } => reqwest::Method::DELETE,
Method::Post { .. } => reqwest::Method::POST,
Method::Put { .. } => reqwest::Method::PUT,
Method::Patch { .. } => reqwest::Method::PATCH,
}
}
}

#[async_trait(?Send)]
#[async_trait]
pub trait HttpClient: Clone + Send + Sync {
async fn request<Query, Body, Output>(
&self,
Expand Down Expand Up @@ -113,91 +102,6 @@ pub trait HttpClient: Clone + Send + Sync {
) -> Result<Output, Error>;
}

#[cfg(feature = "reqwest")]
#[derive(Debug, Clone, Default)]
pub struct ReqwestClient {
client: reqwest::Client,
}

#[cfg(feature = "reqwest")]
impl ReqwestClient {
pub fn new(api_key: Option<&str>) -> Self {
use reqwest::{header, ClientBuilder};

let builder = ClientBuilder::new();
let mut headers = header::HeaderMap::new();
headers.insert(
header::USER_AGENT,
header::HeaderValue::from_str(&qualified_version()).unwrap(),
);

if let Some(api_key) = api_key {
headers.insert(
header::AUTHORIZATION,
header::HeaderValue::from_str(&format!("Bearer {api_key}")).unwrap(),
);
}

let builder = builder.default_headers(headers);
let client = builder.build().unwrap();

ReqwestClient { client }
}
}

#[cfg(feature = "reqwest")]
#[async_trait(?Send)]
impl HttpClient for ReqwestClient {
async fn stream_request<
'a,
Query: Serialize + Send + Sync,
Body: futures_io::AsyncRead + Send + Sync + 'static,
Output: DeserializeOwned + 'static,
>(
&self,
url: &str,
method: Method<Query, Body>,
content_type: &str,
expected_status_code: u16,
) -> Result<Output, Error> {
use reqwest::header;

let url = add_query_parameters(url, method.query())?;

let mut request = self.client.request(method.verb(), &url);

if let Some(body) = method.into_body() {
let reader = tokio_util::compat::FuturesAsyncReadCompatExt::compat(body);
let stream = tokio_util::io::ReaderStream::new(reader);
let body = reqwest::Body::wrap_stream(stream);

request = request
.header(header::CONTENT_TYPE, content_type)
.body(body);
}

let response = self.client.execute(request.build()?).await?;
let status = response.status().as_u16();
let mut body = response.text().await?;

if body.is_empty() {
body = "null".to_string();
}

parse_response(status, expected_status_code, &body, url.to_string())
}
}

pub fn add_query_parameters<Query: Serialize>(url: &str, query: &Query) -> Result<String, Error> {
let query = yaup::to_string(query)?;

if query.is_empty() {
Ok(url.to_string())
} else {
Ok(format!("{url}?{query}"))
}
}

pub fn parse_response<Output: DeserializeOwned>(
status_code: u16,
expected_status_code: u16,
Expand Down Expand Up @@ -239,13 +143,7 @@ pub fn parse_response<Output: DeserializeOwned>(
}
}

pub fn qualified_version() -> String {
const VERSION: Option<&str> = option_env!("CARGO_PKG_VERSION");

format!("Meilisearch Rust (v{})", VERSION.unwrap_or("unknown"))
}

#[async_trait(?Send)]
#[async_trait]
impl HttpClient for Infallible {
async fn request<Query, Body, Output>(
&self,
Expand Down
106 changes: 106 additions & 0 deletions src/reqwest.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
use async_trait::async_trait;
use serde::{de::DeserializeOwned, Serialize};

use crate::{
errors::Error,
request::{parse_response, HttpClient, Method},
};

#[derive(Debug, Clone, Default)]
pub struct ReqwestClient {
client: reqwest::Client,
}

#[cfg(feature = "reqwest")]
impl ReqwestClient {
pub fn new(api_key: Option<&str>) -> Self {
use reqwest::{header, ClientBuilder};

let builder = ClientBuilder::new();
let mut headers = header::HeaderMap::new();
headers.insert(
header::USER_AGENT,
header::HeaderValue::from_str(&qualified_version()).unwrap(),
);

if let Some(api_key) = api_key {
headers.insert(
header::AUTHORIZATION,
header::HeaderValue::from_str(&format!("Bearer {api_key}")).unwrap(),
);
}

let builder = builder.default_headers(headers);
let client = builder.build().unwrap();

ReqwestClient { client }
}
}

#[async_trait]
impl HttpClient for ReqwestClient {
async fn stream_request<
'a,
Query: Serialize + Send + Sync,
Body: futures_io::AsyncRead + Send + Sync + 'static,
Output: DeserializeOwned + 'static,
>(
&self,
url: &str,
method: Method<Query, Body>,
content_type: &str,
expected_status_code: u16,
) -> Result<Output, Error> {
use reqwest::header;

let url = add_query_parameters(url, method.query())?;

let mut request = self.client.request(verb(&method), &url);

if let Some(body) = method.into_body() {
let reader = tokio_util::compat::FuturesAsyncReadCompatExt::compat(body);
let stream = tokio_util::io::ReaderStream::new(reader);
let body = reqwest::Body::wrap_stream(stream);

request = request
.header(header::CONTENT_TYPE, content_type)
.body(body);
}

let response = self.client.execute(request.build()?).await?;
let status = response.status().as_u16();
let mut body = response.text().await?;

if body.is_empty() {
body = "null".to_string();
}

parse_response(status, expected_status_code, &body, url.to_string())
}
}

fn verb<Q, B>(method: &Method<Q, B>) -> reqwest::Method {
match method {
Method::Get { .. } => reqwest::Method::GET,
Method::Delete { .. } => reqwest::Method::DELETE,
Method::Post { .. } => reqwest::Method::POST,
Method::Put { .. } => reqwest::Method::PUT,
Method::Patch { .. } => reqwest::Method::PATCH,
}
}

pub fn add_query_parameters<Query: Serialize>(url: &str, query: &Query) -> Result<String, Error> {
let query = yaup::to_string(query)?;

if query.is_empty() {
Ok(url.to_string())
} else {
Ok(format!("{url}?{query}"))
}
}

pub fn qualified_version() -> String {
const VERSION: Option<&str> = option_env!("CARGO_PKG_VERSION");

format!("Meilisearch Rust (v{})", VERSION.unwrap_or("unknown"))
}

0 comments on commit 8867b9e

Please sign in to comment.