From 541d7bbbc836bc50c2cbb96b1c4c26e9ef0b7cee Mon Sep 17 00:00:00 2001 From: Icelk Date: Wed, 19 Jul 2023 17:41:45 +0200 Subject: [PATCH] HTTP/3 support. Uses the `h3` crate. Quinn for the QUIC stack. - Adds alt-svc headers to responses - Removed Async{Read, Write} for ResponseBodyPipe, since they are inefficient for HTTP/2&/3 implementations. - Redesigned reverse proxy & WebSocket. - send_response, etc, all take ownership. This is elegant. - WebSocket fails for HTTP/2. Before it silently errored. TODO: - WebSocket / WebTransport? support for HTTP/2 & HTTP/3 - io_uring support for HTTP/3 --- Cargo.toml | 11 +- extensions/src/push.rs | 12 +- extensions/src/reverse-proxy.rs | 101 +++++---- roadmap.md | 27 +-- src/application.rs | 372 ++++++++++++++++++-------------- src/extensions.rs | 41 ++-- src/host.rs | 12 +- src/lib.rs | 336 +++++++++++++++++++++-------- src/prelude.rs | 8 +- src/shutdown.rs | 119 +++++++--- src/websocket.rs | 79 ++++++- 11 files changed, 733 insertions(+), 385 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 41ee4d5..8ea3f75 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -34,6 +34,7 @@ bytes = "1" compact_str = "0.7.0" log = "0.4" time = { version = "0.3", features = ["parsing", "formatting", "macros"] } +socket2 = { version = "0.5.3", optional = true, features = ["all"] } h2 = { version = "0.3.17", default-features = false, optional = true } http = "0.2" @@ -74,6 +75,11 @@ tokio-tungstenite = { version = "0.19", optional = true, default-features = fals sha-1 = { version = "0.10", optional = true } futures-util = { version = "0.3", optional = true, default-features = false, features = ["sink"] } +# HTTP/3 +h3 = { version = "0.0.2", optional = true } +h3-quinn = { version = "0.0.3", optional = true } +quinn = { version = "0.10.1", default-features = false, features = ["tls-rustls", "log", "runtime-tokio"], optional = true } + [target.'cfg(unix)'.dependencies] libc = { version = "0.2", default-features = false } @@ -92,9 +98,10 @@ br = ["brotli"] gzip = ["flate2"] # HTTP standards -all-http = ["https", "http2"] +all-http = ["https", "http2", "http3"] https = ["rustls", "rustls-pemfile", "webpki", "async-networking"] http2 = ["h2", "https"] +http3 = ["h3", "h3-quinn", "quinn", "https"] # Graceful shutdown; shutdown.rs graceful-shutdown = ["handover"] @@ -111,7 +118,7 @@ nonce = ["rand", "base64", "memchr"] websocket = ["tokio-tungstenite", "sha-1", "base64", "futures-util"] # Use tokio's async networking instead of the blocking variant. -async-networking = ["tokio/net"] +async-networking = ["tokio/net", "socket2"] uring = ["tokio-uring", "kvarn_signal/uring", "async-networking"] diff --git a/extensions/src/push.rs b/extensions/src/push.rs index 06fbf2d..87289ab 100644 --- a/extensions/src/push.rs +++ b/extensions/src/push.rs @@ -50,7 +50,7 @@ pub fn mount(extensions: &mut Extensions, manager: SmartPush) -> &mut Extensions pub fn always<'a>( request: &'a FatRequest, host: &'a Host, - response_pipe: &'a mut application::ResponsePipe, + response_pipe: &'a mut application::ResponseBodyPipe, bytes: Bytes, addr: SocketAddr, ) -> RetFut<'a, ()> { @@ -104,7 +104,7 @@ impl Default for SmartPush { async fn push<'a>( request: &'a FatRequest, host: &'a Host, - response_pipe: &'a mut application::ResponsePipe, + response_pipe: &'a mut application::ResponseBodyPipe, bytes: Bytes, addr: SocketAddr, manager: Option<&'a Mutex>, @@ -113,9 +113,9 @@ async fn push<'a>( // let request = unsafe { request.get_inner() }; // let response_pipe = unsafe { response_pipe.get_inner() }; - // If it is not HTTP/1 + // If it is not HTTP/2 #[allow(irrefutable_let_patterns)] - if let ResponsePipe::Http1(_) = &response_pipe { + if !matches!(response_pipe, ResponseBodyPipe::Http2(_, _)) { return; } @@ -211,7 +211,7 @@ async fn push<'a>( let empty_request = utils::empty_clone_request(&push_request); - let mut response_pipe = match response_pipe.push_request(empty_request) { + let response_pipe = match response_pipe.push_request(empty_request) { Ok(pipe) => pipe, Err(_) => return, }; @@ -221,7 +221,7 @@ async fn push<'a>( let response = kvarn::handle_cache(&mut push_request, addr, host).await; - if let Err(err) = kvarn::SendKind::Push(&mut response_pipe) + if let Err(err) = kvarn::SendKind::Push(response_pipe) .send(response, request, host, addr) .await { diff --git a/extensions/src/reverse-proxy.rs b/extensions/src/reverse-proxy.rs index 92b234d..98b8b0f 100644 --- a/extensions/src/reverse-proxy.rs +++ b/extensions/src/reverse-proxy.rs @@ -17,22 +17,6 @@ pub mod async_bits { } }; } - macro_rules! ret_ready_err { - ($poll: expr) => { - match $poll { - Poll::Ready(Err(e)) => return Poll::Ready(Err(e)), - Poll::Ready(r) => Poll::Ready(r), - _ => $poll, - } - }; - ($poll: expr, $map: expr) => { - match $poll { - Poll::Ready(Err(e)) => return Poll::Ready(Err($map(e))), - Poll::Ready(r) => Poll::Ready(r), - _ => Poll::Pending, - } - }; - } #[derive(Debug)] pub struct CopyBuffer { @@ -251,45 +235,72 @@ impl OpenBackError { } } } -pub struct ByteProxy<'a, F: AsyncRead + AsyncWrite + Unpin, B: AsyncRead + AsyncWrite + Unpin> { - front: &'a mut F, +pub struct ByteProxy<'a, B: AsyncRead + AsyncWrite + Unpin> { + front: &'a mut ResponseBodyPipe, back: &'a mut B, // ToDo: Optimize to one buffer! - front_buf: CopyBuffer, - back_buf: CopyBuffer, + buf: Vec, } -impl<'a, F: AsyncRead + AsyncWrite + Unpin, B: AsyncRead + AsyncWrite + Unpin> ByteProxy<'a, F, B> { - pub fn new(front: &'a mut F, back: &'a mut B) -> Self { +impl<'a, B: AsyncRead + AsyncWrite + Unpin> ByteProxy<'a, B> { + pub fn new(front: &'a mut ResponseBodyPipe, back: &'a mut B) -> Self { Self { front, back, - front_buf: CopyBuffer::new(), - back_buf: CopyBuffer::new(), + buf: Vec::with_capacity(16 * 1024), } } - pub fn poll_channel(&mut self, cx: &mut Context) -> Poll> { - macro_rules! copy_from_to { - ($reader: expr, $error: expr, $buf: expr, $writer: expr) => { - if let Poll::Ready(Ok(pipe_closed)) = ret_ready_err!( - $buf.poll_copy(cx, Pin::new($reader), Pin::new($writer)), - $error - ) { - if pipe_closed { - return Poll::Ready(Err(OpenBackError::Closed)); - } else { - return Poll::Ready(Ok(())); + pub async fn channel(&mut self) -> Result<(), OpenBackError> { + let mut front_done = false; + let mut back_done = false; + loop { + if !front_done { + if let ResponseBodyPipe::Http1(h1) = self.front { + { + unsafe { self.buf.set_len(self.buf.capacity()) }; + let read = h1 + .lock() + .await + .read(&mut self.buf) + .await + .map_err(OpenBackError::Front)?; + if read == 0 { + front_done = true; + } + unsafe { self.buf.set_len(read) }; } - }; - }; - } - - copy_from_to!(self.back, OpenBackError::Back, self.front_buf, self.front); - copy_from_to!(self.front, OpenBackError::Front, self.back_buf, self.back); + self.back + .write_all(&self.buf) + .await + .map_err(OpenBackError::Back)?; + } else { + front_done = true; + } + } + if !back_done { + { + unsafe { self.buf.set_len(self.buf.capacity()) }; + let read = self + .back + .read(&mut self.buf) + .await + .map_err(OpenBackError::Back)?; + if read == 0 { + back_done = true; + } + unsafe { self.buf.set_len(read) }; + } + self.front + .send(Bytes::copy_from_slice(&self.buf)) + .await + .map_err(io::Error::from) + .map_err(OpenBackError::Front)?; + } - Poll::Pending - } - pub async fn channel(&mut self) -> Result<(), OpenBackError> { - futures_util::future::poll_fn(|cx| self.poll_channel(cx)).await + if front_done && back_done { + break; + } + } + Ok(()) } } diff --git a/roadmap.md b/roadmap.md index e9d9230..53433c3 100644 --- a/roadmap.md +++ b/roadmap.md @@ -7,20 +7,26 @@ Info on changes in older versions are available at the [changelog](CHANGELOG.md) > The work will be taking place in branches, named after the target release. The order of these feature releases are not set in stone; > the features of 0.7.0 might come out as version 0.6.0 -# v0.6.0 HTTP/3 +# v0.6.0 edgebleed This is where Kvarn turns into a cutting-edge web server. > Kvarn already has a good flexible design, so adding this is largely making > a glue-crate to make HTTP/3 accessible like HTTP/2 is in the `h2` crate. +## v0.8.0 io_uring + +Use Linux's new `io_uring` interface for handling networking and IO on Linux. +This should improve performance and power efficiency. This is merged into v0.6.0. + ## To do _Well..._ -- [ ] HTTP/3 crate -- [ ] HTTP/3 support in Kvarn -- [ ] cfg to disable new feature +- [x] HTTP/3 support in Kvarn +- [x] cfg to disable new feature +- [x] io_uring support +- [ ] io_uring support for HTTP/3 # v0.7.0 DynLan @@ -43,16 +49,3 @@ Another challenge is isolating requests while using one VM. - [ ] cfg - [ ] PHP bindings - [ ] PHP crate - -# v0.8.0 io_uring - -Use Linux's new `io_uring` interface for handling networking and IO on Linux. -This should improve performance and power efficiency. - -## To do - -- [ ] Wait for [`tokio-uring`](https://docs.rs/tokio-uring) to add multithreading support -- [ ] Or support an entirely different runtime (e.g. [`monoio`](https://github.com/bytedance/monoio) - (it shouldn't be an issue that it's developed by ByteDance? Are be being tracked?)) - - [ ] Investigate compatibility issues with ecosystem. Actual implementation should be fine - (the `net` feature in `tokio` is already optional) diff --git a/src/application.rs b/src/application.rs index 141e662..764027e 100644 --- a/src/application.rs +++ b/src/application.rs @@ -13,6 +13,8 @@ pub use response::Http1Body; #[cfg(all(feature = "uring", not(feature = "async-networking")))] compile_error!("You must enable the 'async-networking' feature to use uring."); +#[cfg(all(feature = "http3", not(feature = "async-networking")))] +compile_error!("You must enable the 'async-networking' feature to use HTTP/3."); #[cfg(feature = "uring")] pub use uring_tokio_compat::TcpStreamAsyncWrapper; @@ -30,13 +32,17 @@ pub enum Error { /// [`h2`] emitted an error #[cfg(feature = "http2")] H2(h2::Error), + /// [`h3`] emitted an error + #[cfg(feature = "http3")] + H3(h3::Error), /// The HTTP version assumed by the client is not supported. /// Invalid ALPN config is a candidate. VersionNotSupported, - /// You tried to push a response on a HTTP/1 connection. + /// You tried to push a response on a HTTP/1 (or HTTP/3, for now) connection. /// /// *Use HTTP/2 instead, or check if the [`ResponsePipe`] is HTTP/1*. - PushOnHttp1, + /// Will also fail if you try to push on a pipe returned from a previous push. + UnsupportedPush, /// Client closed connection before the response could be sent. ClientRefusedResponse, } @@ -65,6 +71,13 @@ impl From for Error { Self::H2(err) } } +#[cfg(feature = "http3")] +impl From for Error { + #[inline] + fn from(err: h3::Error) -> Self { + Self::H3(err) + } +} impl From for io::Error { fn from(err: Error) -> io::Error { match err { @@ -72,12 +85,14 @@ impl From for io::Error { Error::Io(io) => io, #[cfg(feature = "http2")] Error::H2(h2) => io::Error::new(io::ErrorKind::InvalidData, h2), + #[cfg(feature = "http3")] + Error::H3(h3) => io::Error::new(io::ErrorKind::InvalidData, h3), Error::VersionNotSupported => io::Error::new( io::ErrorKind::InvalidData, "http version unsupported. Invalid ALPN config.", ), - Error::PushOnHttp1 => io::Error::new( + Error::UnsupportedPush => io::Error::new( io::ErrorKind::InvalidInput, "can not push requests on http/1", ), @@ -92,12 +107,11 @@ impl From for io::Error { /// /// See [`HttpConnection::new`] on how to make one and /// [`HttpConnection::accept`] on getting a [`FatRequest`]. -#[derive(Debug)] #[must_use] pub enum HttpConnection { - /// A HTTP/1 connection + /// An HTTP/1 connection Http1(Arc>), - /// A HTTP/2 connection + /// An HTTP/2 connection /// /// This is boxed because a [`h2::server::Connection`] takes up /// over 1000 bytes of memory, and an [`Arc`] 8 bytes. @@ -107,6 +121,23 @@ pub enum HttpConnection { /// We'll see how we move forward once HTTP/3 support lands. #[cfg(feature = "http2")] Http2(Box>), + #[cfg(feature = "http3")] + /// An HTTP/3 conenction. + Http3(h3::server::Connection), +} +impl Debug for HttpConnection { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + match self { + Self::Http1(arg0) => f.debug_tuple("Http1").field(arg0).finish(), + #[cfg(feature = "http2")] + Self::Http2(arg0) => f.debug_tuple("Http2").field(arg0).finish(), + #[cfg(feature = "http3")] + Self::Http3(_) => f + .debug_tuple("Http3") + .field(&"[internal h3 connection]".as_clean()) + .finish(), + } + } } /// The data for [`Body::Bytes`]. @@ -117,9 +148,18 @@ pub struct ByteBody { read: usize, } impl ByteBody { - /// Get a reference to the bytes of this body. - pub fn inner(&self) -> &Bytes { - &self.content + /// Read the rest of the bytes of this body + pub fn read_rest(&mut self) -> Bytes { + let b = self.content.slice(self.read..); + self.read = self.content.len(); + b + } + /// Read `n` bytes of this body + pub fn read_n(&mut self, n: usize) -> Bytes { + let n = n.min(self.content.len() - self.read); + let b = self.content.slice(self.read..(self.read + n)); + self.read += n; + b } } impl From for ByteBody { @@ -136,7 +176,6 @@ impl From for ByteBody { /// The inner variables are streams. To get the bytes, use [`Body::read_to_bytes()`] when needed. /// /// Also see [`FatRequest`]. -#[derive(Debug)] pub enum Body { /// A body of [`Bytes`]. /// @@ -151,35 +190,94 @@ pub enum Body { /// [`Body::read_to_bytes`] leverages this and just /// continues writing to the buffer. Http1(response::Http1Body), - /// A HTTP/2 body provided by [`h2`]. + /// An HTTP/2 body provided by [`h2`]. #[cfg(feature = "http2")] Http2(h2::RecvStream), + /// An HTTP/3 body provided by [`h3`]. + #[cfg(feature = "http3")] + Http3(h3::server::RequestStream), +} +impl Debug for Body { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + match self { + Self::Bytes(arg0) => f.debug_tuple("Bytes").field(arg0).finish(), + Self::Http1(arg0) => f.debug_tuple("Http1").field(arg0).finish(), + #[cfg(feature = "http2")] + Self::Http2(arg0) => f.debug_tuple("Http2").field(arg0).finish(), + #[cfg(feature = "http3")] + Self::Http3(_) => f + .debug_tuple("Http3") + .field(&"[internal h3 connection]".as_clean()) + .finish(), + } + } } /// A pipe to send a [`Response`] through. /// /// You may also push requests if the pipe is [`ResponsePipe::Http2`] /// by calling [`ResponsePipe::push_request`]. -#[derive(Debug)] #[must_use] pub enum ResponsePipe { - /// A HTTP/1 stream to send a response. + /// An HTTP/1 stream to send a response. Http1(Arc>), - /// A HTTP/2 response pipe. + /// An HTTP/2 response pipe. #[cfg(feature = "http2")] Http2(h2::server::SendResponse), + /// An HTTP/3 response pipe. + #[cfg(feature = "http3")] + Http3(h3::server::RequestStream, Bytes>), +} +impl Debug for ResponsePipe { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + match self { + Self::Http1(arg0) => f.debug_tuple("Http1").field(arg0).finish(), + #[cfg(feature = "http2")] + Self::Http2(arg0) => f.debug_tuple("Http2").field(arg0).finish(), + #[cfg(feature = "http3")] + Self::Http3(_) => f + .debug_tuple("Http3") + .field(&"[internal h3 connection]".as_clean()) + .finish(), + } + } +} +/// Abstraction layer over different kinds of HTTP/2 response senders. +#[derive(Debug)] +#[cfg(feature = "http2")] +pub enum H2SendResponse { + /// The initial response + Initial(h2::server::SendResponse), + /// Server-pushed responses + Pushed(h2::server::SendPushedResponse), } /// A pipe to send a body after the [`Response`] is sent by /// [`ResponsePipe::send_response`]. /// /// The [`AsyncWriteExt::shutdown`] does nothing, and will immediately return with Ok(()) -#[derive(Debug)] pub enum ResponseBodyPipe { /// HTTP/1 pipe Http1(Arc>), /// HTTP/2 pipe #[cfg(feature = "http2")] - Http2(h2::SendStream), + Http2(h2::SendStream, H2SendResponse), + /// HTTP/3 pipe + #[cfg(feature = "http3")] + Http3(h3::server::RequestStream, Bytes>), +} +impl Debug for ResponseBodyPipe { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + match self { + Self::Http1(arg0) => f.debug_tuple("Http1").field(arg0).finish(), + #[cfg(feature = "http2")] + Self::Http2(arg0, arg1) => f.debug_tuple("Http2").field(arg0).field(arg1).finish(), + #[cfg(feature = "http3")] + Self::Http3(_) => f + .debug_tuple("Http3") + .field(&"[internal h3 connection]".as_clean()) + .finish(), + } + } } /// A [`ResponsePipe`]-like for a pushed request-response pair. /// @@ -260,6 +358,17 @@ impl HttpConnection { }, None => Err(utils::parse::Error::UnexpectedEnd.into()), }, + #[cfg(feature = "http3")] + Self::Http3(c) => match c.accept().await { + Ok(opt) => match opt { + Some((req, stream)) => { + let (write, read) = stream.split(); + Ok((req.map(|()| Body::Http3(read)), ResponsePipe::Http3(write))) + } + None => Err(utils::parse::Error::UnexpectedEnd.into()), + }, + Err(err) => Err(err.into()), + }, } } /// Ask this connection to shutdown. @@ -269,16 +378,15 @@ impl HttpConnection { drop(h.lock().await.shutdown().await); } #[cfg(feature = "http2")] - Self::Http2(_h) => {} + Self::Http2(mut h) => h.graceful_shutdown(), + #[cfg(feature = "http3")] + Self::Http3(mut h) => drop(h.shutdown(1024)), } } } mod request { - use super::{ - io, response, utils, Arc, AsyncRead, Body, Bytes, Context, Encryption, Error, Mutex, Pin, - Poll, ReadBuf, Request, - }; + use super::{io, response, utils, Arc, Body, Bytes, Encryption, Error, Mutex, Request}; #[inline] pub(crate) async fn parse_http_1( @@ -510,7 +618,7 @@ mod request { #[inline] pub async fn read_to_bytes(&mut self, max_len: usize) -> io::Result { match self { - Self::Bytes(bytes) => Ok(bytes.inner().clone()), + Self::Bytes(bytes) => Ok(bytes.read_rest()), Self::Http1(h1) => h1.read_to_bytes(max_len).await, #[cfg(feature = "http2")] Self::Http2(h2) => { @@ -529,49 +637,30 @@ mod request { .map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err))?; bytes.extend_from_slice(&data[..(data.len().min(left))]); + let left = max_len.saturating_sub(bytes.len()); + if left == 0 { + break; + } } Ok(bytes.freeze()) } - } - } - } + #[cfg(feature = "http3")] + Self::Http3(h3) => { + use bytes::BufMut; + let mut bytes = bytes::BytesMut::new(); + while let Some(data) = h3.recv_data().await.map_err(Error::H3)? { + let left = max_len.saturating_sub(bytes.len()); + if left == 0 { + break; + } - impl AsyncRead for Body { - fn poll_read( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &mut ReadBuf<'_>, - ) -> Poll> { - match self.get_mut() { - Self::Http1(s) => unsafe { Pin::new_unchecked(s).poll_read(cx, buf) }, - #[cfg(feature = "http2")] - Self::Http2(tls) => { - let data = match tls.poll_data(cx) { - Poll::Ready(data) => data, - Poll::Pending => return Poll::Pending, - }; - match data { - Some(d) => match d { - Ok(data) => buf.put_slice(&data), - Err(err) => { - let err = io::Error::new(io::ErrorKind::InvalidData, err); - return Poll::Ready(Err(err)); - } - }, - None => return Poll::Ready(Ok(())), - } - Poll::Ready(Ok(())) - } - Self::Bytes(byte_body) => { - let rest = byte_body.inner().get(byte_body.read..).unwrap_or(&[]); - if rest.is_empty() { - return Poll::Ready(Ok(())); + bytes.put(data); + let left = max_len.saturating_sub(bytes.len()); + if left == 0 { + break; + } } - let len = std::cmp::min(buf.remaining(), rest.len()); - buf.put_slice(&rest[..len]); - byte_body.read += len; - cx.waker().wake_by_ref(); - Poll::Pending + Ok(bytes.freeze()) } } } @@ -713,61 +802,47 @@ mod response { /// [`h2::server::SendResponse::send_response()`] for more info. #[inline] pub async fn send_response( - &mut self, + self, mut response: Response<()>, #[allow(unused_variables)] end_of_stream: bool, ) -> Result { match self { Self::Http1(s) => { - let mut writer = s.lock().await; - match response - .headers() - .get("connection") - .map(HeaderValue::to_str) - .and_then(Result::ok) { - Some("close") | None => { - response - .headers_mut() - .insert("connection", HeaderValue::from_static("keep-alive")); + let mut writer = s.lock().await; + match response + .headers() + .get("connection") + .map(HeaderValue::to_str) + .and_then(Result::ok) + { + Some("close") | None => { + response + .headers_mut() + .insert("connection", HeaderValue::from_static("keep-alive")); + } + _ => {} } - _ => {} + let mut writer = tokio::io::BufWriter::with_capacity(512, &mut *writer); + async_bits::write::response(&response, b"", &mut writer).await?; + writer.flush().await?; + writer.into_inner(); } - let mut writer = tokio::io::BufWriter::with_capacity(512, &mut *writer); - async_bits::write::response(&response, b"", &mut writer).await?; - writer.flush().await?; - writer.into_inner(); - Ok(ResponseBodyPipe::Http1(Arc::clone(s))) + Ok(ResponseBodyPipe::Http1(s)) } #[cfg(feature = "http2")] - Self::Http2(s) => match s.send_response(response, end_of_stream) { + Self::Http2(mut s) => match s.send_response(response, end_of_stream) { Err(ref err) if err.get_io().is_none() && err.reason().is_none() => { Err(Error::ClientRefusedResponse) } Err(err) => Err(err.into()), - Ok(pipe) => Ok(ResponseBodyPipe::Http2(pipe)), + Ok(pipe) => Ok(ResponseBodyPipe::Http2(pipe, H2SendResponse::Initial(s))), }, - } - } - /// Pushes `request` to client. - /// - /// # Errors - /// - /// If you try to push if `self` is [`ResponsePipe::Http1`], an [`Error::PushOnHttp1`] is returned. - /// Returns errors from [`h2::server::SendResponse::push_request()`]. - #[inline] - #[allow(clippy::needless_pass_by_value)] - pub fn push_request( - &mut self, - #[allow(unused_variables)] request: Request<()>, - ) -> Result { - match self { - Self::Http1(_) => Err(Error::PushOnHttp1), - #[cfg(feature = "http2")] - Self::Http2(h2) => match h2.push_request(request) { - Ok(pipe) => Ok(PushedResponsePipe::Http2(pipe)), + #[cfg(feature = "http3")] + Self::Http3(mut s) => match s.send_response(response).await { Err(err) => Err(err.into()), + Ok(()) => Ok(ResponseBodyPipe::Http3(s)), }, } } @@ -784,6 +859,8 @@ mod response { }, #[cfg(feature = "http2")] Self::Http2(_) => *response.version_mut() = Version::HTTP_2, + #[cfg(feature = "http3")] + Self::Http3(_) => *response.version_mut() = Version::HTTP_3, } } } @@ -798,22 +875,23 @@ mod response { #[inline] #[allow(clippy::needless_pass_by_value)] pub fn send_response( - &mut self, + self, response: Response<()>, end_of_stream: bool, ) -> Result { match self { #[cfg(feature = "http2")] - Self::Http2(s) => { + Self::Http2(mut s) => { let mut response = response; *response.version_mut() = Version::HTTP_2; match s.send_response(response, end_of_stream) { Err(err) => Err(err.into()), - Ok(pipe) => Ok(ResponseBodyPipe::Http2(pipe)), + Ok(pipe) => Ok(ResponseBodyPipe::Http2(pipe, H2SendResponse::Pushed(s))), } } - #[cfg(not(any(feature = "http2")))] + #[allow(unreachable_patterns)] + #[cfg(not(feature = "http2"))] _ => unreachable!(), } } @@ -857,10 +935,37 @@ mod response { } } #[cfg(feature = "http2")] - Self::Http2(h2) => h2.send_data(data, end_of_stream)?, + Self::Http2(h2, _) => h2.send_data(data, end_of_stream)?, + #[cfg(feature = "http3")] + Self::Http3(h3) => h3.send_data(data).await?, } Ok(()) } + /// Pushes `request` to client. + /// + /// # Errors + /// + /// If you try to push if `self` is [`ResponsePipe::Http1`], an [`Error::PushOnHttp1`] is returned. + /// Returns errors from [`h2::server::SendResponse::push_request()`]. + #[inline] + #[allow(clippy::needless_pass_by_value)] + pub fn push_request( + &mut self, + #[allow(unused_variables)] request: Request<()>, + ) -> Result { + match self { + Self::Http1(_) => Err(Error::UnsupportedPush), + #[cfg(feature = "http2")] + Self::Http2(_, H2SendResponse::Pushed(_)) => Err(Error::UnsupportedPush), + #[cfg(feature = "http2")] + Self::Http2(_, H2SendResponse::Initial(h2)) => match h2.push_request(request) { + Ok(pipe) => Ok(PushedResponsePipe::Http2(pipe)), + Err(err) => Err(err.into()), + }, + #[cfg(feature = "http3")] + Self::Http3(_) => Err(Error::UnsupportedPush), + } + } /// Closes the pipe. /// /// # Errors @@ -872,64 +977,9 @@ mod response { match self { Self::Http1(h1) => h1.lock().await.flush().await.map_err(Into::into), #[cfg(feature = "http2")] - Self::Http2(h2) => h2.send_data(Bytes::new(), true).map_err(Error::from), - } - } - } - impl AsyncRead for ResponseBodyPipe { - fn poll_read( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &mut ReadBuf<'_>, - ) -> Poll> { - match self.get_mut() { - Self::Http1(s) => match s.try_lock() { - Err(_) => Poll::Pending, - Ok(mut s) => Pin::new(&mut *s).poll_read(cx, buf), - }, - #[cfg(feature = "http2")] - Self::Http2(_) => Poll::Ready(Ok(())), - } - } - } - impl AsyncWrite for ResponseBodyPipe { - fn poll_write( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &[u8], - ) -> Poll> { - match self.get_mut() { - Self::Http1(s) => match s.try_lock() { - Err(_) => Poll::Pending, - Ok(mut s) => Pin::new(&mut *s).poll_write(cx, buf), - }, - #[cfg(feature = "http2")] - Self::Http2(s) => Poll::Ready( - s.send_data(Bytes::copy_from_slice(buf), false) - .map_err(|e| { - if e.is_io() { - // This is ok; we just checked it is IO. - e.into_io().unwrap() - } else { - io::Error::new(io::ErrorKind::Other, e.to_string()) - } - }) - .map(|()| buf.len()), - ), - } - } - fn poll_shutdown(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll> { - Poll::Ready(Ok(())) - } - fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - if let Self::Http1(s) = self.get_mut() { - if let Ok(mut s) = s.try_lock() { - Pin::new(&mut *s).poll_flush(cx) - } else { - Poll::Pending - } - } else { - Poll::Ready(Ok(())) + Self::Http2(h2, _) => h2.send_data(Bytes::new(), true).map_err(Error::from), + #[cfg(feature = "http3")] + Self::Http3(h3) => h3.finish().await.map_err(Error::from), } } } diff --git a/src/extensions.rs b/src/extensions.rs index 12a7e08..13bbfd9 100644 --- a/src/extensions.rs +++ b/src/extensions.rs @@ -214,7 +214,7 @@ pub trait PostCall: KvarnSendSync { &'a self, request: &'a FatRequest, host: &'a Host, - response_pipe: &'a mut ResponsePipe, + response_pipe: &'a mut ResponseBodyPipe, identity_body: Bytes, addr: SocketAddr, ) -> RetFut<'a, ()>; @@ -223,7 +223,7 @@ impl< F: for<'a> Fn( &'a FatRequest, &'a Host, - &'a mut ResponsePipe, + &'a mut ResponseBodyPipe, Bytes, SocketAddr, ) -> RetFut<'a, ()> @@ -234,7 +234,7 @@ impl< &'a self, request: &'a FatRequest, host: &'a Host, - response_pipe: &'a mut ResponsePipe, + response_pipe: &'a mut ResponseBodyPipe, identity_body: Bytes, addr: SocketAddr, ) -> RetFut<'a, ()> { @@ -963,7 +963,7 @@ impl Extensions { &self, request: &FatRequest, bytes: Bytes, - response_pipe: &mut ResponsePipe, + response_pipe: &mut ResponseBodyPipe, addr: SocketAddr, host: &Host, ) { @@ -1206,21 +1206,31 @@ pub fn stream_body() -> Box { #[cfg(not(feature = "uring"))] let len = meta.len() as usize; - #[cfg(feature = "uring")] + #[allow(clippy::uninit_vec)] let fut = response_pipe_fut!(response, _host, move |file: fs::File| { - let mut buf = Vec::with_capacity(1024 * 32); + let mut buf = Vec::with_capacity(1024 * 64); + #[cfg(feature = "uring")] let mut pos = 0; unsafe { buf.set_len(buf.capacity()) }; loop { - let (r, b) = file.read_at(buf, pos).await; - buf = b; + #[cfg(feature = "uring")] + let r = { + let (r, b) = file.read_at(buf, pos).await; + buf = b; + r + }; + #[cfg(not(feature = "uring"))] + let r = file.read(&mut buf).await; match r { Ok(read) => { if read == 0 { break; } - pos += read as u64; - match response.write_all(&buf[..read]).await { + #[cfg(feature = "uring")] + { + pos += read as u64; + } + match response.send(Bytes::copy_from_slice(&buf[..read])).await { Ok(()) => {} Err(_) => { break; @@ -1234,15 +1244,6 @@ pub fn stream_body() -> Box { } } }); - #[cfg(not(feature = "uring"))] - let fut = response_pipe_fut!( - response, - _host, - move |file: fs::File, first_bytes: Vec| { - drop(response.write_all(first_bytes).await); - let _err = tokio::io::copy(file, response).await; - } - ); FatResponse::new(response, comprash::ServerCachePreference::None) .with_future_and_len(fut, len) @@ -1523,7 +1524,7 @@ mod macros { (), |$request: &'a $crate::FatRequest: &$crate::FatRequest: a1, $host: &'a $crate::prelude::Host: &$crate::prelude::Host: a2, - $response_pipe: &'a mut $crate::application::ResponsePipe: &mut $crate::application::ResponsePipe: a3, + $response_pipe: &'a mut $crate::application::ResponseBodyPipe: &mut $crate::application::ResponseBodyPipe: a3, $bytes: $crate::prelude::Bytes: $crate::prelude::Bytes: a4, $addr: $crate::prelude::SocketAddr: $crate::prelude::SocketAddr: a5|, $(($($move:$ty),+))?, diff --git a/src/host.rs b/src/host.rs index 8f966a0..c1f78ef 100644 --- a/src/host.rs +++ b/src/host.rs @@ -982,12 +982,22 @@ impl ResolvesServerCert for Collection { /// > ***Note:** this is often not needed, as the ALPN protocols /// are set in [`host::Collection::make_config()`].* #[must_use] +#[allow(unused_mut)] pub fn alpn() -> Vec> { - let vec = vec![ + let mut vec = vec![ #[cfg(feature = "http2")] b"h2".to_vec(), b"http/1.1".to_vec(), ]; + #[cfg(feature = "http3")] + { + vec.insert(0, b"h3-29".to_vec()); + vec.insert(0, b"h3-30".to_vec()); + vec.insert(0, b"h3-31".to_vec()); + vec.insert(0, b"h3-31".to_vec()); + vec.insert(0, b"h3-32".to_vec()); + vec.insert(0, b"h3".to_vec()); + } vec } diff --git a/src/lib.rs b/src/lib.rs index f019f8f..542aac8 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -206,6 +206,9 @@ impl RunConfig { /// # }; /// ``` pub async fn execute(self) -> Arc { + #[cfg(feature = "async-networking")] + use socket2::{Domain, Protocol, Type}; + let RunConfig { ports, #[cfg(feature = "handover")] @@ -244,55 +247,116 @@ impl RunConfig { #[cfg(feature = "async-networking")] for descriptor in &ports { fn create_listener( - create_socket: impl Fn() -> tokio::net::TcpSocket, + create_socket: impl Fn() -> socket2::Socket, + tcp: bool, address: SocketAddr, shutdown_manager: &mut shutdown::Manager, + #[allow(unused_variables)] descriptor: &PortDescriptor, ) -> AcceptManager { let socket = create_socket(); #[cfg(all(unix, not(target_os = "solaris"), not(target_os = "illumos")))] { - if socket.set_reuseaddr(true).is_err() - || socket.set_reuseport(true).is_err() + if socket.set_reuse_address(true).is_err() + || socket.set_reuse_port(true).is_err() { error!("Failed to set reuse address/port. This is needed for graceful shutdown handover."); } } - socket.bind(address).expect("Failed to bind address"); + socket + .bind(&address.into()) + .expect("Failed to bind address"); + + info!("tcp {tcp}, addr {address:?}"); + + // wrap listener + #[cfg(feature = "http3")] + if !tcp { + return shutdown_manager.add_listener(shutdown::Listener::Udp( + h3_quinn::Endpoint::new( + h3_quinn::quinn::EndpointConfig::default(), + Some(h3_quinn::quinn::ServerConfig::with_crypto( + descriptor.server_config.clone().unwrap(), + )), + socket.into(), + h3_quinn::quinn::default_runtime().unwrap(), + ) + .unwrap(), + )); + } - let listener = socket + socket .listen(1024) .expect("Failed to listen on bound address."); - // wrap listener #[cfg(feature = "uring")] - let listener = - tokio_uring::net::TcpListener::from_std(listener.into_std().unwrap()); + let listener = tokio_uring::net::TcpListener::from_std(socket.into()); + #[cfg(not(feature = "uring"))] + let listener = TcpListener::from_std(socket.into()).unwrap(); - shutdown_manager.add_listener(listener) + shutdown_manager.add_listener(shutdown::Listener::Tcp(listener)) } if matches!(descriptor.version, BindIpVersion::V4 | BindIpVersion::Both) { let listener = create_listener( || { - tokio::net::TcpSocket::new_v4() + socket2::Socket::new(Domain::IPV4, Type::STREAM, Some(Protocol::TCP)) .expect("Failed to create a new IPv4 socket configuration") }, + true, SocketAddr::new(IpAddr::V4(net::Ipv4Addr::UNSPECIFIED), descriptor.port), &mut shutdown_manager, + descriptor, ); listeners.push((listener, Arc::clone(descriptor))); } if matches!(descriptor.version, BindIpVersion::V6 | BindIpVersion::Both) { let listener = create_listener( || { - tokio::net::TcpSocket::new_v6() + socket2::Socket::new(Domain::IPV6, Type::STREAM, Some(Protocol::TCP)) .expect("Failed to create a new IPv6 socket configuration") }, + true, SocketAddr::new(IpAddr::V6(net::Ipv6Addr::UNSPECIFIED), descriptor.port), &mut shutdown_manager, + descriptor, ); listeners.push((listener, Arc::clone(descriptor))); } + #[cfg(feature = "http3")] + if descriptor.server_config.is_some() { + if matches!(descriptor.version, BindIpVersion::V4 | BindIpVersion::Both) { + let listener = create_listener( + || { + socket2::Socket::new(Domain::IPV4, Type::DGRAM, Some(Protocol::UDP)) + .expect("Failed to create a new IPv4 socket configuration") + }, + false, + SocketAddr::new( + IpAddr::V4(net::Ipv4Addr::UNSPECIFIED), + descriptor.port, + ), + &mut shutdown_manager, + descriptor, + ); + listeners.push((listener, Arc::clone(descriptor))); + } + if matches!(descriptor.version, BindIpVersion::V6 | BindIpVersion::Both) { + let listener = create_listener( + || { + socket2::Socket::new(Domain::IPV6, Type::DGRAM, Some(Protocol::UDP)) + .expect("Failed to create a new IPv6 socket configuration") + }, + false, + SocketAddr::new( + IpAddr::V6(net::Ipv6Addr::UNSPECIFIED), + descriptor.port, + ), + &mut shutdown_manager, + descriptor, + ); + listeners.push((listener, Arc::clone(descriptor))); + } + } } #[cfg(not(feature = "async-networking"))] for descriptor in &ports { @@ -303,7 +367,7 @@ impl RunConfig { )) .expect("Failed to bind to IPv4"); listeners.push(( - shutdown_manager.add_listener(listener), + shutdown_manager.add_listener(shutdown::Listener::Tcp(listener)), Arc::clone(descriptor), )); } @@ -314,7 +378,7 @@ impl RunConfig { )) .expect("Failed to bind to IPv6"); listeners.push(( - shutdown_manager.add_listener(listener), + shutdown_manager.add_listener(shutdown::Listener::Tcp(listener)), Arc::clone(descriptor), )); } @@ -473,13 +537,22 @@ pub async fn spawn( } } +/// An incoming connection, before it's wrapped with HTTP. +#[derive(Debug)] +pub enum Incoming { + /// Used for HTTP/1 & HTTP/2 + Tcp(TcpStream), + /// Used for HTTP/3 + #[cfg(feature = "http3")] + Udp(h3_quinn::quinn::Connection), +} async fn accept( mut listener: AcceptManager, descriptor: Arc, shutdown_manager: &Arc, first: bool, ) -> Result<(), io::Error> { - let local_addr = listener.get_inner().local_addr().unwrap(); + let local_addr = listener.get_inner().local_addr(); if first { info!( "Started listening on port {} using {}", @@ -489,7 +562,7 @@ async fn accept( } loop { - match listener.accept(shutdown_manager).await { + let (stream, addr) = match listener.accept(shutdown_manager).await { AcceptAction::Shutdown => { if first { info!( @@ -500,48 +573,33 @@ async fn accept( } return Ok(()); } - AcceptAction::Accept(result) => match result { - Ok((socket, addr)) => { - match descriptor.data.limiter().register(addr.ip()) { - LimitAction::Drop => { - drop(socket); - return Ok(()); - } - LimitAction::Send | LimitAction::Passed => {} - } + AcceptAction::AcceptTcp(result) => match result { + Ok((stream, addr)) => (Incoming::Tcp(stream), addr), + Err(err) => { + #[cfg(feature = "graceful-shutdown")] + let connections = format!( + " {} current connections.", + shutdown_manager.get_connecions() + ); + #[cfg(not(feature = "graceful-shutdown"))] + let connections = ""; - let descriptor = Arc::clone(&descriptor); + // An error occurred + error!("Failed to accept() on TCP listener.{connections}"); - #[cfg(feature = "graceful-shutdown")] - let shutdown_manager = Arc::clone(shutdown_manager); - let _task = spawn(async move { - #[cfg(feature = "graceful-shutdown")] - shutdown_manager.add_connection(); - let _result = handle_connection(socket, addr, descriptor, || { - #[cfg(feature = "async-networking")] - { - #[cfg(feature = "graceful-shutdown")] - { - !shutdown_manager.get_shutdown(threading::Ordering::Relaxed) - } - #[cfg(not(feature = "graceful-shutdown"))] - { - true - } - } - #[cfg(not(feature = "async-networking"))] - { - false - } - }) - .await; - #[cfg(feature = "graceful-shutdown")] - shutdown_manager.remove_connection(); - }) - .await; - continue; + return Err(err); + } + }, + #[cfg(feature = "http3")] + AcceptAction::AcceptUdp(result) => match result { + Ok(stream) => { + let addr = stream.remote_address(); + (Incoming::Udp(stream), addr) } Err(err) => { + if err.kind() == io::ErrorKind::TimedOut { + continue; + } #[cfg(feature = "graceful-shutdown")] let connections = format!( " {} current connections.", @@ -551,12 +609,50 @@ async fn accept( let connections = ""; // An error occurred - error!("Failed to accept() on listener.{connections}"); + error!("Failed to accept() on UDP listener.{connections}"); return Err(err); } }, + }; + + match descriptor.data.limiter().register(addr.ip()) { + LimitAction::Drop => { + drop(stream); + return Ok(()); + } + LimitAction::Send | LimitAction::Passed => {} } + + let descriptor = Arc::clone(&descriptor); + + #[cfg(feature = "graceful-shutdown")] + let shutdown_manager = Arc::clone(shutdown_manager); + let _task = spawn(async move { + #[cfg(feature = "graceful-shutdown")] + shutdown_manager.add_connection(); + let _result = handle_connection(stream, addr, descriptor, || { + #[cfg(feature = "async-networking")] + { + #[cfg(feature = "graceful-shutdown")] + { + !shutdown_manager.get_shutdown(threading::Ordering::Relaxed) + } + #[cfg(not(feature = "graceful-shutdown"))] + { + true + } + } + #[cfg(not(feature = "async-networking"))] + { + false + } + }) + .await; + #[cfg(feature = "graceful-shutdown")] + shutdown_manager.remove_connection(); + }) + .await; } } @@ -573,45 +669,89 @@ async fn accept( /// Will pass any errors from reading the request, making a TLS handshake, and writing the response. /// See [`handle_cache()`] and [`handle_request()`]; errors from them are passed up, through this fn. pub async fn handle_connection( - stream: TcpStream, + stream: Incoming, address: SocketAddr, descriptor: Arc, mut continue_accepting: impl FnMut() -> bool, ) -> io::Result<()> { - // LAYER 2 - #[cfg(feature = "https")] - let encrypted = { - encryption::Encryption::new_tcp(stream, descriptor.server_config.clone()) - .await - .map_err(|err| match err { - encryption::Error::Io(io) => io, - encryption::Error::Tls(tls) => io::Error::new(io::ErrorKind::InvalidData, tls), - }) - }?; - #[cfg(not(feature = "https"))] - let encrypted = encryption::Encryption::new_tcp(stream); - - let version = match encrypted.alpn_protocol() { - Some(b"h2") => Version::HTTP_2, - None | Some(b"http/1.1") => Version::HTTP_11, - Some(b"http/1.0") => Version::HTTP_10, - Some(b"http/0.9") => Version::HTTP_09, - Some(proto) => { - warn!("HTTP version not supported. Something is probably wrong with your alpn config. Client requested {}", String::from_utf8_lossy(proto)); - return Ok(()); + let (mut http, sni, version) = match stream { + Incoming::Tcp(stream) => { + // LAYER 2 + #[cfg(feature = "https")] + let encrypted = { + encryption::Encryption::new_tcp(stream, descriptor.server_config.clone()) + .await + .map_err(|err| match err { + encryption::Error::Io(io) => io, + encryption::Error::Tls(tls) => { + io::Error::new(io::ErrorKind::InvalidData, tls) + } + }) + }?; + #[cfg(not(feature = "https"))] + let encrypted = encryption::Encryption::new_tcp(stream); + + let version = match encrypted.alpn_protocol() { + Some(b"h2") => Version::HTTP_2, + None | Some(b"http/1.1") => Version::HTTP_11, + Some(b"http/1.0") => Version::HTTP_10, + Some(b"http/0.9") => Version::HTTP_09, + Some(proto) => { + warn!( + "HTTP version not supported. \ + Something is probably wrong with your alpn config. \ + Client requested {}", + String::from_utf8_lossy(proto) + ); + return Ok(()); + } + }; + let sni = encrypted.server_name().map(|s| s.to_compact_string()); + debug!("New connection requesting hostname '{sni:?}'"); + + // LAYER 3 + let http = application::HttpConnection::new(encrypted, version) + .await + .map_err::(application::Error::into)?; + (http, sni, version) + } + #[cfg(feature = "http3")] + Incoming::Udp(stream) => { + let handshake_data: Box = stream + .handshake_data() + .expect("connection is established") + .downcast() + .expect("we're using rustls"); + ( + application::HttpConnection::Http3( + h3::server::builder() + .build(h3_quinn::Connection::new(stream)) + .await + .map_err(application::Error::H3)?, + ), + handshake_data.server_name.map(CompactString::from), + Version::HTTP_3, + ) } }; - let sni = encrypted.server_name().map(|s| s.to_compact_string()); - debug!("New connection requesting hostname '{sni:?}'"); - - // LAYER 3 - let mut http = application::HttpConnection::new(encrypted, version) - .await - .map_err::(application::Error::into)?; debug!("Accepting requests from {}", address); - while let Ok((mut request, mut response_pipe)) = http + #[allow(unused_variables)] + let port = descriptor.port(); + #[allow(unused_variables)] + #[cfg(feature = "https")] + let secure = descriptor.server_config.is_some(); + #[cfg(all(feature = "http3", not(feature = "http2")))] + let alt_svc_header = format!("h3=\":{port}\"; ma=2592000"); + #[cfg(all(feature = "http2", not(feature = "http3")))] + let alt_svc_header = format!("h2=\":{port}\"; ma=2592000"); + #[cfg(all(feature = "http3", feature = "http2"))] + let alt_svc_header = format!("h3=\":{port}\"; ma=2592000, h2=\":{port}\"; ma=2592000"); + #[cfg(any(feature = "http2", feature = "http3"))] + let alt_svc_header = Bytes::from(alt_svc_header.into_bytes()); + + while let Ok((mut request, response_pipe)) = http .accept( descriptor .data @@ -668,19 +808,29 @@ pub async fn handle_connection( debug_assert!(descriptor.data.get_host(&host.name).is_some()); let hostname = host.name.clone(); let moved_host_collection = Arc::clone(&descriptor.data); + #[cfg(any(feature = "http2", feature = "http3"))] + let alt_svc_header = alt_svc_header.clone(); let future = async move { // UNWRAP: This host must be part of the Collection, as we got it from there. let host = moved_host_collection.get_host(&hostname).unwrap(); - let response = handle_cache(&mut request, address, host).await; + #[allow(unused_mut)] + let mut response = handle_cache(&mut request, address, host).await; + + #[cfg(any(feature = "http2", feature = "http3"))] + if secure { + response.response.headers_mut().append( + HeaderName::from_static("alt-svc"), + HeaderValue::from_maybe_shared(alt_svc_header).unwrap(), + ); + } - if let Err(err) = SendKind::Send(&mut response_pipe) + if let Err(err) = SendKind::Send(response_pipe) .send(response, &request, host, address) .await { error!("Got error from writing response: {:?}", err); } drop(request); - drop(response_pipe); }; // When version is HTTP/1, we block the socket if we begin listening to it again. @@ -706,13 +856,13 @@ pub async fn handle_connection( /// Most often, this is `Send`, but when a push promise is created, /// this will be `Push`. This can be used by [`extensions::Post`]. #[derive(Debug)] -pub enum SendKind<'a> { +pub enum SendKind { /// Send the response normally. - Send(&'a mut application::ResponsePipe), + Send(application::ResponsePipe), /// Send the response as a HTTP/2 push. - Push(&'a mut application::PushedResponsePipe), + Push(application::PushedResponsePipe), } -impl<'a> SendKind<'a> { +impl SendKind { /// Ensures correct version and length (only applicable for HTTP/1 connections) /// of a response according to inner enum variants. #[inline] @@ -729,7 +879,7 @@ impl<'a> SendKind<'a> { /// returns any errors with sending the data. #[inline] pub async fn send( - &mut self, + self, response: CacheReply, request: &FatRequest, host: &Host, @@ -791,7 +941,7 @@ impl<'a> SendKind<'a> { // Process post extensions host.extensions - .resolve_post(request, identity_body, response_pipe, address, host) + .resolve_post(request, identity_body, &mut body_pipe, address, host) .await; // Close the pipe. diff --git a/src/prelude.rs b/src/prelude.rs index 47a190a..6895f09 100644 --- a/src/prelude.rs +++ b/src/prelude.rs @@ -35,7 +35,7 @@ pub use error::{default as default_error, default_response as default_error_resp pub use extensions::{Package, Post, Prepare, Present, Prime, ResponsePipeFuture}; pub use host::{Collection as HostCollection, Host}; pub use read; -pub use shutdown::{AcceptAction, AcceptManager}; +pub(crate) use shutdown::{AcceptAction, AcceptManager}; pub use utils::{build_bytes, chars::*, parse, parse::SanitizeError, AsCleanDebug}; /// **Prelude:** file system @@ -55,12 +55,12 @@ pub mod fs { pub mod networking { pub use super::async_bits::*; #[cfg(not(feature = "async-networking"))] - pub use std::net::{TcpListener, TcpStream}; + pub use std::net::{TcpListener, TcpStream, UdpSocket}; #[cfg(all(feature = "async-networking", not(feature = "uring")))] - pub use tokio::net::{TcpListener, TcpSocket, TcpStream}; + pub use tokio::net::{TcpListener, TcpSocket, TcpStream, UdpSocket}; #[cfg(all(feature = "async-networking", feature = "uring"))] pub use { - crate::application::TcpStreamAsyncWrapper as TcpStream, tokio_uring::net::TcpListener, + crate::application::TcpStreamAsyncWrapper as TcpStream, tokio_uring::net::TcpListener,tokio::net::UdpSocket }; } diff --git a/src/shutdown.rs b/src/shutdown.rs index a7274c5..3e1f93b 100644 --- a/src/shutdown.rs +++ b/src/shutdown.rs @@ -123,7 +123,7 @@ impl Manager { /// Adds a listener to this manager. /// /// This is used so the `accept` future resolves immediately when the shutdown is triggered. - pub fn add_listener(&mut self, listener: TcpListener) -> AcceptManager { + pub(crate) fn add_listener(&mut self, listener: Listener) -> AcceptManager { AcceptManager { #[cfg(feature = "graceful-shutdown")] index: { @@ -351,31 +351,58 @@ impl Manager { /// Can either be a new connection or a shutdown signal. /// The listener should be dropped right after the shutdown signal is received. #[must_use] -pub enum AcceptAction { +pub(crate) enum AcceptAction { /// Shutdown signal; immediately drop this struct. + #[allow(dead_code)] Shutdown, /// Accept a new connection or handle a IO error. - Accept(io::Result<(TcpStream, SocketAddr)>), + AcceptTcp(io::Result<(TcpStream, SocketAddr)>), + /// Accept a new connection or handle a IO error. + #[cfg(feature = "http3")] + AcceptUdp(io::Result), } - impl Debug for AcceptAction { fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { match self { Self::Shutdown => write!(f, "Shutdown"), - Self::Accept(arg0) => f + Self::AcceptTcp(arg0) => f .debug_tuple("Accept") .field(&arg0.as_ref().map(|(_, addr)| addr)) .finish(), + #[cfg(feature = "http3")] + Self::AcceptUdp(arg0) => f + .debug_tuple("Accept") + .field( + &arg0 + .as_ref() + .map(h3_quinn::quinn::Connection::remote_address), + ) + .finish(), } } } +pub(crate) enum Listener { + Tcp(TcpListener), + #[cfg(feature = "http3")] + Udp(h3_quinn::Endpoint), +} +impl Listener { + pub(crate) fn local_addr(&self) -> SocketAddr { + match self { + Listener::Tcp(tcp) => tcp.local_addr(), + #[cfg(feature = "http3")] + Listener::Udp(udp) => udp.local_addr(), + } + .unwrap_or_else(|_| SocketAddr::V4(net::SocketAddrV4::new(net::Ipv4Addr::LOCALHOST, 0))) + } +} /// A wrapper around [`TcpListener`] (and `UdpListener` when HTTP/3 comes around) /// which waits for a new connection **or** a shutdown signal. #[must_use] -pub struct AcceptManager { +pub(crate) struct AcceptManager { #[cfg(feature = "graceful-shutdown")] index: WakerIndex, - listener: TcpListener, + listener: Listener, } // SAFETY: TcpListener is just an FD, and can be sent across threads. unsafe impl Send for AcceptManager {} @@ -400,7 +427,7 @@ impl AcceptManager { /// Please increase the count of connections on [`Manager`] when this connection is accepted /// and decrease it when the connection dies. #[allow(clippy::let_and_return)] // cfg - pub async fn accept(&mut self, _manager: &Manager) -> AcceptAction { + pub(crate) async fn accept(&mut self, _manager: &Manager) -> AcceptAction { #[cfg(feature = "async-networking")] { let action = AcceptFuture { @@ -418,12 +445,14 @@ impl AcceptManager { } #[cfg(not(feature = "async-networking"))] { - AcceptAction::Accept(self.listener.accept()) + match &mut self.listener { + Listener::Tcp(tcp) => AcceptAction::AcceptTcp(tcp.accept()), + } } } /// Returns a reference to the inner listener. #[must_use] - pub fn get_inner(&self) -> &TcpListener { + pub(crate) fn get_inner(&self) -> &Listener { &self.listener } } @@ -433,7 +462,18 @@ struct AcceptFuture<'a> { manager: &'a Manager, #[cfg(feature = "graceful-shutdown")] index: WakerIndex, - listener: &'a mut TcpListener, + listener: &'a mut Listener, +} +#[cfg(all(feature = "async-networking", feature = "http3"))] +async fn accept_udp(endpoint: &mut h3_quinn::Endpoint) -> io::Result { + if let Some(s) = endpoint.accept().await { + s.await.map_err(io::Error::from) + } else { + Err(io::Error::new( + io::ErrorKind::ConnectionReset, + "accept socket finished", + )) + } } #[cfg(feature = "async-networking")] impl<'a> AcceptFuture<'a> { @@ -451,34 +491,55 @@ impl<'a> AcceptFuture<'a> { self.manager.set_waker(self.index, Waker::clone(cx.waker())); Poll::Pending }); - let listener_fut = self.listener.accept(); - tokio::pin!(shutdown_fut); - #[cfg(feature = "uring")] - tokio::select! { - _ = shutdown_fut => AcceptAction::Shutdown, - r = listener_fut => AcceptAction::Accept(r.map(|(stream, addr)| (TcpStream::new(stream), addr))), - } - #[cfg(not(feature = "uring"))] - tokio::select! { - _ = shutdown_fut => AcceptAction::Shutdown, - r = listener_fut => AcceptAction::Accept(r), + match self.listener { + Listener::Tcp(tcp) => { + let listener_fut = tcp.accept(); + tokio::pin!(shutdown_fut); + #[cfg(feature = "uring")] + tokio::select! { + _ = shutdown_fut => AcceptAction::Shutdown, + r = listener_fut => AcceptAction::AcceptTcp(r.map(|(stream, addr)| (TcpStream::new(stream), addr))), + } + #[cfg(not(feature = "uring"))] + tokio::select! { + _ = shutdown_fut => AcceptAction::Shutdown, + r = listener_fut => AcceptAction::AcceptTcp(r), + } + } + #[cfg(feature = "http3")] + Listener::Udp(udp) => { + let listener_fut = accept_udp(udp); + tokio::pin!(shutdown_fut); + tokio::select! { + _ = shutdown_fut => AcceptAction::Shutdown, + r = listener_fut => AcceptAction::AcceptUdp(r), + + } + } } } #[cfg(not(feature = "graceful-shutdown"))] { #[cfg(feature = "uring")] { - AcceptAction::Accept( - self.listener - .accept() - .await - .map(|(stream, addr)| (TcpStream::new(stream), addr)), - ) + match self.listener { + Listener::Tcp(s) => AcceptAction::AcceptTcp( + s.accept() + .await + .map(|(stream, addr)| (TcpStream::new(stream), addr)), + ), + #[cfg(feature = "http3")] + Listener::Udp(udp) => AcceptAction::AcceptUdp(accept_udp(udp).await), + } } #[cfg(not(feature = "uring"))] { - AcceptAction::Accept(self.listener.accept().await) + match self.listener { + Listener::Tcp(s) => AcceptAction::AcceptTcp(s.accept().await), + #[cfg(feature = "http3")] + Listener::Udp(udp) => AcceptAction::AcceptUdp(accept_udp(udp).await), + } } } } diff --git a/src/websocket.rs b/src/websocket.rs index cd350f8..f51678b 100644 --- a/src/websocket.rs +++ b/src/websocket.rs @@ -116,18 +116,83 @@ pub async fn response(req: &FatRequest, host: &Host, future: ResponsePipeFuture) .with_future(future) } +/// Error from WebSocket operations +#[derive(Debug)] +pub enum Error { + /// WebSocket currently isn't supported for HTTP/3 nor HTTP/2. + WebSocketUnsupported, +} +/// Variants of WebSocket streams. +#[derive(Debug)] +pub enum WSStream<'a> { + /// + Http1(&'a Arc>), +} +impl<'a> AsyncRead for WSStream<'a> { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + match self.get_mut() { + Self::Http1(s) => match s.try_lock() { + Err(_) => Poll::Pending, + Ok(mut s) => Pin::new(&mut *s).poll_read(cx, buf), + }, + } + } +} +impl<'a> AsyncWrite for WSStream<'a> { + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + match self.get_mut() { + Self::Http1(s) => match s.try_lock() { + Err(_) => Poll::Pending, + Ok(mut s) => Pin::new(&mut *s).poll_write(cx, buf), + }, + } + } + fn poll_shutdown(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + match self.get_mut() { + Self::Http1(s) => { + if let Ok(mut s) = s.try_lock() { + Pin::new(&mut *s).poll_flush(cx) + } else { + Poll::Pending + } + } + } + } +} + /// Get a [`tokio_tungstenite::WebSocketStream`] from the `pipe` given by [`response_pipe_fut!`]. /// /// # Examples /// /// See [`response()`]. +/// +/// # Errors +/// +/// Errors if `pipe` is [unsupported](Error::WebSocketUnsupported). pub async fn wrap( pipe: &mut ResponseBodyPipe, -) -> tokio_tungstenite::WebSocketStream<&mut ResponseBodyPipe> { - tokio_tungstenite::WebSocketStream::from_raw_socket( - pipe, - tungstenite::protocol::Role::Server, - None, - ) - .await +) -> Result, Error> { + match pipe { + ResponseBodyPipe::Http1(s) => Ok(tokio_tungstenite::WebSocketStream::from_raw_socket( + WSStream::Http1(s), + tungstenite::protocol::Role::Server, + None, + ) + .await), + #[cfg(feature = "http2")] + ResponseBodyPipe::Http2(_, _) => Err(Error::WebSocketUnsupported), + #[cfg(feature = "http3")] + ResponseBodyPipe::Http3(_) => Err(Error::WebSocketUnsupported), + } }