Skip to content

Commit e29eb01

Browse files
committed
feat: use custom axum-server wrapper with timeouts
- Use axum-server isntead of directly the axum crate (like in the tracker). - Add wrapper to axum-server to enable timeouts.
1 parent b948573 commit e29eb01

File tree

4 files changed

+396
-26
lines changed

4 files changed

+396
-26
lines changed

src/web/api/mod.rs

-6
Original file line numberDiff line numberDiff line change
@@ -30,12 +30,6 @@ pub struct Running {
3030
pub api_server: Option<JoinHandle<Result<(), std::io::Error>>>,
3131
}
3232

33-
#[must_use]
34-
#[derive(Debug)]
35-
pub struct ServerStartedMessage {
36-
pub socket_addr: SocketAddr,
37-
}
38-
3933
/// Starts the API server.
4034
#[must_use]
4135
pub async fn start(app_data: Arc<AppData>, net_ip: &str, net_port: u16, implementation: &Version) -> api::Running {

src/web/api/server/custom_axum.rs

+275
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,275 @@
1+
//! Wrapper for Axum server to add timeouts.
2+
//!
3+
//! Copyright (c) Eray Karatay ([@programatik29](https://github.com/programatik29)).
4+
//!
5+
//! See: <https://gist.github.com/programatik29/36d371c657392fd7f322e7342957b6d1>.
6+
//!
7+
//! If a client opens a HTTP connection and it does not send any requests, the
8+
//! connection is closed after a timeout. You can test it with:
9+
//!
10+
//! ```text
11+
//! telnet 127.0.0.1 1212
12+
//! Trying 127.0.0.1...
13+
//! Connected to 127.0.0.1.
14+
//! Escape character is '^]'.
15+
//! Connection closed by foreign host.
16+
//! ```
17+
//!
18+
//! If you want to know more about Axum and timeouts see <https://github.com/josecelano/axum-server-timeout>.
19+
use std::future::Ready;
20+
use std::io::ErrorKind;
21+
use std::net::TcpListener;
22+
use std::pin::Pin;
23+
use std::task::{Context, Poll};
24+
use std::time::Duration;
25+
26+
use axum_server::accept::Accept;
27+
use axum_server::tls_rustls::{RustlsAcceptor, RustlsConfig};
28+
use axum_server::Server;
29+
use futures_util::{ready, Future};
30+
use http_body::{Body, Frame};
31+
use hyper::Response;
32+
use hyper_util::rt::TokioTimer;
33+
use pin_project_lite::pin_project;
34+
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
35+
use tokio::sync::mpsc::{self, UnboundedReceiver, UnboundedSender};
36+
use tokio::time::{Instant, Sleep};
37+
use tower::Service;
38+
39+
const HTTP1_HEADER_READ_TIMEOUT: Duration = Duration::from_secs(5);
40+
const HTTP2_KEEP_ALIVE_TIMEOUT: Duration = Duration::from_secs(5);
41+
const HTTP2_KEEP_ALIVE_INTERVAL: Duration = Duration::from_secs(5);
42+
43+
#[must_use]
44+
pub fn from_tcp_with_timeouts(socket: TcpListener) -> Server {
45+
add_timeouts(axum_server::from_tcp(socket))
46+
}
47+
48+
#[must_use]
49+
pub fn from_tcp_rustls_with_timeouts(socket: TcpListener, tls: RustlsConfig) -> Server<RustlsAcceptor> {
50+
add_timeouts(axum_server::from_tcp_rustls(socket, tls))
51+
}
52+
53+
fn add_timeouts<A>(mut server: Server<A>) -> Server<A> {
54+
server.http_builder().http1().timer(TokioTimer::new());
55+
server.http_builder().http2().timer(TokioTimer::new());
56+
57+
server.http_builder().http1().header_read_timeout(HTTP1_HEADER_READ_TIMEOUT);
58+
server
59+
.http_builder()
60+
.http2()
61+
.keep_alive_timeout(HTTP2_KEEP_ALIVE_TIMEOUT)
62+
.keep_alive_interval(HTTP2_KEEP_ALIVE_INTERVAL);
63+
64+
server
65+
}
66+
67+
#[derive(Clone)]
68+
pub struct TimeoutAcceptor;
69+
70+
impl<I, S> Accept<I, S> for TimeoutAcceptor {
71+
type Stream = TimeoutStream<I>;
72+
type Service = TimeoutService<S>;
73+
type Future = Ready<std::io::Result<(Self::Stream, Self::Service)>>;
74+
75+
fn accept(&self, stream: I, service: S) -> Self::Future {
76+
let (tx, rx) = mpsc::unbounded_channel();
77+
78+
let stream = TimeoutStream::new(stream, HTTP1_HEADER_READ_TIMEOUT, rx);
79+
let service = TimeoutService::new(service, tx);
80+
81+
std::future::ready(Ok((stream, service)))
82+
}
83+
}
84+
85+
#[derive(Clone)]
86+
pub struct TimeoutService<S> {
87+
inner: S,
88+
sender: UnboundedSender<TimerSignal>,
89+
}
90+
91+
impl<S> TimeoutService<S> {
92+
fn new(inner: S, sender: UnboundedSender<TimerSignal>) -> Self {
93+
Self { inner, sender }
94+
}
95+
}
96+
97+
impl<S, B, Request> Service<Request> for TimeoutService<S>
98+
where
99+
S: Service<Request, Response = Response<B>>,
100+
{
101+
type Response = Response<TimeoutBody<B>>;
102+
type Error = S::Error;
103+
type Future = TimeoutServiceFuture<S::Future>;
104+
105+
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
106+
self.inner.poll_ready(cx)
107+
}
108+
109+
fn call(&mut self, req: Request) -> Self::Future {
110+
// send timer wait signal
111+
let _ = self.sender.send(TimerSignal::Wait);
112+
113+
TimeoutServiceFuture::new(self.inner.call(req), self.sender.clone())
114+
}
115+
}
116+
117+
pin_project! {
118+
pub struct TimeoutServiceFuture<F> {
119+
#[pin]
120+
inner: F,
121+
sender: Option<UnboundedSender<TimerSignal>>,
122+
}
123+
}
124+
125+
impl<F> TimeoutServiceFuture<F> {
126+
fn new(inner: F, sender: UnboundedSender<TimerSignal>) -> Self {
127+
Self {
128+
inner,
129+
sender: Some(sender),
130+
}
131+
}
132+
}
133+
134+
impl<F, B, E> Future for TimeoutServiceFuture<F>
135+
where
136+
F: Future<Output = Result<Response<B>, E>>,
137+
{
138+
type Output = Result<Response<TimeoutBody<B>>, E>;
139+
140+
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
141+
let this = self.project();
142+
this.inner.poll(cx).map(|result| {
143+
result.map(|response| {
144+
response.map(|body| TimeoutBody::new(body, this.sender.take().expect("future polled after ready")))
145+
})
146+
})
147+
}
148+
}
149+
150+
enum TimerSignal {
151+
Wait,
152+
Reset,
153+
}
154+
155+
pin_project! {
156+
pub struct TimeoutBody<B> {
157+
#[pin]
158+
inner: B,
159+
sender: UnboundedSender<TimerSignal>,
160+
}
161+
}
162+
163+
impl<B> TimeoutBody<B> {
164+
fn new(inner: B, sender: UnboundedSender<TimerSignal>) -> Self {
165+
Self { inner, sender }
166+
}
167+
}
168+
169+
impl<B: Body> Body for TimeoutBody<B> {
170+
type Data = B::Data;
171+
type Error = B::Error;
172+
173+
fn poll_frame(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
174+
let this = self.project();
175+
let option = ready!(this.inner.poll_frame(cx));
176+
177+
if option.is_none() {
178+
let _ = this.sender.send(TimerSignal::Reset);
179+
}
180+
181+
Poll::Ready(option)
182+
}
183+
184+
fn is_end_stream(&self) -> bool {
185+
let is_end_stream = self.inner.is_end_stream();
186+
187+
if is_end_stream {
188+
let _ = self.sender.send(TimerSignal::Reset);
189+
}
190+
191+
is_end_stream
192+
}
193+
194+
fn size_hint(&self) -> http_body::SizeHint {
195+
self.inner.size_hint()
196+
}
197+
}
198+
199+
pub struct TimeoutStream<IO> {
200+
inner: IO,
201+
// hyper requires unpin
202+
sleep: Pin<Box<Sleep>>,
203+
duration: Duration,
204+
waiting: bool,
205+
receiver: UnboundedReceiver<TimerSignal>,
206+
finished: bool,
207+
}
208+
209+
impl<IO> TimeoutStream<IO> {
210+
fn new(inner: IO, duration: Duration, receiver: UnboundedReceiver<TimerSignal>) -> Self {
211+
Self {
212+
inner,
213+
sleep: Box::pin(tokio::time::sleep(duration)),
214+
duration,
215+
waiting: false,
216+
receiver,
217+
finished: false,
218+
}
219+
}
220+
}
221+
222+
impl<IO: AsyncRead + Unpin> AsyncRead for TimeoutStream<IO> {
223+
fn poll_read(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>) -> Poll<std::io::Result<()>> {
224+
if !self.finished {
225+
match Pin::new(&mut self.receiver).poll_recv(cx) {
226+
// reset the timer
227+
Poll::Ready(Some(TimerSignal::Reset)) => {
228+
self.waiting = false;
229+
230+
let deadline = Instant::now() + self.duration;
231+
self.sleep.as_mut().reset(deadline);
232+
}
233+
// enter waiting mode (for response body last chunk)
234+
Poll::Ready(Some(TimerSignal::Wait)) => self.waiting = true,
235+
Poll::Ready(None) => self.finished = true,
236+
Poll::Pending => (),
237+
}
238+
}
239+
240+
if !self.waiting {
241+
// return error if timer is elapsed
242+
if let Poll::Ready(()) = self.sleep.as_mut().poll(cx) {
243+
return Poll::Ready(Err(std::io::Error::new(ErrorKind::TimedOut, "request header read timed out")));
244+
}
245+
}
246+
247+
Pin::new(&mut self.inner).poll_read(cx, buf)
248+
}
249+
}
250+
251+
impl<IO: AsyncWrite + Unpin> AsyncWrite for TimeoutStream<IO> {
252+
fn poll_write(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<std::io::Result<usize>> {
253+
Pin::new(&mut self.inner).poll_write(cx, buf)
254+
}
255+
256+
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
257+
Pin::new(&mut self.inner).poll_flush(cx)
258+
}
259+
260+
fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
261+
Pin::new(&mut self.inner).poll_shutdown(cx)
262+
}
263+
264+
fn poll_write_vectored(
265+
mut self: Pin<&mut Self>,
266+
cx: &mut Context<'_>,
267+
bufs: &[std::io::IoSlice<'_>],
268+
) -> Poll<Result<usize, std::io::Error>> {
269+
Pin::new(&mut self.inner).poll_write_vectored(cx, bufs)
270+
}
271+
272+
fn is_write_vectored(&self) -> bool {
273+
self.inner.is_write_vectored()
274+
}
275+
}

