From 07ce7cb1852dee608d10dc14c68ceaabd13c4e22 Mon Sep 17 00:00:00 2001 From: Mason Gup Date: Thu, 2 Jan 2025 16:36:19 -0500 Subject: [PATCH 1/6] Create additional middleware for optional signature --- Cargo.toml | 4 +- README.md | 8 ++- src/axum_service.rs | 120 ++++++++++++++++++++++++++++++++++----- src/validate_incoming.rs | 66 ++++++++++++++++++--- 4 files changed, 172 insertions(+), 26 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index ed038d3..8b68dbe 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "mauth-client" -version = "0.5.0" +version = "0.6.0" authors = ["Mason Gup "] edition = "2021" documentation = "https://docs.rs/mauth-client/" @@ -31,7 +31,7 @@ futures-core = { version = "0.3", optional = true } http = "1" bytes = { version = "1", optional = true } thiserror = "1" -mauth-core = "0.5" +mauth-core = "0.6" [dev-dependencies] tokio = { version = "1", features = ["rt-multi-thread", "macros"] } diff --git a/README.md b/README.md index a800925..eda2942 100644 --- a/README.md +++ b/README.md @@ -51,7 +51,13 @@ match client.get("https://www.example.com/").send().await { The optional `axum-service` feature provides for a Tower Layer and Service that will authenticate incoming requests via MAuth V2 or V1 and provide to the lower layers a -validated app_uuid from the request via the ValidatedRequestDetails struct. +validated app_uuid from the request via the ValidatedRequestDetails struct. Note that +this feature now includes a `RequiredMAuthValidationLayer`, which will reject any +requests without a valid signature before they reach lower layers, and also a +`OptionalMAuthValidationLayer`, which lets all requests through, but only attaches a +ValidatedRequestDetails extension struct if there is a valid signature. When using this +layer, it is the responsiblity of the request handler to check for the extension and +reject requests that are not properly authorized. There are also optional features `tracing-otel-26` and `tracing-otel-27` that pair with the `axum-service` feature to ensure that any outgoing requests for credentials that take diff --git a/src/axum_service.rs b/src/axum_service.rs index 04cf6ae..51faf2b 100644 --- a/src/axum_service.rs +++ b/src/axum_service.rs @@ -14,13 +14,13 @@ use crate::{ /// This is a Tower Service which validates that incoming requests have a valid /// MAuth signature. It only passes the request down to the next layer if the /// signature is valid, otherwise it returns an appropriate error to the caller. -pub struct MAuthValidationService { +pub struct RequiredMAuthValidationService { mauth_info: MAuthInfo, config_info: ConfigFileSection, service: S, } -impl Service for MAuthValidationService +impl Service for RequiredMAuthValidationService where S: Service + Send + Clone + 'static, S::Future: Send + 'static, @@ -48,9 +48,9 @@ where } } -impl Clone for MAuthValidationService { +impl Clone for RequiredMAuthValidationService { fn clone(&self) -> Self { - MAuthValidationService { + RequiredMAuthValidationService { // unwrap is safe because we validated the config_info before constructing the layer mauth_info: MAuthInfo::from_config_section(&self.config_info).unwrap(), config_info: self.config_info.clone(), @@ -59,18 +59,18 @@ impl Clone for MAuthValidationService { } } -/// This is a Tower Layer which applies the MAuthValidationService on top of the +/// This is a Tower Layer which applies the RequiredMAuthValidationService on top of the /// service provided to it. #[derive(Clone)] -pub struct MAuthValidationLayer { +pub struct RequiredMAuthValidationLayer { config_info: ConfigFileSection, } -impl Layer for MAuthValidationLayer { - type Service = MAuthValidationService; +impl Layer for RequiredMAuthValidationLayer { + type Service = RequiredMAuthValidationService; fn layer(&self, service: S) -> Self::Service { - MAuthValidationService { + RequiredMAuthValidationService { // unwrap is safe because we validated the config_info before constructing the layer mauth_info: MAuthInfo::from_config_section(&self.config_info).unwrap(), config_info: self.config_info.clone(), @@ -79,21 +79,113 @@ impl Layer for MAuthValidationLayer { } } -impl MAuthValidationLayer { - /// Construct a MAuthValidationLayer based on the configuration options in the file +impl RequiredMAuthValidationLayer { + /// Construct a RequiredMAuthValidationLayer based on the configuration options in the file /// found in the default location. pub fn from_default_file() -> Result { let config_info = MAuthInfo::config_section_from_default_file()?; // Generate a MAuthInfo and then drop it to validate that it works, // making it safe to use `unwrap` in the service constructor. MAuthInfo::from_config_section(&config_info)?; - Ok(MAuthValidationLayer { config_info }) + Ok(RequiredMAuthValidationLayer { config_info }) } - /// Construct a MAuthValidationLayer based on the configuration options in a manually + /// Construct a RequiredMAuthValidationLayer based on the configuration options in a manually /// created or parsed ConfigFileSection. pub fn from_config_section(config_info: ConfigFileSection) -> Result { MAuthInfo::from_config_section(&config_info)?; - Ok(MAuthValidationLayer { config_info }) + Ok(RequiredMAuthValidationLayer { config_info }) + } +} + +/// This is a Tower Service which validates that incoming requests have a valid +/// MAuth signature. Unlike the Required service, if this service is not able to +/// find or validate a signature, it passes the request down to the lower layers +/// anyways. This means that it is the responsibility of the request handler to +/// check for the `ValidatedRequestDetails` extension to determine if the request +/// has a valid signature. It also means that this service is safe to attach to +/// the whole application, even if some requests are not validated at all or may +/// be validated in a different way. +pub struct OptionalMAuthValidationService { + mauth_info: MAuthInfo, + config_info: ConfigFileSection, + service: S, +} + +impl Service for OptionalMAuthValidationService +where + S: Service + Send + Clone + 'static, + S::Future: Send + 'static, + S::Error: Into>, +{ + type Response = S::Response; + type Error = Box; + type Future = BoxFuture<'static, Result>; + + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + self.service.poll_ready(cx).map_err(|e| e.into()) + } + + fn call(&mut self, request: Request) -> Self::Future { + let mut cloned = self.clone(); + Box::pin(async move { + match cloned.mauth_info.validate_request_optionally(request).await { + Ok(valid_request) => match cloned.service.call(valid_request).await { + Ok(response) => Ok(response), + Err(err) => Err(err.into()), + }, + Err(err) => Err(Box::new(err) as Box), + } + }) + } +} + +impl Clone for OptionalMAuthValidationService { + fn clone(&self) -> Self { + OptionalMAuthValidationService { + // unwrap is safe because we validated the config_info before constructing the layer + mauth_info: MAuthInfo::from_config_section(&self.config_info).unwrap(), + config_info: self.config_info.clone(), + service: self.service.clone(), + } + } +} + +/// This is a Tower Layer which applies the OptionalMAuthValidationService on top of the +/// service provided to it. +#[derive(Clone)] +pub struct OptionalMAuthValidationLayer { + config_info: ConfigFileSection, +} + +impl Layer for OptionalMAuthValidationLayer { + type Service = OptionalMAuthValidationService; + + fn layer(&self, service: S) -> Self::Service { + OptionalMAuthValidationService { + // unwrap is safe because we validated the config_info before constructing the layer + mauth_info: MAuthInfo::from_config_section(&self.config_info).unwrap(), + config_info: self.config_info.clone(), + service, + } + } +} + +impl OptionalMAuthValidationLayer { + /// Construct an OptionalMAuthValidationLayer based on the configuration options in the file + /// found in the default location. + pub fn from_default_file() -> Result { + let config_info = MAuthInfo::config_section_from_default_file()?; + // Generate a MAuthInfo and then drop it to validate that it works, + // making it safe to use `unwrap` in the service constructor. + MAuthInfo::from_config_section(&config_info)?; + Ok(OptionalMAuthValidationLayer { config_info }) + } + + /// Construct an OptionalMAuthValidationLayer based on the configuration options in a manually + /// created or parsed ConfigFileSection. + pub fn from_config_section(config_info: ConfigFileSection) -> Result { + MAuthInfo::from_config_section(&config_info)?; + Ok(OptionalMAuthValidationLayer { config_info }) } } diff --git a/src/validate_incoming.rs b/src/validate_incoming.rs index ac02f5d..9539a63 100644 --- a/src/validate_incoming.rs +++ b/src/validate_incoming.rs @@ -1,4 +1,5 @@ use crate::{MAuthInfo, CLIENT, PUBKEY_CACHE}; +use axum::extract::Request; use chrono::prelude::*; use mauth_core::verifier::Verifier; use thiserror::Error; @@ -15,11 +16,16 @@ pub struct ValidatedRequestDetails { pub app_uuid: Uuid, } +const MAUTH_V1_SIGNATURE_HEADER: &str = "X-MWS-Authentication"; +const MAUTH_V2_SIGNATURE_HEADER: &str = "MCC-Authentication"; +const MAUTH_V1_TIMESTAMP_HEADER: &str = "X-MWS-Time"; +const MAUTH_V2_TIMESTAMP_HEADER: &str = "MCC-Time"; + impl MAuthInfo { pub(crate) async fn validate_request( &self, - req: axum::extract::Request, - ) -> Result { + req: Request, + ) -> Result { let (mut parts, body) = req.into_parts(); let body_bytes = axum::body::to_bytes(body, usize::MAX) .await @@ -30,7 +36,7 @@ impl MAuthInfo { app_uuid: host_app_uuid, }); let new_body = axum::body::Body::from(body_bytes); - let new_request = axum::extract::Request::from_parts(parts, new_body); + let new_request = Request::from_parts(parts, new_body); Ok(new_request) } Err(err) => { @@ -41,7 +47,7 @@ impl MAuthInfo { app_uuid: host_app_uuid, }); let new_body = axum::body::Body::from(body_bytes); - let new_request = axum::extract::Request::from_parts(parts, new_body); + let new_request = Request::from_parts(parts, new_body); Ok(new_request) } Err(err) => Err(err), @@ -53,6 +59,48 @@ impl MAuthInfo { } } + pub(crate) async fn validate_request_optionally( + &self, + req: Request, + ) -> Result { + let (mut parts, body) = req.into_parts(); + if parts.headers.contains_key(MAUTH_V2_SIGNATURE_HEADER) + || parts.headers.contains_key(MAUTH_V1_SIGNATURE_HEADER) + { + let body_bytes = axum::body::to_bytes(body, usize::MAX).await?; + + match self.validate_request_v2(&parts, &body_bytes).await { + Ok(host_app_uuid) => { + parts.extensions.insert(ValidatedRequestDetails { + app_uuid: host_app_uuid, + }); + } + Err(err) => { + if self.allow_v1_auth { + match self.validate_request_v1(&parts, &body_bytes).await { + Ok(host_app_uuid) => { + parts.extensions.insert(ValidatedRequestDetails { + app_uuid: host_app_uuid, + }); + } + Err(err) => { + parts.extensions.insert(err); + } + } + } else { + parts.extensions.insert(err); + } + } + } + + let new_body = axum::body::Body::from(body_bytes); + let new_request = Request::from_parts(parts, new_body); + Ok(new_request) + } else { + Ok(Request::from_parts(parts, body)) + } + } + async fn validate_request_v2( &self, req: &http::request::Parts, @@ -61,7 +109,7 @@ impl MAuthInfo { //retrieve and parse auth string let sig_header = req .headers - .get("MCC-Authentication") + .get(MAUTH_V2_SIGNATURE_HEADER) .ok_or(MAuthValidationError::NoSig)? .to_str() .map_err(|_| MAuthValidationError::InvalidSignature)?; @@ -70,7 +118,7 @@ impl MAuthInfo { //retrieve and validate timestamp let ts_str = req .headers - .get("MCC-Time") + .get(MAUTH_V2_TIMESTAMP_HEADER) .ok_or(MAuthValidationError::NoTime)? .to_str() .map_err(|_| MAuthValidationError::InvalidTime)?; @@ -107,7 +155,7 @@ impl MAuthInfo { //retrieve and parse auth string let sig_header = req .headers - .get("X-MWS-Authentication") + .get(MAUTH_V1_SIGNATURE_HEADER) .ok_or(MAuthValidationError::NoSig)? .to_str() .map_err(|_| MAuthValidationError::InvalidSignature)?; @@ -116,7 +164,7 @@ impl MAuthInfo { //retrieve and validate timestamp let ts_str = req .headers - .get("X-MWS-Time") + .get(MAUTH_V1_TIMESTAMP_HEADER) .ok_or(MAuthValidationError::NoTime)? .to_str() .map_err(|_| MAuthValidationError::InvalidTime)?; @@ -218,7 +266,7 @@ impl MAuthInfo { } /// All of the possible errors that can take place when attempting to verify a response signature -#[derive(Debug, Error)] +#[derive(Debug, Error, Clone)] pub enum MAuthValidationError { /// The timestamp of the response was either invalid or outside of the permitted /// range From 3eb3ad4c91ea5c00afa239a49b6b95f3b8d9cc92 Mon Sep 17 00:00:00 2001 From: Mason Gup Date: Thu, 2 Jan 2025 17:19:08 -0500 Subject: [PATCH 2/6] Impl FromRequestParts for ValidatedRequestDetails --- README.md | 11 +++++++++-- src/axum_service.rs | 17 ++++++++++++++++- src/reqwest_middleware.rs | 1 - 3 files changed, 25 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index eda2942..8c02652 100644 --- a/README.md +++ b/README.md @@ -51,14 +51,21 @@ match client.get("https://www.example.com/").send().await { The optional `axum-service` feature provides for a Tower Layer and Service that will authenticate incoming requests via MAuth V2 or V1 and provide to the lower layers a -validated app_uuid from the request via the ValidatedRequestDetails struct. Note that +validated app_uuid from the request via the `ValidatedRequestDetails` struct. Note that this feature now includes a `RequiredMAuthValidationLayer`, which will reject any requests without a valid signature before they reach lower layers, and also a `OptionalMAuthValidationLayer`, which lets all requests through, but only attaches a -ValidatedRequestDetails extension struct if there is a valid signature. When using this +`ValidatedRequestDetails` extension struct if there is a valid signature. When using this layer, it is the responsiblity of the request handler to check for the extension and reject requests that are not properly authorized. +Note that `ValidatedRequestDetails` implements Axum's `FromRequestParts`, so you can +specify it bare in a request handler. This implementation includes returning a 401 +Unauthorized status code if the extension is not present. If you would like to return +a different response, or respond to the lack of the extension in another way, you can +use a more manual mechanism to check for the extension and decide how to proceed if it +is not present. + There are also optional features `tracing-otel-26` and `tracing-otel-27` that pair with the `axum-service` feature to ensure that any outgoing requests for credentials that take place in the context of an incoming web request also include the proper OpenTelemetry span diff --git a/src/axum_service.rs b/src/axum_service.rs index 51faf2b..9c66297 100644 --- a/src/axum_service.rs +++ b/src/axum_service.rs @@ -1,11 +1,13 @@ //! Structs and impls related to providing a Tower Service and Layer to verify incoming requests -use axum::extract::Request; +use axum::extract::{FromRequestParts, Request}; use futures_core::future::BoxFuture; +use http::{request::Parts, StatusCode}; use std::error::Error; use std::task::{Context, Poll}; use tower::{Layer, Service}; +use crate::validate_incoming::ValidatedRequestDetails; use crate::{ config::{ConfigFileSection, ConfigReadError}, MAuthInfo, @@ -189,3 +191,16 @@ impl OptionalMAuthValidationLayer { Ok(OptionalMAuthValidationLayer { config_info }) } } + +#[async_trait::async_trait] +impl FromRequestParts for ValidatedRequestDetails { + type Rejection = StatusCode; + + async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result { + parts + .extensions + .get::() + .cloned() + .ok_or(StatusCode::UNAUTHORIZED) + } +} diff --git a/src/reqwest_middleware.rs b/src/reqwest_middleware.rs index 71f6c0d..80c6cb8 100644 --- a/src/reqwest_middleware.rs +++ b/src/reqwest_middleware.rs @@ -6,7 +6,6 @@ use crate::{sign_outgoing::SigningError, MAuthInfo}; #[async_trait::async_trait] impl Middleware for MAuthInfo { - #[must_use] async fn handle( &self, mut req: Request, From 89a0735b988bffa795eae590f24dcdf89b9de80f Mon Sep 17 00:00:00 2001 From: Mason Gup Date: Fri, 3 Jan 2025 14:50:27 -0500 Subject: [PATCH 3/6] Refactoring how some things work and add doc tests for middleware updates --- Cargo.toml | 5 +- README.md | 102 +++++++++++++++++++++++++++++++++++++++ src/axum_service.rs | 62 ++++++++++++++++-------- src/validate_incoming.rs | 27 ++++++++--- 4 files changed, 168 insertions(+), 28 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 8b68dbe..10b7916 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -26,17 +26,18 @@ dirs = "5" chrono = "0.4" tokio = { version = "1", features = ["fs"] } tower = { version = "0.4", optional = true } -axum = { version = ">= 0.7.2", optional = true } +axum = { version = ">= 0.8", optional = true } futures-core = { version = "0.3", optional = true } http = "1" bytes = { version = "1", optional = true } thiserror = "1" mauth-core = "0.6" +tracing = { version = "0.1", optional = true } [dev-dependencies] tokio = { version = "1", features = ["rt-multi-thread", "macros"] } [features] -axum-service = ["tower", "futures-core", "axum", "bytes"] +axum-service = ["tower", "futures-core", "axum", "bytes", "tracing"] tracing-otel-26 = ["reqwest-tracing/opentelemetry_0_26"] tracing-otel-27 = ["reqwest-tracing/opentelemetry_0_27"] diff --git a/README.md b/README.md index 8c02652..e0c6161 100644 --- a/README.md +++ b/README.md @@ -7,6 +7,8 @@ the MAuth protocol, and verify the responses. Usage example: release any code to Production or deploy in a Client-accessible environment without getting approval for the full stack used through the Architecture and Security groups. +## Outgoing Requests + ```no_run use mauth_client::MAuthInfo; use reqwest::Client; @@ -49,6 +51,8 @@ match client.get("https://www.example.com/").send().await { # } ``` +## Incoming Requests + The optional `axum-service` feature provides for a Tower Layer and Service that will authenticate incoming requests via MAuth V2 or V1 and provide to the lower layers a validated app_uuid from the request via the `ValidatedRequestDetails` struct. Note that @@ -66,6 +70,104 @@ a different response, or respond to the lack of the extension in another way, yo use a more manual mechanism to check for the extension and decide how to proceed if it is not present. +### Examples for `RequiredMAuthValidationLayer` + +```no_run +# async fn run_server() { +use mauth_client::{ + axum_service::RequiredMAuthValidationLayer, + validate_incoming::ValidatedRequestDetails, +}; +use axum::{http::StatusCode, Router, routing::get, serve}; +use tokio::net::TcpListener; + +// If there is not a valid mauth signature, this function will never run at all, and +// the request will return an empty 401 Unauthorized +async fn foo() -> StatusCode { + StatusCode::OK +} + +// In addition to returning a 401 Unauthorized without running if there is not a valid +// MAuth signature, this also makes the validated requesting app UUID available to +// the function +async fn bar(details: ValidatedRequestDetails) -> StatusCode { + println!("Got a request from app with UUID: {}", details.app_uuid); + StatusCode::OK +} + +// This function will run regardless of whether or not there is a mauth signature +async fn baz() -> StatusCode { + StatusCode::OK +} + +// Attaching the baz route handler after the layer means the layer is not run for +// requests to that path, so no mauth checking will be performed for that route and +// any other routes attached after the layer +let router = Router::new() + .route("/foo", get(foo)) + .route("/bar", get(bar)) + .layer(RequiredMAuthValidationLayer::from_default_file().unwrap()) + .route("/baz", get(baz)); +let listener = TcpListener::bind("0.0.0.0:3000").await.unwrap(); +serve(listener, router).await.unwrap(); +# } +``` + +### Examples for `OptionalMAuthValidationLayer` + +```no_run +# async fn run_server() { +use mauth_client::{ + axum_service::OptionalMAuthValidationLayer, + validate_incoming::ValidatedRequestDetails, +}; +use axum::{http::StatusCode, Router, routing::get, serve}; +use tokio::net::TcpListener; + +// This request will run no matter what the authorization status is +async fn foo() -> StatusCode { + StatusCode::OK +} + +// If there is not a valid mauth signature, this function will never run at all, and +// the request will return an empty 401 Unauthorized +async fn bar(_: ValidatedRequestDetails) -> StatusCode { + StatusCode::OK +} + +// In addition to returning a 401 Unauthorized without running if there is not a valid +// MAuth signature, this also makes the validated requesting app UUID available to +// the function +async fn baz(details: ValidatedRequestDetails) -> StatusCode { + println!("Got a request from app with UUID: {}", details.app_uuid); + StatusCode::OK +} + +// This request will run whether or not there is a valid mauth signature, but the Option +// provided can be used to tell you whether there was a valid signature, so you can +// implement things like multiple possible types of authentication or behavior other than +// a 401 return if there is no authentication +async fn bam(optional_details: Option) -> StatusCode { + match optional_details { + Some(details) => println!("Got a request from app with UUID: {}", details.app_uuid), + None => println!("Got a request without a valid mauth signature"), + } + StatusCode::OK +} + +let router = Router::new() + .route("/foo", get(foo)) + .route("/bar", get(bar)) + .route("/baz", get(baz)) + .route("/bam", get(bam)) + .layer(OptionalMAuthValidationLayer::from_default_file().unwrap()); +let listener = TcpListener::bind("0.0.0.0:3000").await.unwrap(); +serve(listener, router).await.unwrap(); +# } +``` + +### OpenTelemetry Integration + There are also optional features `tracing-otel-26` and `tracing-otel-27` that pair with the `axum-service` feature to ensure that any outgoing requests for credentials that take place in the context of an incoming web request also include the proper OpenTelemetry span diff --git a/src/axum_service.rs b/src/axum_service.rs index 9c66297..980c938 100644 --- a/src/axum_service.rs +++ b/src/axum_service.rs @@ -1,11 +1,17 @@ //! Structs and impls related to providing a Tower Service and Layer to verify incoming requests -use axum::extract::{FromRequestParts, Request}; +use axum::{ + body::Body, + extract::{FromRequestParts, OptionalFromRequestParts, Request}, + response::IntoResponse, +}; use futures_core::future::BoxFuture; -use http::{request::Parts, StatusCode}; +use http::{request::Parts, Response, StatusCode}; +use std::convert::Infallible; use std::error::Error; use std::task::{Context, Poll}; use tower::{Layer, Service}; +use tracing::error; use crate::validate_incoming::ValidatedRequestDetails; use crate::{ @@ -27,13 +33,14 @@ where S: Service + Send + Clone + 'static, S::Future: Send + 'static, S::Error: Into>, + S::Response: Into>, { - type Response = S::Response; - type Error = Box; + type Response = Response; + type Error = S::Error; type Future = BoxFuture<'static, Result>; fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { - self.service.poll_ready(cx).map_err(|e| e.into()) + self.service.poll_ready(cx) } fn call(&mut self, request: Request) -> Self::Future { @@ -41,10 +48,16 @@ where Box::pin(async move { match cloned.mauth_info.validate_request(request).await { Ok(valid_request) => match cloned.service.call(valid_request).await { - Ok(response) => Ok(response), - Err(err) => Err(err.into()), + Ok(response) => Ok(response.into()), + Err(err) => Err(err), }, - Err(err) => Err(Box::new(err) as Box), + Err(err) => { + error!( + error = ?err, + "Failed to validate MAuth signature, rejecting request" + ); + Ok(StatusCode::UNAUTHORIZED.into_response()) + } } }) } @@ -121,23 +134,18 @@ where S::Error: Into>, { type Response = S::Response; - type Error = Box; + type Error = S::Error; type Future = BoxFuture<'static, Result>; fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { - self.service.poll_ready(cx).map_err(|e| e.into()) + self.service.poll_ready(cx) } fn call(&mut self, request: Request) -> Self::Future { let mut cloned = self.clone(); Box::pin(async move { - match cloned.mauth_info.validate_request_optionally(request).await { - Ok(valid_request) => match cloned.service.call(valid_request).await { - Ok(response) => Ok(response), - Err(err) => Err(err.into()), - }, - Err(err) => Err(Box::new(err) as Box), - } + let processed_request = cloned.mauth_info.validate_request_optionally(request).await; + cloned.service.call(processed_request).await }) } } @@ -192,8 +200,10 @@ impl OptionalMAuthValidationLayer { } } -#[async_trait::async_trait] -impl FromRequestParts for ValidatedRequestDetails { +impl FromRequestParts for ValidatedRequestDetails +where + S: Send + Sync, +{ type Rejection = StatusCode; async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result { @@ -204,3 +214,17 @@ impl FromRequestParts for ValidatedRequestDetails { .ok_or(StatusCode::UNAUTHORIZED) } } + +impl OptionalFromRequestParts for ValidatedRequestDetails +where + S: Send + Sync, +{ + type Rejection = Infallible; + + async fn from_request_parts( + parts: &mut Parts, + _state: &S, + ) -> Result, Self::Rejection> { + Ok(parts.extensions.get::().cloned()) + } +} diff --git a/src/validate_incoming.rs b/src/validate_incoming.rs index 9539a63..8529cc4 100644 --- a/src/validate_incoming.rs +++ b/src/validate_incoming.rs @@ -1,8 +1,10 @@ use crate::{MAuthInfo, CLIENT, PUBKEY_CACHE}; use axum::extract::Request; +use bytes::Bytes; use chrono::prelude::*; use mauth_core::verifier::Verifier; use thiserror::Error; +use tracing::error; use uuid::Uuid; /// This struct holds the app UUID for a validated request. It is meant to be used with the @@ -59,15 +61,26 @@ impl MAuthInfo { } } - pub(crate) async fn validate_request_optionally( - &self, - req: Request, - ) -> Result { + pub(crate) async fn validate_request_optionally(&self, req: Request) -> Request { let (mut parts, body) = req.into_parts(); if parts.headers.contains_key(MAUTH_V2_SIGNATURE_HEADER) || parts.headers.contains_key(MAUTH_V1_SIGNATURE_HEADER) { - let body_bytes = axum::body::to_bytes(body, usize::MAX).await?; + // By my reading of the code for this it should never fail, since we are passing + // MAX for the limit. But just to be safe, we will log the error and proceed with + // an empty body just in case instead of unwrapping. This would cause the body to + // be unavailable to the lower layers, but they would probably also fail to get it + // anyways since we just did here. + let body_bytes = match axum::body::to_bytes(body, usize::MAX).await { + Ok(bytes) => bytes, + Err(err) => { + error!( + error = ?err, + "Failed to retrieve request body, continuing with empty body" + ); + Bytes::new() + } + }; match self.validate_request_v2(&parts, &body_bytes).await { Ok(host_app_uuid) => { @@ -95,9 +108,9 @@ impl MAuthInfo { let new_body = axum::body::Body::from(body_bytes); let new_request = Request::from_parts(parts, new_body); - Ok(new_request) + new_request } else { - Ok(Request::from_parts(parts, body)) + Request::from_parts(parts, body) } } From f48cfb982db67adc7c6ea92c10ce7af324f9f8f9 Mon Sep 17 00:00:00 2001 From: Mason Gup Date: Fri, 3 Jan 2025 14:53:54 -0500 Subject: [PATCH 4/6] Clippy --- src/validate_incoming.rs | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/validate_incoming.rs b/src/validate_incoming.rs index 8529cc4..79e73f3 100644 --- a/src/validate_incoming.rs +++ b/src/validate_incoming.rs @@ -107,8 +107,7 @@ impl MAuthInfo { } let new_body = axum::body::Body::from(body_bytes); - let new_request = Request::from_parts(parts, new_body); - new_request + Request::from_parts(parts, new_body) } else { Request::from_parts(parts, body) } From 6465766eef69f08a766e708c5ae95c7dd71f7541 Mon Sep 17 00:00:00 2001 From: Mason Gup Date: Fri, 3 Jan 2025 17:01:03 -0500 Subject: [PATCH 5/6] Add error logging for signature validation failures in optional middleware --- src/validate_incoming.rs | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/src/validate_incoming.rs b/src/validate_incoming.rs index 79e73f3..dbfc870 100644 --- a/src/validate_incoming.rs +++ b/src/validate_incoming.rs @@ -73,9 +73,9 @@ impl MAuthInfo { // anyways since we just did here. let body_bytes = match axum::body::to_bytes(body, usize::MAX).await { Ok(bytes) => bytes, - Err(err) => { + Err(error) => { error!( - error = ?err, + ?error, "Failed to retrieve request body, continuing with empty body" ); Bytes::new() @@ -88,7 +88,7 @@ impl MAuthInfo { app_uuid: host_app_uuid, }); } - Err(err) => { + Err(error_v2) => { if self.allow_v1_auth { match self.validate_request_v1(&parts, &body_bytes).await { Ok(host_app_uuid) => { @@ -96,12 +96,18 @@ impl MAuthInfo { app_uuid: host_app_uuid, }); } - Err(err) => { - parts.extensions.insert(err); + Err(error_v1) => { + error!( + ?error_v2, + ?error_v1, + "Error attempting to validate MAuth signatures" + ); + parts.extensions.insert(error_v1); } } } else { - parts.extensions.insert(err); + error!(?error_v2, "Error attempting to validate MAuth V2 signature"); + parts.extensions.insert(error_v2); } } } From d613f7133504dbfd4e35f59524d028aeede22381 Mon Sep 17 00:00:00 2001 From: Mason Gup Date: Fri, 3 Jan 2025 18:00:39 -0500 Subject: [PATCH 6/6] Document error handling --- README.md | 14 ++++++++++++++ src/axum_service.rs | 16 +++++++++++++++- 2 files changed, 29 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index e0c6161..ba29e9f 100644 --- a/README.md +++ b/README.md @@ -166,6 +166,20 @@ serve(listener, router).await.unwrap(); # } ``` +### Error Handling + +Both the `RequiredMAuthValidationLayer` and the `OptionalMAuthValidationLayer` layers will +log errors encountered via `tracing` under the `mauth_client::validate_incoming` target. + +The Required layer returns the 401 response immediately, so there is no convenient way to +retrieve the error in order to do anything more sophisticated with it. + +The Optional layer, in addition to loging the error, will also add the `MAuthValidationError` +to the request extensions. If desired, any request handlers or middlewares can retrieve it +from there in order to take further actions based on the error type. This error type also +implements Axum's `OptionalFromRequestParts`, so you can more easily retrieve it using +`Option` anywhere that supports extractors. + ### OpenTelemetry Integration There are also optional features `tracing-otel-26` and `tracing-otel-27` that pair with diff --git a/src/axum_service.rs b/src/axum_service.rs index 980c938..4cdf1a6 100644 --- a/src/axum_service.rs +++ b/src/axum_service.rs @@ -13,7 +13,7 @@ use std::task::{Context, Poll}; use tower::{Layer, Service}; use tracing::error; -use crate::validate_incoming::ValidatedRequestDetails; +use crate::validate_incoming::{MAuthValidationError, ValidatedRequestDetails}; use crate::{ config::{ConfigFileSection, ConfigReadError}, MAuthInfo, @@ -228,3 +228,17 @@ where Ok(parts.extensions.get::().cloned()) } } + +impl OptionalFromRequestParts for MAuthValidationError +where + S: Send + Sync, +{ + type Rejection = Infallible; + + async fn from_request_parts( + parts: &mut Parts, + _state: &S, + ) -> Result, Self::Rejection> { + Ok(parts.extensions.get::().cloned()) + } +}