From c2ce5c1752eb6000ec7a66f354fb8648bf637028 Mon Sep 17 00:00:00 2001 From: Isawan Millican Date: Sat, 22 Jun 2024 23:34:00 +0100 Subject: [PATCH] fix: Implemented graceful shutdown Wait until all in-flight requests are responded to before shutdown --- src/lib.rs | 84 +++++++++++++++++++++++++++++++----------------------- 1 file changed, 49 insertions(+), 35 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index ca93fdc8..941d8e8d 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -12,6 +12,7 @@ use aws_config::BehaviorVersion; use axum::extract::Request; use axum_prometheus::metrics_exporter_prometheus::PrometheusHandle; use config::{Args, ServerArgs}; +use futures::join; use hyper::body::Incoming; use hyper_util::{ rt::{TokioExecutor, TokioIo}, @@ -25,10 +26,11 @@ use tokio::{ net::TcpListener, select, sync::{mpsc, oneshot::Sender}, + task::JoinSet, }; use tokio_util::sync::CancellationToken; use tower::Service; -use tracing::error; +use tracing::{error, warn}; use crate::{ credhelper::database::DatabaseCredentials, refresh::refresher, registry::RegistryClient, @@ -39,35 +41,51 @@ pub struct StartUpNotify { pub msg: T, } -async fn serve(listener: TcpListener, app: axum::Router) { +async fn serve(listener: TcpListener, cancel: CancellationToken, app: axum::Router) { + let mut join_set = JoinSet::new(); loop { - let (socket, _remote_addr) = listener.accept().await.unwrap(); - - let tower_service = app.clone(); - - // Spawn a task to handle the connection. That way we can multiple connections - // concurrently. - tokio::spawn(async move { - let socket = TokioIo::new(socket); - - // Hyper also has its own `Service` trait and doesn't use tower. We can use - // `hyper::service::service_fn` to create a hyper `Service` that calls our app through - // `tower::Service::call`. - let hyper_service = hyper::service::service_fn(move |request: Request| { - // We have to clone `tower_service` because hyper's `Service` uses `&self` whereas - // tower's `Service` requires `&mut self`. - // - // We don't need to call `poll_ready` since `Router` is always ready. - tower_service.clone().call(request) - }); - - if let Err(err) = server::conn::auto::Builder::new(TokioExecutor::new()) - .serve_connection(socket, hyper_service) - .await - { - tracing::warn!("failed to serve connection: {err:#}"); + select! { + result = listener.accept() => { + let (socket, _remote_addr) = match result { + Ok(x) => x, + Err(err) => { + warn!(reason = %err, "failed to accept connection"); + continue; + } + }; + let tower_service = app.clone(); + + // Spawn a task to handle the connection. That way we can multiple connections + // concurrently. + join_set.spawn(async move { + let socket = TokioIo::new(socket); + + // Hyper also has its own `Service` trait and doesn't use tower. We can use + // `hyper::service::service_fn` to create a hyper `Service` that calls our app through + // `tower::Service::call`. + let hyper_service = hyper::service::service_fn(move |request: Request| { + // We have to clone `tower_service` because hyper's `Service` uses `&self` whereas + // tower's `Service` requires `&mut self`. + // + // We don't need to call `poll_ready` since `Router` is always ready. + tower_service.clone().call(request) + }); + + if let Err(err) = server::conn::auto::Builder::new(TokioExecutor::new()) + .serve_connection(socket, hyper_service) + .await + { + tracing::warn!("failed to serve connection: {err:#}"); + } + }); + while join_set.try_join_next().is_some() {} // clean up any completed requests } - }); + + _ = cancel.cancelled() => { + break; + } + } + while (join_set.join_next().await).is_some() {} // wait until all requests are done } } @@ -106,7 +124,7 @@ pub async fn setup_server( // path style required for minio to work // Set up AWS SDK - let aws_config = aws_config::defaults(BehaviorVersion::v2023_11_09()) + let aws_config = aws_config::defaults(BehaviorVersion::v2024_03_28()) .load() .await; let mut s3_config = aws_sdk_s3::config::Builder::from(&aws_config).force_path_style(true); @@ -183,17 +201,13 @@ pub async fn run_server( let listener = TcpListener::bind(&bind_addr).await.unwrap(); let local_addr = listener.local_addr().unwrap(); - let server = serve(listener, app); + let server = serve(listener, cancel.child_token(), app); startup .send(StartUpNotify { msg: local_addr }) .expect("Sender channel has already been used"); - select! { - _ = server => (), - _ = refresher => (), - _ = cancel.cancelled() => tracing::trace!("Cancellation requested"), - } + join!(server, refresher); tracing::debug!("Shutting down server"); Ok(()) }