Skip to content

Commit

Permalink
Add graceful shutdown
Browse files Browse the repository at this point in the history
  • Loading branch information
tiberiuv committed Feb 23, 2025
1 parent 145b85e commit 224d866
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 23 deletions.
29 changes: 18 additions & 11 deletions src/api.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use std::net::{IpAddr, SocketAddr};
use std::time::Duration;

use crate::crowdsec::CrowdsecAppsecApi;
use axum::extract::{ConnectInfo, FromRef, FromRequestParts, Request, State};
Expand All @@ -9,6 +10,7 @@ use axum::routing::{get, MethodRouter};
use axum::{async_trait, Json, RequestPartsExt, Router};
use ipnet::IpNet;
use reqwest::StatusCode;
use tower_http::timeout::TimeoutLayer;
use tower_http::trace::TraceLayer;

use crate::App;
Expand Down Expand Up @@ -45,7 +47,7 @@ where
?remote_client_ip,
"Received request from untrusted ip rejecting...",
);
return Err((StatusCode::FORBIDDEN, "Forbidden"));
return Err((StatusCode::FORBIDDEN, "Request blocked!"));
}

let real_client_ip = get_client_ip_x_forwarded_for(
Expand Down Expand Up @@ -149,7 +151,7 @@ async fn check_ip(

if app.blacklist.contains(real_client_ip) {
tracing::info!(real_client_ip = real_client_ip.to_string(), "Ip is banned!");
return StatusCode::FORBIDDEN.into_response();
return (StatusCode::FORBIDDEN, "Request blocked!").into_response();
}

let result = app
Expand All @@ -158,12 +160,12 @@ async fn check_ip(
.await;
match result {
Ok(is_ok) => if is_ok {
StatusCode::OK
StatusCode::OK.into_response()
} else {
StatusCode::FORBIDDEN
(StatusCode::FORBIDDEN, "Request blocked!").into_response()
}
.into_response(),
Err(_err) => StatusCode::FORBIDDEN.into_response(),
Err(_err) => (StatusCode::FORBIDDEN, "Request blocked!").into_response(),
}
}

Expand Down Expand Up @@ -207,19 +209,24 @@ fn api_server_router(state: App) -> Router {
}
}),
)
.layer(TimeoutLayer::new(Duration::from_secs(5)))
.with_state(state)
}

pub async fn api_server_listen(state: App, socket_addr: SocketAddr) -> std::io::Result<()> {
pub async fn api_server_listen(
state: App,
socket_addr: SocketAddr,
handle: axum_server::Handle,
) -> std::io::Result<()> {
let router = api_server_router(state);

tracing::info!(listen = ?socket_addr, "Starting API server");
let listener = tokio::net::TcpListener::bind(socket_addr).await.unwrap();
axum::serve(
listener,
router.into_make_service_with_connect_info::<SocketAddr>(),
)
.await

axum_server::from_tcp(listener.into_std()?)
.handle(handle)
.serve(router.into_make_service_with_connect_info::<SocketAddr>())
.await
}

#[cfg(test)]
Expand Down
48 changes: 36 additions & 12 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use std::sync::LazyLock;

use clap::Parser;
use rustls::crypto::CryptoProvider;
use tokio::signal;
use tracing::info;
use waf_bouncer::api::api_server_listen;
use waf_bouncer::cli::ClientCerts;
Expand All @@ -12,6 +13,33 @@ use waf_bouncer::{

pub static BLACKLIST_CACHE: LazyLock<BlacklistCache> = LazyLock::new(Default::default);

async fn shutdown_signal(handle: axum_server::Handle) {
let ctrl_c = async {
signal::ctrl_c()
.await
.expect("failed to install Ctrl+C handler");
};

#[cfg(unix)]
let terminate = async {
signal::unix::signal(signal::unix::SignalKind::terminate())
.expect("failed to install signal handler")
.recv()
.await;
};

#[cfg(not(unix))]
let terminate = std::future::pending::<()>();

tokio::select! {
_ = ctrl_c => {},
_ = terminate => {},
}

tracing::info!("Received termination signal shutting down");
handle.graceful_shutdown(Some(std::time::Duration::from_secs(10)));
}

#[tokio::main]
async fn main() -> anyhow::Result<()> {
CryptoProvider::install_default(rustls::crypto::aws_lc_rs::default_provider())
Expand Down Expand Up @@ -47,20 +75,16 @@ async fn main() -> anyhow::Result<()> {
};
info!(?app.config, "config");

let mut task_set = tokio::task::JoinSet::new();
let app_clone = app.clone();
task_set.spawn(async move { reconcile(app_clone).await });
task_set.spawn(async move {
api_server_listen(app, cli.listen_addr)
.await
.map_err(anyhow::Error::new)
});
let handle = axum_server::Handle::new();
let shutdown_future = shutdown_signal(handle.clone());

while let Some(res) = task_set.join_next().await {
res??;
tokio::select! {
_ = reconcile(app.clone()) => {}
_ = api_server_listen(app, cli.listen_addr, handle) => {}
_ = shutdown_future => {
info!("Exit")
}
}

info!("Exit!");

Ok(())
}

0 comments on commit 224d866

Please sign in to comment.