diff --git a/CHANGELOG.md b/CHANGELOG.md index 13e1be7..d5dfd69 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,7 +4,14 @@ All notable changes to this project will be documented in this file. ## [Unreleased] -- +### Added + +- Support for response body size metric, which can be turned on via `PrometheusMetricLayerBuilder::enable_response_body_size`. +- All metrics now are initialized via `metrics::describe_*` function by default, but can be turned off with `PrometheusMetricLayerBuilder::no_initialize_metrics`. + +### Changed + +- The lower-level Lifecycle API has changed: separated the `OnBodyChunk` trait, which is ran when a response body chunk has been generated. # [0.4.0] - 2023-07-24 diff --git a/README.md b/README.md index 276f561..eeb418f 100644 --- a/README.md +++ b/README.md @@ -24,7 +24,7 @@ By default three HTTP metrics are tracked - `axum_http_requests_duration_seconds` (labels: endpoint, method, status): the request duration for all HTTP requests handled (histogram) - `axum_http_requests_pending` (labels: endpoint, method): the number of currently in-flight requests (gauge) -Note that in the future request size metric is also planned to be implemented. +This crate also allows to track response body sizes as a histogram — see `PrometheusMetricLayerBuilder::enable_response_body_size`. ### Renaming Metrics @@ -33,6 +33,7 @@ These metrics can be renamed by specifying environmental variables at compile ti - `AXUM_HTTP_REQUESTS_TOTAL` - `AXUM_HTTP_REQUESTS_DURATION_SECONDS` - `AXUM_HTTP_REQUESTS_PENDING` +- `AXUM_HTTP_RESPONSE_BODY_SIZE` (if body size tracking is enabled) These environmental variables can be set in your `.cargo/config.toml` since Cargo 1.56: @@ -41,6 +42,7 @@ These environmental variables can be set in your `.cargo/config.toml` since Carg AXUM_HTTP_REQUESTS_TOTAL = "my_app_requests_total" AXUM_HTTP_REQUESTS_DURATION_SECONDS = "my_app_requests_duration_seconds" AXUM_HTTP_REQUESTS_PENDING = "my_app_requests_pending" +AXUM_HTTP_RESPONSE_BODY_SIZE = "my_app_response_body_size" ``` ..or optionally use `PrometheusMetricLayerBuilder::with_prefix` function. @@ -133,6 +135,8 @@ struct Recorder; // In order to use this with `axum_prometheus`, we must implement `MakeDefaultHandle`. impl MakeDefaultHandle for Recorder { + // We don't need to return anything meaningful from here (unlike PrometheusHandle) + // Let's just return an empty tuple. type Out = (); fn make_default_handle() -> Self::Out { @@ -144,9 +148,6 @@ impl MakeDefaultHandle for Recorder { .expect("Could not create StatsdRecorder"); metrics::set_boxed_recorder(Box::new(recorder)).unwrap(); - // We don't need to return anything meaningful from here (unlike PrometheusHandle) - // Let's just return an empty tuple. - () } } diff --git a/examples/builder-example/src/main.rs b/examples/builder-example/src/main.rs index e59897c..8e10f0f 100644 --- a/examples/builder-example/src/main.rs +++ b/examples/builder-example/src/main.rs @@ -18,7 +18,7 @@ async fn main() { tracing_subscriber::registry() .with( tracing_subscriber::EnvFilter::try_from_default_env() - .unwrap_or_else(|_| "builder-example=debug".into()), + .unwrap_or_else(|_| "builder_example=debug".into()), ) .with(tracing_subscriber::fmt::layer()) .init(); diff --git a/examples/endpoint-type-example/src/main.rs b/examples/endpoint-type-example/src/main.rs index ce5639c..0741714 100644 --- a/examples/endpoint-type-example/src/main.rs +++ b/examples/endpoint-type-example/src/main.rs @@ -14,7 +14,7 @@ async fn main() { tracing_subscriber::registry() .with( tracing_subscriber::EnvFilter::try_from_default_env() - .unwrap_or_else(|_| "endpoint-type-example=debug".into()), + .unwrap_or_else(|_| "endpoint_type_example=debug".into()), ) .with(tracing_subscriber::fmt::layer()) .init(); diff --git a/examples/exporter-statsd-example/src/main.rs b/examples/exporter-statsd-example/src/main.rs index 16e828d..422e465 100644 --- a/examples/exporter-statsd-example/src/main.rs +++ b/examples/exporter-statsd-example/src/main.rs @@ -15,6 +15,8 @@ struct Recorder; // In order to use this with `axum_prometheus`, we must implement `MakeDefaultHandle`. impl MakeDefaultHandle for Recorder { + // We don't need to return anything meaningful from here (unlike PrometheusHandle) + // Let's just return an empty tuple. type Out = (); fn make_default_handle() -> Self::Out { @@ -26,9 +28,6 @@ impl MakeDefaultHandle for Recorder { .expect("Could not create StatsdRecorder"); metrics::set_boxed_recorder(Box::new(recorder)).unwrap(); - // We don't need to return anything meaningful from here (unlike PrometheusHandle) - // Let's just return an empty tuple. - () } } @@ -37,7 +36,7 @@ async fn main() { tracing_subscriber::registry() .with( tracing_subscriber::EnvFilter::try_from_default_env() - .unwrap_or_else(|_| "exporter-statsd-example=debug".into()), + .unwrap_or_else(|_| "exporter_statsd_example=debug".into()), ) .with(tracing_subscriber::fmt::layer()) .init(); diff --git a/examples/simple-example/src/main.rs b/examples/simple-example/src/main.rs index e0a0997..140a4fc 100644 --- a/examples/simple-example/src/main.rs +++ b/examples/simple-example/src/main.rs @@ -13,7 +13,7 @@ async fn main() { tracing_subscriber::registry() .with( tracing_subscriber::EnvFilter::try_from_default_env() - .unwrap_or_else(|_| "simple-example=debug".into()), + .unwrap_or_else(|_| "simple_example=debug".into()), ) .with(tracing_subscriber::fmt::layer()) .init(); diff --git a/src/builder.rs b/src/builder.rs index 3bfb00b..3089954 100644 --- a/src/builder.rs +++ b/src/builder.rs @@ -58,6 +58,8 @@ pub struct MetricLayerBuilder<'a, T, M, S: MetricBuilderState> { pub(crate) traffic: Traffic<'a>, pub(crate) metric_handle: Option, pub(crate) metric_prefix: Option, + pub(crate) enable_body_size: bool, + pub(crate) no_initialize_metrics: bool, pub(crate) _marker: PhantomData<(S, M)>, } @@ -144,6 +146,23 @@ where self.traffic.with_endpoint_label_type(endpoint_label); self } + + /// Enable response body size tracking. + /// + /// #### Note: + /// This may introduce some performance overhead. + pub fn enable_response_body_size(mut self, enable: bool) -> Self { + self.enable_body_size = enable; + self + } + + /// By default, all metrics are initialized via `metrics::describe_*` macros, setting descriptions and units. + /// + /// This function disables this initialization. + pub fn no_initialize_metrics(mut self) -> Self { + self.no_initialize_metrics = true; + self + } } impl<'a, T, M> MetricLayerBuilder<'a, T, M, LayerOnly> @@ -156,7 +175,9 @@ where _marker: PhantomData, traffic: Traffic::new(), metric_handle: None, + no_initialize_metrics: false, metric_prefix: None, + enable_body_size: false, } } @@ -166,7 +187,9 @@ where /// - `{prefix}_http_requests_pending` /// - `{prefix}_http_requests_duration_seconds` /// - /// Note that this will take precedence over environment variables. + /// ..and will also use `{prefix}_http_response_body_size`, if response body size tracking is enabled. + /// + /// This method will take precedence over environment variables. /// /// ## Note /// @@ -249,11 +272,16 @@ where if let Some(prefix) = layer_only.metric_prefix.as_ref() { set_prefix(prefix); } + if !layer_only.no_initialize_metrics { + describe_metrics(layer_only.enable_body_size); + } MetricLayerBuilder { _marker: PhantomData, traffic: layer_only.traffic, metric_handle: layer_only.metric_handle, + no_initialize_metrics: layer_only.no_initialize_metrics, metric_prefix: layer_only.metric_prefix, + enable_body_size: layer_only.enable_body_size, } } } @@ -273,3 +301,28 @@ where /// A builder for [`crate::PrometheusMetricLayer`] that enables further customizations. pub type PrometheusMetricLayerBuilder<'a, S> = MetricLayerBuilder<'a, PrometheusHandle, crate::Handle, S>; + +fn describe_metrics(enable_body_size: bool) { + metrics::describe_counter!( + crate::utils::requests_total_name(), + metrics::Unit::Count, + "The number of times a HTTP request was processed." + ); + metrics::describe_gauge!( + crate::utils::requests_pending_name(), + metrics::Unit::Count, + "The number of currently in-flight requests." + ); + metrics::describe_histogram!( + crate::utils::requests_duration_name(), + metrics::Unit::Seconds, + "The distribution of HTTP response times." + ); + if enable_body_size { + metrics::describe_histogram!( + crate::utils::response_body_size_name(), + metrics::Unit::Count, + "The distribution of HTTP response body sizes." + ); + } +} diff --git a/src/lib.rs b/src/lib.rs index 8baf97a..04edfeb 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -9,7 +9,7 @@ //! - `axum_http_requests_duration_seconds` (labels: endpoint, method, status): the request duration for all HTTP requests handled (histogram) //! - `axum_http_requests_pending` (labels: endpoint, method): the number of currently in-flight requests (gauge) //! -//! Note that in the future request size metric is also planned to be implemented. +//! This crate also allows to track response body sizes as a histogram — see [`PrometheusMetricLayerBuilder::enable_response_body_size`]. //! //! ### Renaming Metrics //! @@ -17,6 +17,7 @@ //! - `AXUM_HTTP_REQUESTS_TOTAL` //! - `AXUM_HTTP_REQUESTS_DURATION_SECONDS` //! - `AXUM_HTTP_REQUESTS_PENDING` +//! - `AXUM_HTTP_RESPONSE_BODY_SIZE` (if body size tracking is enabled) //! //! These environmental variables can be set in your `.cargo/config.toml` since Cargo 1.56: //! ```toml @@ -24,6 +25,7 @@ //! AXUM_HTTP_REQUESTS_TOTAL = "my_app_requests_total" //! AXUM_HTTP_REQUESTS_DURATION_SECONDS = "my_app_requests_duration_seconds" //! AXUM_HTTP_REQUESTS_PENDING = "my_app_requests_pending" +//! AXUM_HTTP_RESPONSE_BODY_SIZE = "my_app_response_body_size" //! ``` //! //! ..or optionally use [`PrometheusMetricLayerBuilder::with_prefix`] function. @@ -162,16 +164,27 @@ pub const AXUM_HTTP_REQUESTS_TOTAL: &str = match option_env!("AXUM_HTTP_REQUESTS None => "axum_http_requests_total", }; +/// Identifies the histogram/summary used for response body size. Defaults to `axum_http_response_body_size`, +/// but can be changed by setting the `AXUM_HTTP_RESPONSE_BODY_SIZE` env at compile time. +pub const AXUM_HTTP_RESPONSE_BODY_SIZE: &str = match option_env!("AXUM_HTTP_RESPONSE_BODY_SIZE") { + Some(n) => n, + None => "axum_http_response_body_size", +}; + #[doc(hidden)] pub static PREFIXED_HTTP_REQUESTS_TOTAL: OnceCell = OnceCell::new(); #[doc(hidden)] pub static PREFIXED_HTTP_REQUESTS_DURATION_SECONDS: OnceCell = OnceCell::new(); #[doc(hidden)] pub static PREFIXED_HTTP_REQUESTS_PENDING: OnceCell = OnceCell::new(); +#[doc(hidden)] +pub static PREFIXED_HTTP_RESPONSE_BODY_SIZE: OnceCell = OnceCell::new(); use std::borrow::Cow; use std::collections::HashMap; use std::marker::PhantomData; +use std::sync::atomic::AtomicBool; +use std::sync::Arc; use std::time::Instant; mod builder; @@ -184,6 +197,7 @@ pub use builder::MetricLayerBuilder; pub use builder::PrometheusMetricLayerBuilder; use builder::{LayerOnly, Paired}; use lifecycle::layer::LifeCycleLayer; +use lifecycle::OnBodyChunk; use lifecycle::{service::LifeCycle, Callbacks}; use metrics::{decrement_gauge, histogram, increment_counter, increment_gauge}; use once_cell::sync::OnceCell; @@ -209,7 +223,6 @@ fn set_prefix(prefix: impl AsRef) { PREFIXED_HTTP_REQUESTS_TOTAL .set(format!("{}_http_requests_total", prefix.as_ref())) .expect("the prefix has already been set, and can only be set once."); - PREFIXED_HTTP_REQUESTS_DURATION_SECONDS .set(format!( "{}_http_requests_duration_seconds", @@ -219,6 +232,9 @@ fn set_prefix(prefix: impl AsRef) { PREFIXED_HTTP_REQUESTS_PENDING .set(format!("{}_http_requests_pending", prefix.as_ref())) .expect("the prefix has already been set, and can only be set once."); + PREFIXED_HTTP_RESPONSE_BODY_SIZE + .set(format!("{}_http_response_body_size", prefix.as_ref())) + .expect("the prefix has already been set, and can only be set once."); } /// A marker struct that implements the [`lifecycle::Callbacks`] trait. @@ -285,6 +301,65 @@ pub struct MetricsData { pub endpoint: String, pub start: Instant, pub method: &'static str, + pub body_size: f64, + // FIXME: Unclear at the moment, maybe just a simple bool could suffice here? + pub(crate) exact_body_size_called: Arc, +} + +/// A marker struct that implements [`lifecycle::OnBodyChunk`], so it can be used to track response body sizes. +#[derive(Clone)] +pub struct BodySizeRecorder; + +impl OnBodyChunk for BodySizeRecorder +where + B: bytes::Buf, +{ + type Data = Option; + + #[inline] + fn call(&mut self, body: &B, body_size: Option, data: &mut Self::Data) { + let Some(metrics_data) = data else { return }; + // If the exact body size is known ahead of time, we'll just call this whole thing once. + if let Some(exact_size) = body_size { + if !metrics_data + .exact_body_size_called + .swap(true, std::sync::atomic::Ordering::Relaxed) + { + // If the body size is enormous, we lose some precision. It shouldn't matter really. + metrics_data.body_size = exact_size as f64; + body_size_histogram(metrics_data); + } + } else { + // Otherwise, sum all the chunks. + metrics_data.body_size += body.remaining() as f64; + body_size_histogram(metrics_data); + } + } +} + +impl OnBodyChunk for Option +where + T: OnBodyChunk, + B: bytes::Buf, +{ + type Data = T::Data; + + fn call(&mut self, body: &B, body_size: Option, data: &mut Self::Data) { + if let Some(this) = self { + T::call(this, body, body_size, data); + } + } +} + +fn body_size_histogram(metrics_data: &MetricsData) { + let labels = &[ + ("method", metrics_data.method.to_owned()), + ("endpoint", metrics_data.endpoint.clone()), + ]; + let response_body_size = PREFIXED_HTTP_RESPONSE_BODY_SIZE + .get() + .map_or(AXUM_HTTP_RESPONSE_BODY_SIZE, |s| s.as_str()); + metrics::histogram!(response_body_size, metrics_data.body_size, labels); } impl<'a, FailureClass> Callbacks for Traffic<'a> { @@ -335,6 +410,8 @@ impl<'a, FailureClass> Callbacks for Traffic<'a> { endpoint, start: now, method, + body_size: 0.0, + exact_body_size_called: Arc::new(AtomicBool::new(false)), }) } @@ -380,7 +457,11 @@ impl<'a, FailureClass> Callbacks for Traffic<'a> { /// The tower middleware layer for recording http metrics with different exporters. pub struct GenericMetricLayer<'a, T, M> { - pub(crate) inner_layer: LifeCycleLayer, Traffic<'a>>, + pub(crate) inner_layer: LifeCycleLayer< + SharedClassifier, + Traffic<'a>, + Option, + >, _marker: PhantomData<(T, M)>, } @@ -389,7 +470,7 @@ impl<'a, T, M> std::clone::Clone for GenericMetricLayer<'a, T, M> { fn clone(&self) -> Self { GenericMetricLayer { inner_layer: self.inner_layer.clone(), - _marker: self._marker.clone(), + _marker: self._marker, } } } @@ -455,17 +536,25 @@ where pub fn new() -> Self { let make_classifier = StatusInRangeAsFailures::new_for_client_and_server_errors().into_make_classifier(); - let inner_layer = LifeCycleLayer::new(make_classifier, Traffic::new()); + let inner_layer = LifeCycleLayer::new(make_classifier, Traffic::new(), None); Self { inner_layer, _marker: PhantomData, } } + pub fn enable_response_body_size(&mut self) { + self.inner_layer.on_body_chunk(Some(BodySizeRecorder)); + } + pub(crate) fn from_builder(builder: MetricLayerBuilder<'a, T, M, LayerOnly>) -> Self { let make_classifier = StatusInRangeAsFailures::new_for_client_and_server_errors().into_make_classifier(); - let inner_layer = LifeCycleLayer::new(make_classifier, builder.traffic); + let inner_layer = if builder.enable_body_size { + LifeCycleLayer::new(make_classifier, builder.traffic, Some(BodySizeRecorder)) + } else { + LifeCycleLayer::new(make_classifier, builder.traffic, None) + }; Self { inner_layer, _marker: PhantomData, @@ -475,7 +564,11 @@ where pub(crate) fn pair_from_builder(builder: MetricLayerBuilder<'a, T, M, Paired>) -> (Self, T) { let make_classifier = StatusInRangeAsFailures::new_for_client_and_server_errors().into_make_classifier(); - let inner_layer = LifeCycleLayer::new(make_classifier, builder.traffic); + let inner_layer = if builder.enable_body_size { + LifeCycleLayer::new(make_classifier, builder.traffic, Some(BodySizeRecorder)) + } else { + LifeCycleLayer::new(make_classifier, builder.traffic, None) + }; ( Self { @@ -531,7 +624,12 @@ where } impl<'a, S, T, M> Layer for GenericMetricLayer<'a, T, M> { - type Service = LifeCycle, Traffic<'a>>; + type Service = LifeCycle< + S, + SharedClassifier, + Traffic<'a>, + Option, + >; fn layer(&self, inner: S) -> Self::Service { self.inner_layer.layer(inner) diff --git a/src/lifecycle/body.rs b/src/lifecycle/body.rs index 46a4b95..d485d42 100644 --- a/src/lifecycle/body.rs +++ b/src/lifecycle/body.rs @@ -1,5 +1,6 @@ -use super::{Callbacks, FailedAt}; +use super::{Callbacks, FailedAt, OnBodyChunk}; use futures_core::ready; +use http::HeaderValue; use http_body::Body; use pin_project::pin_project; use std::{ @@ -11,18 +12,24 @@ use tower_http::classify::ClassifyEos; /// Response body for [`LifeCycle`]. #[pin_project] -pub struct ResponseBody { +pub struct ResponseBody { #[pin] pub(super) inner: B, - pub(super) parts: Option<(C, Callbacks, CallbacksData)>, + pub(super) parts: Option<(C, Callbacks)>, + pub(super) callbacks_data: CallbacksData, + pub(super) on_body_chunk: OnBodyChunk, + pub(super) content_length: Option, } -impl Body for ResponseBody +impl Body + for ResponseBody where B: Body, B::Error: fmt::Display + 'static, C: ClassifyEos, CallbacksT: Callbacks, + OnBodyChunkT: OnBodyChunk, + CallbacksData: Clone, { type Data = B::Data; type Error = B::Error; @@ -33,21 +40,29 @@ where ) -> Poll>> { let this = self.project(); - let result = ready!(this.inner.poll_data(cx)); + let body_size = this.inner.size_hint().exact(); + let Some(result) = ready!(this.inner.poll_data(cx)) else { + return Poll::Ready(None); + }; + + let body_size = body_size.or_else(|| { + this.content_length + .as_ref() + .and_then(|cl| cl.to_str().ok()) + .and_then(|cl| cl.parse().ok()) + }); match result { - None => Poll::Ready(None), - Some(Ok(chunk)) => { - if let Some((_, callbacks, callbacks_data)) = &this.parts { - callbacks.on_body_chunk(&chunk, callbacks_data); - } + Ok(chunk) => { + this.on_body_chunk + .call(&chunk, body_size, this.callbacks_data); Poll::Ready(Some(Ok(chunk))) } - Some(Err(err)) => { - if let Some((classify_eos, callbacks, callbacks_data)) = this.parts.take() { + Err(err) => { + if let Some((classify_eos, callbacks)) = this.parts.take() { let classification = classify_eos.classify_error(&err); - callbacks.on_failure(FailedAt::Body, classification, callbacks_data); + callbacks.on_failure(FailedAt::Body, classification, this.callbacks_data); } Poll::Ready(Some(Err(err))) @@ -65,18 +80,18 @@ where match result { Ok(trailers) => { - if let Some((classify_eos, callbacks, callbacks_data)) = this.parts.take() { + if let Some((classify_eos, callbacks)) = this.parts.take() { let trailers = trailers.as_ref(); let classification = classify_eos.classify_eos(trailers); - callbacks.on_eos(trailers, classification, callbacks_data); + callbacks.on_eos(trailers, classification, this.callbacks_data.clone()); } Poll::Ready(Ok(trailers)) } Err(err) => { - if let Some((classify_eos, callbacks, callbacks_data)) = this.parts.take() { + if let Some((classify_eos, callbacks)) = this.parts.take() { let classification = classify_eos.classify_error(&err); - callbacks.on_failure(FailedAt::Trailers, classification, callbacks_data); + callbacks.on_failure(FailedAt::Trailers, classification, this.callbacks_data); } Poll::Ready(Err(err)) diff --git a/src/lifecycle/future.rs b/src/lifecycle/future.rs index 70b0ad0..8d94334 100644 --- a/src/lifecycle/future.rs +++ b/src/lifecycle/future.rs @@ -9,28 +9,33 @@ use std::{ }; use tower_http::classify::{ClassifiedResponse, ClassifyResponse}; -use super::{body::ResponseBody, Callbacks, FailedAt}; +use super::{body::ResponseBody, Callbacks, FailedAt, OnBodyChunk}; #[pin_project] -pub struct ResponseFuture { +pub struct ResponseFuture { #[pin] pub(super) inner: F, pub(super) classifier: Option, pub(super) callbacks: Option, + pub(super) on_body_chunk: Option, pub(super) callbacks_data: Option, } -impl Future - for ResponseFuture +impl Future + for ResponseFuture where F: Future, E>>, ResBody: Body, C: ClassifyResponse, CallbacksT: Callbacks, E: std::fmt::Display + 'static, + OnBodyChunkT: OnBodyChunk, + CallbacksData: Clone, { - type Output = - Result>, E>; + type Output = Result< + Response>, + E, + >; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let this = self.project(); let result = ready!(this.inner.poll(cx)); @@ -47,9 +52,14 @@ where .callbacks_data .take() .expect("polled future after completion"); + let on_body_chunk = this + .on_body_chunk + .take() + .expect("polled future after completion"); match result { Ok(res) => { + let content_length = res.headers().get(http::header::CONTENT_LENGTH).cloned(); let classification = classifier.classify_response(&res); match classification { @@ -62,6 +72,9 @@ where let res = res.map(|body| ResponseBody { inner: body, parts: None, + on_body_chunk, + callbacks_data: callbacks_data.clone(), + content_length, }); Poll::Ready(Ok(res)) } @@ -73,7 +86,10 @@ where ); let res = res.map(|body| ResponseBody { inner: body, - parts: Some((classify_eos, callbacks, callbacks_data)), + callbacks_data: callbacks_data.clone(), + on_body_chunk, + parts: Some((classify_eos, callbacks)), + content_length, }); Poll::Ready(Ok(res)) } @@ -81,7 +97,7 @@ where } Err(err) => { let classification = classifier.classify_error(&err); - callbacks.on_failure(FailedAt::Response, classification, callbacks_data); + callbacks.on_failure(FailedAt::Response, classification, &mut callbacks_data); Poll::Ready(Err(err)) } } diff --git a/src/lifecycle/layer.rs b/src/lifecycle/layer.rs index b3b9970..85798d5 100644 --- a/src/lifecycle/layer.rs +++ b/src/lifecycle/layer.rs @@ -8,33 +8,41 @@ use super::service::LifeCycle; /// /// [`Layer`]: tower::Layer #[derive(Debug, Clone)] -pub struct LifeCycleLayer { +pub struct LifeCycleLayer { pub(super) make_classifier: MC, pub(super) callbacks: Callbacks, + pub(super) on_body_chunk: OnBodyChunk, } -impl LifeCycleLayer { +impl LifeCycleLayer { /// Create a new `LifeCycleLayer`. - pub fn new(make_classifier: MC, callbacks: Callbacks) -> Self { + pub fn new(make_classifier: MC, callbacks: Callbacks, on_body_chunk: OnBodyChunk) -> Self { LifeCycleLayer { make_classifier, callbacks, + on_body_chunk, } } + + pub(crate) fn on_body_chunk(&mut self, on_body_chunk: OnBodyChunk) { + self.on_body_chunk = on_body_chunk; + } } -impl Layer for LifeCycleLayer +impl Layer for LifeCycleLayer where MC: Clone, Callbacks: Clone, + OnBodyChunk: Clone, { - type Service = LifeCycle; + type Service = LifeCycle; fn layer(&self, inner: S) -> Self::Service { LifeCycle { inner, make_classifier: self.make_classifier.clone(), callbacks: self.callbacks.clone(), + on_body_chunk: self.on_body_chunk.clone(), } } } diff --git a/src/lifecycle/mod.rs b/src/lifecycle/mod.rs index 1f74d23..0003445 100644 --- a/src/lifecycle/mod.rs +++ b/src/lifecycle/mod.rs @@ -54,17 +54,6 @@ pub trait Callbacks: Sized { ) { } - /// Perform some action when a response body chunk has been generated. - /// - /// This is called when [`Body::poll_data`] completes with `Some(Ok(chunk))` - /// regardless if the chunk is empty or not. - /// - /// The default implementation does nothing and returns immediately. - /// - /// [`Body::poll_data`]: http_body::Body::poll_data - #[inline] - fn on_body_chunk(&self, _check: &B, _data: &Self::Data) {} - /// Perform some action when a stream has ended. /// /// This is called when [`Body::poll_trailers`] completes with @@ -118,11 +107,27 @@ pub trait Callbacks: Sized { self, _failed_at: FailedAt, _failure_classification: FailureClass, - _data: Self::Data, + _data: &mut Self::Data, ) { } } +/// A trait that allows to hook into [`http_body::Body::poll_data`]'s lifecycle. +pub trait OnBodyChunk { + type Data; + + /// Perform some action when a response body chunk has been generated. + /// + /// This is called when [`Body::poll_data`] completes with `Some(Ok(chunk))` + /// regardless if the chunk is empty or not. + /// + /// The default implementation does nothing and returns immediately. + /// + /// [`Body::poll_data`]: http_body::Body::poll_data + #[inline] + fn call(&mut self, _body: &B, _exact_body_size: Option, _data: &mut Self::Data) {} +} + /// Enum used to specify where an error was encountered. #[derive(Debug)] pub enum FailedAt { diff --git a/src/lifecycle/service.rs b/src/lifecycle/service.rs index c6576f4..40694ac 100644 --- a/src/lifecycle/service.rs +++ b/src/lifecycle/service.rs @@ -5,26 +5,39 @@ use http_body::Body; use tower::Service; use tower_http::classify::MakeClassifier; -use super::{body::ResponseBody, future::ResponseFuture, layer::LifeCycleLayer, Callbacks}; +use super::{ + body::ResponseBody, future::ResponseFuture, layer::LifeCycleLayer, Callbacks, OnBodyChunk, +}; #[derive(Clone, Debug)] -pub struct LifeCycle { +pub struct LifeCycle { pub(super) inner: S, pub(super) make_classifier: MC, pub(super) callbacks: Callbacks, + pub(super) on_body_chunk: OnBodyChunk, } -impl LifeCycle { - pub fn new(inner: S, make_classifier: MC, callbacks: Callbacks) -> Self { +impl LifeCycle { + pub fn new( + inner: S, + make_classifier: MC, + callbacks: Callbacks, + on_body_chunk: OnBodyChunk, + ) -> Self { Self { inner, make_classifier, callbacks, + on_body_chunk, } } - pub fn layer(make_classifier: MC, callbacks: Callbacks) -> LifeCycleLayer { - LifeCycleLayer::new(make_classifier, callbacks) + pub fn layer( + make_classifier: MC, + callbacks: Callbacks, + on_body_chunk: OnBodyChunk, + ) -> LifeCycleLayer { + LifeCycleLayer::new(make_classifier, callbacks, on_body_chunk) } /// Gets a reference to the underlying service. @@ -43,17 +56,23 @@ impl LifeCycle { } } -impl Service> for LifeCycle +impl Service> + for LifeCycle where S: Service, Response = Response>, ResBody: Body, MC: MakeClassifier, CallbacksT: Callbacks + Clone, S::Error: std::fmt::Display + 'static, + OnBodyChunkT: OnBodyChunk + Clone, + CallbacksT::Data: Clone, { - type Response = Response>; + type Response = Response< + ResponseBody, + >; type Error = S::Error; - type Future = ResponseFuture; + type Future = + ResponseFuture; fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { self.inner.poll_ready(cx) @@ -69,6 +88,7 @@ where classifier: Some(classifier), callbacks: Some(self.callbacks.clone()), callbacks_data: Some(callbacks_data), + on_body_chunk: Some(self.on_body_chunk.clone()), } } } diff --git a/src/utils.rs b/src/utils.rs index e557624..77a5e7c 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -1,9 +1,10 @@ +//! Utilities for getting metric names at runtime, and other helpers. use http::Method; use crate::{ AXUM_HTTP_REQUESTS_DURATION_SECONDS, AXUM_HTTP_REQUESTS_PENDING, AXUM_HTTP_REQUESTS_TOTAL, - PREFIXED_HTTP_REQUESTS_DURATION_SECONDS, PREFIXED_HTTP_REQUESTS_PENDING, - PREFIXED_HTTP_REQUESTS_TOTAL, + AXUM_HTTP_RESPONSE_BODY_SIZE, PREFIXED_HTTP_REQUESTS_DURATION_SECONDS, + PREFIXED_HTTP_REQUESTS_PENDING, PREFIXED_HTTP_REQUESTS_TOTAL, PREFIXED_HTTP_RESPONSE_BODY_SIZE, }; /// Standard HTTP request duration buckets measured in seconds. The default buckets are tailored to broadly @@ -57,3 +58,13 @@ pub fn requests_pending_name() -> &'static str { .get() .map_or(AXUM_HTTP_REQUESTS_PENDING, |s| s.as_str()) } + +/// The name of the response body size metric. By default, it's the same as [`AXUM_HTTP_RESPONSE_BODY_SIZE`], but +/// can be changed via the [`with_prefix`] function. +/// +/// [`with_prefix`]: crate::MetricLayerBuilder::with_prefix +pub fn response_body_size_name() -> &'static str { + PREFIXED_HTTP_RESPONSE_BODY_SIZE + .get() + .map_or(AXUM_HTTP_RESPONSE_BODY_SIZE, |s| s.as_str()) +}