src/web/api/server/mod.rs

+33-20
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,20 @@
1+
pub mod custom_axum;
2+
pub mod signals;
13
pub mod v1;
24

35
use std::net::SocketAddr;
46
use std::sync::Arc;
57

8+
use axum_server::Handle;
69
use log::info;
7-
use tokio::net::TcpListener;
8-
use tokio::sync::oneshot::{self, Sender};
10+
use tokio::sync::oneshot::{Receiver, Sender};
911
use v1::routes::router;
1012

11-
use super::{Running, ServerStartedMessage};
13+
use self::signals::{Halted, Started};
14+
use super::Running;
1215
use crate::common::AppData;
16+
use crate::web::api::server::custom_axum::TimeoutAcceptor;
17+
use crate::web::api::server::signals::graceful_shutdown;
1318

1419
/// Starts the API server.
1520
///
@@ -21,13 +26,14 @@ pub async fn start(app_data: Arc<AppData>, net_ip: &str, net_port: u16) -> Runni
2126
.parse()
2227
.expect("API server socket address to be valid.");
2328

24-
let (tx, rx) = oneshot::channel::<ServerStartedMessage>();
29+
let (tx_start, rx) = tokio::sync::oneshot::channel::<Started>();
30+
let (_tx_halt, rx_halt) = tokio::sync::oneshot::channel::<Halted>();
2531

