Skip to content

Commit

Permalink
feat: update services api client (#1695)
Browse files Browse the repository at this point in the history
* feat: use reqwest client

* wip: add project and org ops to permit client

* chore: isolate reqwest changes

* refactor bytes response (unused)

* nits

* fix: json feature
  • Loading branch information
jonaro00 authored Mar 20, 2024
1 parent 4ab5f08 commit 27c5c37
Show file tree
Hide file tree
Showing 10 changed files with 141 additions and 144 deletions.
3 changes: 2 additions & 1 deletion common/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -57,11 +57,12 @@ backend = [
"axum/matched-path",
"axum/json",
"claims",
"hyper/client",
"hyper",
"opentelemetry_sdk",
"opentelemetry-appender-tracing",
"opentelemetry-otlp",
"models",
"reqwest/json",
"rustrict", # only ProjectName model uses it
"thiserror",
"tokio",
Expand Down
85 changes: 24 additions & 61 deletions common/src/backends/client/gateway.rs
Original file line number Diff line number Diff line change
@@ -1,37 +1,9 @@
use headers::Authorization;
use http::{Method, Uri};
use http::Method;
use tracing::instrument;

use crate::models;

use super::{Error, ServicesApiClient};

/// Wrapper struct to make API calls to gateway easier
#[derive(Clone)]
pub struct Client {
public_client: ServicesApiClient,
private_client: ServicesApiClient,
}

impl Client {
/// Make a gateway client that is able to call the public and private APIs on gateway
pub fn new(public_uri: Uri, private_uri: Uri) -> Self {
Self {
public_client: ServicesApiClient::new(public_uri),
private_client: ServicesApiClient::new(private_uri),
}
}

/// Get the client of public API calls
pub fn public_client(&self) -> &ServicesApiClient {
&self.public_client
}

/// Get the client of private API calls
pub fn private_client(&self) -> &ServicesApiClient {
&self.private_client
}
}
use super::{header_map_with_bearer, Error, ServicesApiClient};

/// Interact with all the data relating to projects
#[allow(async_fn_in_trait)]
Expand Down Expand Up @@ -65,33 +37,29 @@ pub trait ProjectsDal {
}
}

impl ProjectsDal for Client {
impl ProjectsDal for ServicesApiClient {
#[instrument(skip_all)]
async fn get_user_project(
&self,
user_token: &str,
project_name: &str,
) -> Result<models::project::Response, Error> {
self.public_client
.request(
Method::GET,
format!("projects/{}", project_name).as_str(),
None::<()>,
Some(Authorization::bearer(user_token).expect("to build an authorization bearer")),
)
.await
self.get(
format!("projects/{}", project_name).as_str(),
Some(header_map_with_bearer(user_token)),
)
.await
}

#[instrument(skip_all)]
async fn head_user_project(&self, user_token: &str, project_name: &str) -> Result<bool, Error> {
self.public_client
.request_raw(
Method::HEAD,
format!("projects/{}", project_name).as_str(),
None::<()>,
Some(Authorization::bearer(user_token).expect("to build an authorization bearer")),
)
.await?;
self.request_raw(
Method::HEAD,
format!("projects/{}", project_name).as_str(),
None::<()>,
Some(header_map_with_bearer(user_token)),
)
.await?;

Ok(true)
}
Expand All @@ -101,13 +69,7 @@ impl ProjectsDal for Client {
&self,
user_token: &str,
) -> Result<Vec<models::project::Response>, Error> {
self.public_client
.request(
Method::GET,
"projects",
None::<()>,
Some(Authorization::bearer(user_token).expect("to build an authorization bearer")),
)
self.get("projects", Some(header_map_with_bearer(user_token)))
.await
}
}
Expand All @@ -116,24 +78,25 @@ impl ProjectsDal for Client {
mod tests {
use test_context::{test_context, AsyncTestContext};

use crate::backends::client::ServicesApiClient;
use crate::models::project::{Response, State};
use crate::test_utils::get_mocked_gateway_server;

use super::{Client, ProjectsDal};
use super::ProjectsDal;

impl AsyncTestContext for Client {
impl AsyncTestContext for ServicesApiClient {
async fn setup() -> Self {
let server = get_mocked_gateway_server().await;

Client::new(server.uri().parse().unwrap(), server.uri().parse().unwrap())
ServicesApiClient::new(server.uri().parse().unwrap())
}

async fn teardown(self) {}
}

#[test_context(Client)]
#[test_context(ServicesApiClient)]
#[tokio::test]
async fn get_user_projects(client: &mut Client) {
async fn get_user_projects(client: &mut ServicesApiClient) {
let res = client.get_user_projects("user-1").await.unwrap();

assert_eq!(
Expand All @@ -155,9 +118,9 @@ mod tests {
)
}

#[test_context(Client)]
#[test_context(ServicesApiClient)]
#[tokio::test]
async fn get_user_project_ids(client: &mut Client) {
async fn get_user_project_ids(client: &mut ServicesApiClient) {
let res = client.get_user_project_ids("user-2").await.unwrap();

assert_eq!(res, vec!["00000000000000000000000003"])
Expand Down
141 changes: 93 additions & 48 deletions common/src/backends/client/mod.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
use std::time::Duration;

use bytes::Bytes;
use headers::{ContentType, Header, HeaderMapExt};
use http::{Method, Request, StatusCode, Uri};
use hyper::{body, client::HttpConnector, Body, Client};
use headers::{Authorization, HeaderMapExt};
use http::{HeaderMap, HeaderValue, Method, StatusCode, Uri};
use opentelemetry::global;
use opentelemetry_http::HeaderInjector;
use reqwest::{Client, ClientBuilder, Response};
use serde::{de::DeserializeOwned, Serialize};
use thiserror::Error;
use tracing::{trace, Span};
Expand All @@ -17,95 +19,143 @@ pub use resource_recorder::ResourceDal;

#[derive(Error, Debug)]
pub enum Error {
#[error("Hyper error: {0}")]
Hyper(#[from] hyper::Error),
#[error("Reqwest error: {0}")]
Reqwest(#[from] reqwest::Error),
#[error("Serde JSON error: {0}")]
SerdeJson(#[from] serde_json::Error),
#[error("Hyper error: {0}")]
Http(#[from] hyper::http::Error),
#[error("Request did not return correctly. Got status code: {0}")]
RequestError(StatusCode),
#[error("GRpc request did not return correctly. Got status code: {0}")]
GrpcError(#[from] tonic::Status),
}

/// `Hyper` wrapper to make request to RESTful Shuttle services easy
/// `reqwest` wrapper to make requests to other services easy
#[derive(Clone)]
pub struct ServicesApiClient {
client: Client<HttpConnector>,
client: Client,
base: Uri,
}

impl ServicesApiClient {
fn new(base: Uri) -> Self {
pub fn builder() -> ClientBuilder {
Client::builder().timeout(Duration::from_secs(60))
}

pub fn new(base: Uri) -> Self {
Self {
client: Self::builder().build().unwrap(),
base,
}
}

pub fn new_with_bearer(base: Uri, token: &str) -> Self {
Self {
client: Client::new(),
client: Self::builder()
.default_headers(header_map_with_bearer(token))
.build()
.unwrap(),
base,
}
}

pub async fn request<B: Serialize, T: DeserializeOwned, H: Header>(
pub async fn get<T: DeserializeOwned>(
&self,
path: &str,
headers: Option<HeaderMap<HeaderValue>>,
) -> Result<T, Error> {
self.request(Method::GET, path, None::<()>, headers).await
}

pub async fn post<B: Serialize, T: DeserializeOwned>(
&self,
path: &str,
body: B,
headers: Option<HeaderMap<HeaderValue>>,
) -> Result<T, Error> {
self.request(Method::POST, path, Some(body), headers).await
}

pub async fn delete<B: Serialize, T: DeserializeOwned>(
&self,
path: &str,
body: B,
headers: Option<HeaderMap<HeaderValue>>,
) -> Result<T, Error> {
self.request(Method::DELETE, path, Some(body), headers)
.await
}

pub async fn request<B: Serialize, T: DeserializeOwned>(
&self,
method: Method,
path: &str,
body: Option<B>,
extra_header: Option<H>,
headers: Option<HeaderMap<HeaderValue>>,
) -> Result<T, Error> {
let bytes = self.request_raw(method, path, body, extra_header).await?;
let json = serde_json::from_slice(&bytes)?;

Ok(json)
Ok(self
.request_raw(method, path, body, headers)
.await?
.json()
.await?)
}

pub async fn request_raw<B: Serialize, H: Header>(
pub async fn request_bytes<B: Serialize>(
&self,
method: Method,
path: &str,
body: Option<B>,
extra_header: Option<H>,
headers: Option<HeaderMap<HeaderValue>>,
) -> Result<Bytes, Error> {
Ok(self
.request_raw(method, path, body, headers)
.await?
.bytes()
.await?)
}

// can be used for explicit HEAD requests (ignores body)
pub async fn request_raw<B: Serialize>(
&self,
method: Method,
path: &str,
body: Option<B>,
headers: Option<HeaderMap<HeaderValue>>,
) -> Result<Response, Error> {
let uri = format!("{}{path}", self.base);
trace!(uri, "calling inner service");

let mut req = Request::builder().method(method).uri(uri);
let headers = req
.headers_mut()
.expect("new request to have mutable headers");
if let Some(extra_header) = extra_header {
headers.typed_insert(extra_header);
}
if body.is_some() {
headers.typed_insert(ContentType::json());
}

let mut h = headers.unwrap_or_default();
let cx = Span::current().context();
global::get_text_map_propagator(|propagator| {
propagator.inject_context(&cx, &mut HeaderInjector(req.headers_mut().unwrap()))
propagator.inject_context(&cx, &mut HeaderInjector(&mut h))
});

let req = self.client.request(method, uri).headers(h);
let req = if let Some(body) = body {
req.body(Body::from(serde_json::to_vec(&body)?))
req.json(&body)
} else {
req.body(Body::empty())
req
};

let resp = self.client.request(req?).await?;
trace!(response = ?resp, "Load response");
let resp = req.send().await?;
trace!(response = ?resp, "service response");

if resp.status() != StatusCode::OK {
if !resp.status().is_success() {
return Err(Error::RequestError(resp.status()));
}

let bytes = body::to_bytes(resp.into_body()).await?;

Ok(bytes)
Ok(resp)
}
}

pub fn header_map_with_bearer(token: &str) -> HeaderMap {
let mut h = HeaderMap::new();
h.typed_insert(Authorization::bearer(token).expect("valid token"));
h
}

#[cfg(test)]
mod tests {
use headers::{authorization::Bearer, Authorization};
use http::{Method, StatusCode};
use http::StatusCode;

use crate::models;
use crate::test_utils::get_mocked_gateway_server;
Expand All @@ -120,12 +170,7 @@ mod tests {
let client = ServicesApiClient::new(server.uri().parse().unwrap());

let err = client
.request::<_, Vec<models::project::Response>, _>(
Method::GET,
"projects",
None::<()>,
None::<Authorization<Bearer>>,
)
.get::<Vec<models::project::Response>>("projects", None)
.await
.unwrap_err();

Expand Down
Loading

0 comments on commit 27c5c37

Please sign in to comment.