diff --git a/Cargo.lock b/Cargo.lock index ad74e4df..1719c394 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -170,6 +170,12 @@ dependencies = [ "syn 2.0.53", ] +[[package]] +name = "auto-future" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3c1e7e457ea78e524f48639f551fd79703ac3f2237f5ecccdf4708f8a75ad373" + [[package]] name = "autocfg" version = "1.1.0" @@ -217,7 +223,7 @@ dependencies = [ "http 1.1.0", "http-body 1.0.0", "http-body-util", - "hyper 1.2.0", + "hyper 1.3.1", "hyper-util", "itoa", "matchit", @@ -298,6 +304,35 @@ dependencies = [ "tower-http", ] +[[package]] +name = "axum-test" +version = "14.9.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4c5dd01c3ff7d926efc6db38bc9a2a5fe82ebe3bf85e44200a7ae3b6bda5f4e5" +dependencies = [ + "anyhow", + "async-trait", + "auto-future", + "axum 0.7.5", + "bytes", + "cookie", + "http 1.1.0", + "http-body-util", + "hyper 1.3.1", + "hyper-util", + "mime", + "pretty_assertions", + "reserve-port", + "rust-multipart-rfc7578_2", + "serde", + "serde_json", + "serde_urlencoded", + "smallvec", + "tokio", + "tower", + "url", +] + [[package]] name = "backtrace" version = "0.3.70" @@ -340,7 +375,7 @@ dependencies = [ "bitflags 2.5.0", "cexpr", "clang-sys", - "itertools 0.11.0", + "itertools 0.12.1", "lazy_static", "lazycell", "log", @@ -453,9 +488,9 @@ checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b" [[package]] name = "bytes" -version = "1.5.0" +version = "1.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a2bd12c1caf447e69cd4528f47f94d203fd2582878ecb9e9465484c4148a8223" +checksum = "514de17de45fdb8dc022b1a7975556c53c86f9f0aa5f534b98977b171857c2c9" [[package]] name = "cast" @@ -493,6 +528,7 @@ version = "0.7.8" dependencies = [ "axum 0.7.5", "axum-prometheus", + "axum-test", "base64 0.22.1", "clap", "config", @@ -708,6 +744,16 @@ dependencies = [ "unicode-segmentation", ] +[[package]] +name = "cookie" +version = "0.18.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4ddef33a339a91ea89fb53151bd0a4689cfce27055c291dfa69945475d22c747" +dependencies = [ + "time", + "version_check", +] + [[package]] name = "core-foundation" version = "0.9.4" @@ -953,6 +999,12 @@ dependencies = [ "serde", ] +[[package]] +name = "diff" +version = "0.1.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "56254986775e3233ffa9c4d7d3faaf6d36a2c09d30b20687e9f88bc8bafc16c8" + [[package]] name = "digest" version = "0.10.7" @@ -980,15 +1032,15 @@ checksum = "1435fa1053d8b2fbbe9be7e97eca7f33d37b28409959813daefc1446a14247f1" [[package]] name = "duration-str" -version = "0.7.1" +version = "0.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a8bb6a301a95ba86fa0ebaf71d49ae4838c51f8b84cb88ed140dfb66452bb3c4" +checksum = "7c1a2e028bbf7921549873b291ddc0cfe08b673d9489da81ac28898cd5a0e6e0" dependencies = [ - "nom", "rust_decimal", "serde", "thiserror", "time", + "winnow 0.6.8", ] [[package]] @@ -1509,9 +1561,9 @@ dependencies = [ [[package]] name = "hyper" -version = "1.2.0" +version = "1.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "186548d73ac615b32a73aafe38fb4f56c0d340e110e5a200bcadbaf2e199263a" +checksum = "fe575dd17d0862a9a33781c8c4696a55c320909004a67a00fb286ba8b1bc496d" dependencies = [ "bytes", "futures-channel", @@ -1548,7 +1600,7 @@ checksum = "70206fc6890eaca9fde8a0bf71caa2ddfc9fe045ac9e5c70df101a7dbde866e0" dependencies = [ "bytes", "http-body-util", - "hyper 1.2.0", + "hyper 1.3.1", "hyper-util", "native-tls", "tokio", @@ -1567,7 +1619,7 @@ dependencies = [ "futures-util", "http 1.1.0", "http-body 1.0.0", - "hyper 1.2.0", + "hyper 1.3.1", "pin-project-lite", "socket2", "tokio", @@ -1740,6 +1792,15 @@ dependencies = [ "either", ] +[[package]] +name = "itertools" +version = "0.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ba291022dbbd398a455acf126c1e341954079855bc60dfdda641363bd6922569" +dependencies = [ + "either", +] + [[package]] name = "itertools" version = "0.13.0" @@ -1915,6 +1976,16 @@ version = "0.3.17" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6877bb514081ee2a7ff5ef9de3281f14a4dd4bceac4c09388074a6b5df8a139a" +[[package]] +name = "mime_guess" +version = "2.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4192263c238a5f0d0c6bfd21f336a313a4ce1c450542449ca191bb657b4642ef" +dependencies = [ + "mime", + "unicase", +] + [[package]] name = "minimal-lexical" version = "0.2.1" @@ -2530,6 +2601,16 @@ dependencies = [ "termtree", ] +[[package]] +name = "pretty_assertions" +version = "1.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "af7cee1a6c8a5b9208b3cb1061f10c0cb689087b3d8ce85fb9d2dd7a29b6ba66" +dependencies = [ + "diff", + "yansi", +] + [[package]] name = "prettyplease" version = "0.2.16" @@ -2646,7 +2727,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "81bddcdb20abf9501610992b6759a4c888aef7d1a7247ef75e2404275ac24af1" dependencies = [ "anyhow", - "itertools 0.11.0", + "itertools 0.12.1", "proc-macro2", "quote", "syn 2.0.53", @@ -2890,7 +2971,7 @@ dependencies = [ "http 1.1.0", "http-body 1.0.0", "http-body-util", - "hyper 1.2.0", + "hyper 1.3.1", "hyper-tls", "hyper-util", "ipnet", @@ -2917,6 +2998,16 @@ dependencies = [ "winreg", ] +[[package]] +name = "reserve-port" +version = "2.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9838134a2bfaa8e1f40738fcc972ac799de6e0e06b5157acb95fc2b05a0ea283" +dependencies = [ + "lazy_static", + "thiserror", +] + [[package]] name = "rgb" version = "0.8.37" @@ -2992,6 +3083,22 @@ dependencies = [ "ordered-multimap", ] +[[package]] +name = "rust-multipart-rfc7578_2" +version = "0.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "03b748410c0afdef2ebbe3685a6a862e2ee937127cdaae623336a459451c8d57" +dependencies = [ + "bytes", + "futures-core", + "futures-util", + "http 0.2.12", + "mime", + "mime_guess", + "rand", + "thiserror", +] + [[package]] name = "rust_decimal" version = "1.34.3" @@ -3786,7 +3893,7 @@ dependencies = [ "serde", "serde_spanned", "toml_datetime", - "winnow 0.6.5", + "winnow 0.6.8", ] [[package]] @@ -3978,6 +4085,15 @@ version = "0.1.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "eaea85b334db583fe3274d12b4cd1880032beab409c0d774be044d4480ab9a94" +[[package]] +name = "unicase" +version = "2.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f7d2d4dafb69621809a81864c9c1b864479e1235c0dd4e199924b9742439ed89" +dependencies = [ + "version_check", +] + [[package]] name = "unicode-bidi" version = "0.3.15" @@ -4427,9 +4543,9 @@ dependencies = [ [[package]] name = "winnow" -version = "0.6.5" +version = "0.6.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dffa400e67ed5a4dd237983829e66475f0a4a26938c4b04c21baede6262215b8" +checksum = "c3c52e9c97a68071b23e836c9380edae937f17b9c4667bd021973efc689f618d" dependencies = [ "memchr", ] @@ -4462,6 +4578,12 @@ dependencies = [ "linked-hash-map", ] +[[package]] +name = "yansi" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09041cd90cf85f7f8b2df60c646f853b7f535ce68f85244eb6731cf89fa498ec" + [[package]] name = "yasna" version = "0.5.2" diff --git a/certifier/Cargo.toml b/certifier/Cargo.toml index 180f1c7e..24bca737 100644 --- a/certifier/Cargo.toml +++ b/certifier/Cargo.toml @@ -26,8 +26,8 @@ rand = "0.8.5" serde_json = "1.0.117" base64 = "0.22.1" axum-prometheus = "0.6.1" -tower = { version = "0.4.13", features = ["limit"] } -duration-str = { version = "0.7.1", default-features = false, features = [ +tower = { version = "0.4.13", features = ["limit", "load-shed", "buffer"] } +duration-str = { version = "0.10.0", default-features = false, features = [ "serde", "time", ] } @@ -35,5 +35,6 @@ parity-scale-codec = { version = "3.6.12", features = ["derive", "serde"] } mockall = "0.12.1" [dev-dependencies] +axum-test = "14.9.1" reqwest = { version = "0.12.4", features = ["json"] } tempfile = "3.8.1" diff --git a/certifier/README.md b/certifier/README.md index 5ba4539f..4e146535 100644 --- a/certifier/README.md +++ b/certifier/README.md @@ -45,6 +45,15 @@ init_cfg: metrics: "127.0.0.1:9090" randomx_mode: Fast + +limits: + # How many requests can be processed in parallel. + # As PoST verification is CPU-bound, it defaults to the number of CPUs. + max_concurrent_requests: 4 + # How many requests can be queued, waiting to be processed. + max_pending_requests: 1000 + # Maximum size of request body (the proof JSON) + max_body_size: 1024 ``` Each field can also be provided as env variable prefixed with CERTIFIER. For example, `CERTIFIER_SIGNING_KEY`. diff --git a/certifier/src/certifier.rs b/certifier/src/certifier.rs index 6b132385..f93b05b4 100644 --- a/certifier/src/certifier.rs +++ b/certifier/src/certifier.rs @@ -1,7 +1,11 @@ use std::sync::Arc; use std::time::{Duration, SystemTime}; +use axum::error_handling::HandleErrorLayer; +use axum::extract::DefaultBodyLimit; use axum::http::StatusCode; +use axum::response::{IntoResponse, Response}; +use axum::BoxError; use axum::{extract::State, Json}; use axum::{routing::post, Router}; use ed25519_dalek::{Signature, Signer, SigningKey}; @@ -11,9 +15,14 @@ use post::pow::randomx::PoW; use post::verification::Mode; use serde::{Deserialize, Serialize}; use serde_with::{base64::Base64, serde_as}; +use tower::buffer::BufferLayer; +use tower::limit::ConcurrencyLimitLayer; +use tower::load_shed::error::Overloaded; +use tower::load_shed::LoadShedLayer; +use tower::ServiceBuilder; use tracing::instrument; -use crate::configuration::RandomXMode; +use crate::configuration::{Limits, RandomXMode}; use crate::time::unix_timestamp; #[derive(Debug, Deserialize, Serialize)] @@ -161,6 +170,32 @@ pub fn new( .with_state(Arc::new(certifier)) } +pub trait RouterLimiter { + fn apply_limits(self, limits: Limits) -> Self; +} + +impl RouterLimiter for Router { + fn apply_limits(self, limits: Limits) -> Self { + self.layer( + ServiceBuilder::new() + .layer(DefaultBodyLimit::max(limits.max_body_size)) + .layer(HandleErrorLayer::new(handle_error)) + .layer(LoadShedLayer::new()) + .layer(BufferLayer::new(limits.max_pending_requests)) + .layer(ConcurrencyLimitLayer::new(limits.max_concurrent_requests)) + .into_inner(), + ) + } +} + +async fn handle_error(error: BoxError) -> Response { + if error.is::() { + StatusCode::TOO_MANY_REQUESTS.into_response() + } else { + StatusCode::INTERNAL_SERVER_ERROR.into_response() + } +} + #[cfg(test)] mod tests { use std::{ @@ -168,13 +203,14 @@ mod tests { time::{Duration, SystemTime}, }; - use crate::time::unix_timestamp; + use crate::{certifier::RouterLimiter, configuration::Limits, time::unix_timestamp}; use super::{Certificate, Certifier, MockVerifier}; + use axum::{body::Bytes, routing::post, Router}; + use axum_test::TestServer; use ed25519_dalek::SigningKey; use parity_scale_codec::Decode; use post::{metadata::ProofMetadata, prove::Proof}; - #[test] fn certify_invalid_post() { let mut verifier = MockVerifier::new(); @@ -259,4 +295,20 @@ mod tests { assert!(expiration >= unix_timestamp(started + expiry)); assert!(expiration <= unix_timestamp(SystemTime::now() + expiry)); } + + #[tokio::test] + async fn limit_max_body_size() { + let my_app = Router::new() + .route("/", post(|_: Bytes| async {})) + .apply_limits(Limits { + max_concurrent_requests: 1, + max_pending_requests: 1, + max_body_size: 5, + }); + + let server = TestServer::new(my_app).unwrap(); + + let response = server.post("/").text("i'm a very long text").await; + assert_eq!(response.status_code(), 413); + } } diff --git a/certifier/src/configuration.rs b/certifier/src/configuration.rs index 6b17f0ac..b9a9a13a 100644 --- a/certifier/src/configuration.rs +++ b/certifier/src/configuration.rs @@ -39,10 +39,7 @@ pub struct Config { /// The address to listen on for incoming requests. pub listen: std::net::SocketAddr, - /// The maximum number of requests to process in parallel. - /// Typically set to the number of cores, which is the default (if not set). - #[serde(default = "max_concurrency")] - pub max_concurrent_requests: usize, + pub limits: Limits, #[serde_as(as = "Base64")] /// The base64-encoded secret key used to sign the proofs. @@ -66,6 +63,26 @@ pub struct Config { pub metrics: Option, } +#[derive(Debug, serde::Deserialize, Clone)] +pub struct Limits { + /// The maximum number of requests to process in parallel. + /// Typically set to the number of cores, which is the default (if not set). + #[serde(default = "max_concurrency")] + pub max_concurrent_requests: usize, + + /// The maximum number of requests to queue up before rejecting new requests. + /// Rejected requests have 429 TOO_MANY_REQUESTS status code. + pub max_pending_requests: usize, + + /// The maximum size of a request body (the proof JSON) + #[serde(default = "default_max_body")] + pub max_body_size: usize, +} + +fn default_max_body() -> usize { + 1024 +} + pub fn get_configuration(config_path: &Path) -> Result { info!("loading configuration from {config_path:?}"); diff --git a/certifier/src/main.rs b/certifier/src/main.rs index 1e3f3463..72ab428d 100644 --- a/certifier/src/main.rs +++ b/certifier/src/main.rs @@ -3,10 +3,10 @@ use std::{future::IntoFuture, path::PathBuf}; use axum::routing::get; use axum_prometheus::PrometheusMetricLayerBuilder; use base64::{engine::general_purpose, Engine as _}; +use certifier::certifier::RouterLimiter; use clap::{arg, Parser, Subcommand}; use ed25519_dalek::SigningKey; use tokio::net::TcpListener; -use tower::limit::ConcurrencyLimitLayer; use tracing::info; use tracing_log::LogTracer; use tracing_subscriber::{EnvFilter, FmtSubscriber}; @@ -77,10 +77,7 @@ async fn main() -> Result<(), Box> { info!("POST proof configuration: {:?}", config.post_cfg); info!("POST init configuration: {:?}", config.init_cfg); info!("RandomX mode: {:?}", config.randomx_mode); - info!( - "max concurrent requests: {}", - config.max_concurrent_requests - ); + info!("{:?}", config.limits); if let Some(expiry) = config.certificate_expiration { info!("generated certificates will expire after {expiry:?}"); } else { @@ -94,7 +91,7 @@ async fn main() -> Result<(), Box> { config.randomx_mode, config.certificate_expiration, ) - .layer(ConcurrencyLimitLayer::new(config.max_concurrent_requests)); + .apply_limits(config.limits); if let Some(addr) = config.metrics { info!("metrics enabled on: http://{addr:?}/metrics");