2632
// Run the API server
2733
let join_handle = tokio::spawn(async move {
2834
info!("Starting API server with net config: {} ...", config_socket_addr);
2935

30-
start_server(config_socket_addr, app_data.clone(), tx).await;
36+
start_server(config_socket_addr, app_data.clone(), tx_start, rx_halt).await;
3137

3238
info!("API server stopped");
3339

@@ -46,27 +52,34 @@ pub async fn start(app_data: Arc<AppData>, net_ip: &str, net_port: u16) -> Runni
4652
}
4753
}
4854

49-
async fn start_server(config_socket_addr: SocketAddr, app_data: Arc<AppData>, tx: Sender<ServerStartedMessage>) {
50-
let tcp_listener = TcpListener::bind(config_socket_addr)
51-
.await
52-
.expect("tcp listener to bind to a socket address");
55+
async fn start_server(
56+
config_socket_addr: SocketAddr,
57+
app_data: Arc<AppData>,
58+
tx_start: Sender<Started>,
59+
rx_halt: Receiver<Halted>,
60+
) {
61+
let router = router(app_data);
62+
let socket = std::net::TcpListener::bind(config_socket_addr).expect("Could not bind tcp_listener to address.");
63+
let address = socket.local_addr().expect("Could not get local_addr from tcp_listener.");
5364

54-
let bound_addr = tcp_listener
55-
.local_addr()
56-
.expect("tcp listener to be bound to a socket address.");
65+
let handle = Handle::new();
5766

58-
info!("API server listening on http://{}", bound_addr); // # DevSkim: ignore DS137138
67+
tokio::task::spawn(graceful_shutdown(
68+
handle.clone(),
69+
rx_halt,
70+
format!("Shutting down API server on socket address: {address}"),
71+
));
5972

60-
let app = router(app_data);
73+
info!("API server listening on http://{}", address); // # DevSkim: ignore DS137138
6174

62-
tx.send(ServerStartedMessage { socket_addr: bound_addr })
75+
tx_start
76+
.send(Started { socket_addr: address })
6377
.expect("the API server should not be dropped");
6478

65-
axum::serve(tcp_listener, app.into_make_service_with_connect_info::<SocketAddr>())
66-
.with_graceful_shutdown(async move {
67-
tokio::signal::ctrl_c().await.expect("Failed to listen to shutdown signal.");
68-
info!("Stopping API server on http://{} ...", bound_addr); // # DevSkim: ignore DS137138
69-
})
79+
custom_axum::from_tcp_with_timeouts(socket)
80+
.handle(handle)
81+
.acceptor(TimeoutAcceptor)
82+
.serve(router.into_make_service_with_connect_info::<std::net::SocketAddr>())
7083
.await
7184
.expect("API server should be running");
7285
}

0 commit comments

Comments
 (0)