From 1dc49e5d58fea29795fef5705808fd15eb2f80fc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9odore=20Pr=C3=A9vot?= Date: Mon, 20 Jan 2025 12:14:43 +0100 Subject: [PATCH] test(engineio): mock http/ws connections and improve integration tests (#447) * test(engineio): mock http/ws connections * test(engineio): fix rx to stream adapter --- crates/engineioxide/Cargo.toml | 3 +- crates/engineioxide/src/service/mod.rs | 23 +++ crates/engineioxide/src/transport/ws.rs | 2 +- .../engineioxide/tests/disconnect_reason.rs | 42 ++-- crates/engineioxide/tests/fixture.rs | 191 ++++++++++++------ 5 files changed, 174 insertions(+), 87 deletions(-) diff --git a/crates/engineioxide/Cargo.toml b/crates/engineioxide/Cargo.toml index b8ac8ae8..93425318 100644 --- a/crates/engineioxide/Cargo.toml +++ b/crates/engineioxide/Cargo.toml @@ -54,7 +54,8 @@ tracing-subscriber.workspace = true hyper = { workspace = true, features = ["server", "http1"] } criterion.workspace = true axum.workspace = true -hyper-util = { workspace = true, features = ["tokio", "client-legacy"] } +tokio-stream = "0.1" +tokio-util = { version = "0.7", features = ["io"], default-features = false } [features] v3 = ["memchr", "unicode-segmentation", "itoa"] diff --git a/crates/engineioxide/src/service/mod.rs b/crates/engineioxide/src/service/mod.rs index 261fbb1b..f34745fe 100644 --- a/crates/engineioxide/src/service/mod.rs +++ b/crates/engineioxide/src/service/mod.rs @@ -159,6 +159,29 @@ where } } +#[cfg(feature = "__test_harness")] +#[doc(hidden)] +impl EngineIoService +where + H: EngineIoHandler, +{ + /// Create a new engine.io conn over websocket through a raw stream. + /// Mostly used for testing. + pub fn ws_init( + &self, + conn: S, + protocol: ProtocolVersion, + sid: Option, + req_data: http::request::Parts, + ) -> impl std::future::Future> + where + S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + Send + 'static, + { + let engine = self.engine.clone(); + crate::transport::ws::on_init(engine, conn, protocol, sid, req_data) + } +} + /// A MakeService that always returns a clone of the [`EngineIoService`] it was created with. pub struct MakeEngineIoService { svc: EngineIoService, diff --git a/crates/engineioxide/src/transport/ws.rs b/crates/engineioxide/src/transport/ws.rs index f5a13978..a7acdcf7 100644 --- a/crates/engineioxide/src/transport/ws.rs +++ b/crates/engineioxide/src/transport/ws.rs @@ -100,7 +100,7 @@ pub fn new_req( /// Sends an open packet if it is not an upgrade from a polling request /// /// Read packets from the websocket and handle them, it will block until the connection is closed -async fn on_init( +pub async fn on_init( engine: Arc>, conn: S, protocol: ProtocolVersion, diff --git a/crates/engineioxide/tests/disconnect_reason.rs b/crates/engineioxide/tests/disconnect_reason.rs index b29ab819..394d66b1 100644 --- a/crates/engineioxide/tests/disconnect_reason.rs +++ b/crates/engineioxide/tests/disconnect_reason.rs @@ -18,10 +18,10 @@ use tokio::sync::mpsc; mod fixture; -use fixture::{create_server, send_req}; +use fixture::{create_server, create_ws_connection, send_req}; use tokio_tungstenite::tungstenite::Message; -use crate::fixture::{create_polling_connection, create_ws_connection}; +use crate::fixture::create_polling_connection; #[derive(Debug, Clone)] struct MyHandler { @@ -53,8 +53,8 @@ impl EngineIoHandler for MyHandler { #[tokio::test] pub async fn polling_heartbeat_timeout() { let (disconnect_tx, mut rx) = mpsc::channel(10); - create_server(MyHandler { disconnect_tx }, 1234).await; - create_polling_connection(1234).await; + let mut svc = create_server(MyHandler { disconnect_tx }).await; + create_polling_connection(&mut svc).await; let data = tokio::time::timeout(Duration::from_millis(500), rx.recv()) .await @@ -67,8 +67,8 @@ pub async fn polling_heartbeat_timeout() { #[tokio::test] pub async fn ws_heartbeat_timeout() { let (disconnect_tx, mut rx) = mpsc::channel(10); - create_server(MyHandler { disconnect_tx }, 12344).await; - let _stream = create_ws_connection(12344).await; + let mut svc = create_server(MyHandler { disconnect_tx }).await; + let _stream = create_ws_connection(&mut svc).await; let data = tokio::time::timeout(Duration::from_millis(500), rx.recv()) .await @@ -81,11 +81,11 @@ pub async fn ws_heartbeat_timeout() { #[tokio::test] pub async fn polling_transport_closed() { let (disconnect_tx, mut rx) = mpsc::channel(10); - create_server(MyHandler { disconnect_tx }, 1235).await; - let sid = create_polling_connection(1235).await; + let mut svc = create_server(MyHandler { disconnect_tx }).await; + let sid = create_polling_connection(&mut svc).await; send_req( - 1235, + &mut svc, format!("transport=polling&sid={sid}"), http::Method::POST, Some("1".into()), @@ -103,8 +103,8 @@ pub async fn polling_transport_closed() { #[tokio::test] pub async fn ws_transport_closed() { let (disconnect_tx, mut rx) = mpsc::channel(10); - create_server(MyHandler { disconnect_tx }, 12345).await; - let mut stream = create_ws_connection(12345).await; + let mut svc = create_server(MyHandler { disconnect_tx }).await; + let mut stream = create_ws_connection(&mut svc).await; stream.send(Message::Text("1".into())).await.unwrap(); @@ -119,10 +119,10 @@ pub async fn ws_transport_closed() { #[tokio::test] pub async fn multiple_http_polling() { let (disconnect_tx, mut rx) = mpsc::channel(10); - create_server(MyHandler { disconnect_tx }, 1236).await; - let sid = create_polling_connection(1236).await; + let mut svc = create_server(MyHandler { disconnect_tx }).await; + let sid = create_polling_connection(&mut svc).await; send_req( - 1236, + &mut svc, format!("transport=polling&sid={sid}"), http::Method::GET, None, @@ -131,13 +131,13 @@ pub async fn multiple_http_polling() { tokio::spawn(futures_util::future::join_all(vec![ send_req( - 1236, + &mut svc, format!("transport=polling&sid={sid}"), http::Method::GET, None, ), send_req( - 1236, + &mut svc, format!("transport=polling&sid={sid}"), http::Method::GET, None, @@ -155,10 +155,10 @@ pub async fn multiple_http_polling() { #[tokio::test] pub async fn polling_packet_parsing() { let (disconnect_tx, mut rx) = mpsc::channel(10); - create_server(MyHandler { disconnect_tx }, 1237).await; - let sid = create_polling_connection(1237).await; + let mut svc = create_server(MyHandler { disconnect_tx }).await; + let sid = create_polling_connection(&mut svc).await; send_req( - 1237, + &mut svc, format!("transport=polling&sid={sid}"), http::Method::POST, Some("aizdunazidaubdiz".into()), @@ -176,8 +176,8 @@ pub async fn polling_packet_parsing() { #[tokio::test] pub async fn ws_packet_parsing() { let (disconnect_tx, mut rx) = mpsc::channel(10); - create_server(MyHandler { disconnect_tx }, 12347).await; - let mut stream = create_ws_connection(12347).await; + let mut svc = create_server(MyHandler { disconnect_tx }).await; + let mut stream = create_ws_connection(&mut svc).await; stream .send(Message::Text("aizdunazidaubdiz".into())) .await diff --git a/crates/engineioxide/tests/fixture.rs b/crates/engineioxide/tests/fixture.rs index 42b9a25c..669a5662 100644 --- a/crates/engineioxide/tests/fixture.rs +++ b/crates/engineioxide/tests/fixture.rs @@ -1,21 +1,32 @@ use std::{ collections::VecDeque, - net::{IpAddr, Ipv4Addr, SocketAddr}, + future::Future, + io, + pin::Pin, sync::Arc, + task::{Context, Poll}, time::Duration, }; -use engineioxide::{config::EngineIoConfig, handler::EngineIoHandler, service::EngineIoService}; +use bytes::Bytes; +use engineioxide::{ + config::EngineIoConfig, handler::EngineIoHandler, service::EngineIoService, sid::Sid, + ProtocolVersion, +}; use http::Request; use http_body_util::{BodyExt, Either, Empty, Full}; -use hyper::server::conn::http1; -use hyper_util::{ - client::legacy::Client, - rt::{TokioExecutor, TokioIo}, -}; use serde::{Deserialize, Serialize}; -use tokio::net::{TcpListener, TcpStream}; -use tokio_tungstenite::{MaybeTlsStream, WebSocketStream}; +use tokio::{ + io::{AsyncRead, AsyncWrite, ReadBuf}, + sync::mpsc, +}; +use tokio_stream::wrappers::UnboundedReceiverStream; +use tokio_tungstenite::{ + tungstenite::{handshake::client::generate_key, protocol::Role}, + WebSocketStream, +}; +use tokio_util::io::StreamReader; +use tower_service::Service; /// An OpenPacket is used to initiate a connection #[derive(Debug, Serialize, Deserialize, PartialEq, PartialOrd)] @@ -29,12 +40,12 @@ struct OpenPacket { } /// Params should be in the form of `key1=value1&key2=value2` -pub async fn send_req( - port: u16, +pub fn send_req( + svc: &mut EngineIoService, params: String, method: http::Method, body: Option, -) -> String { +) -> impl Future + 'static { let body = match body { Some(b) => Either::Left(Full::new(VecDeque::from(b.into_bytes()))), None => Either::Right(Empty::>::new()), @@ -42,28 +53,30 @@ pub async fn send_req( let req = Request::builder() .method(method) - .uri(format!( - "http://127.0.0.1:{port}/engine.io/?EIO=4&{}", - params - )) + .uri(format!("http://127.0.0.1/engine.io/?EIO=4&{}", params)) .body(body) .unwrap(); - let mut res = Client::builder(TokioExecutor::new()) - .build_http() - .request(req) - .await - .unwrap(); - let body = res.body_mut().collect().await.unwrap().to_bytes(); - String::from_utf8(body.to_vec()) - .unwrap() - .chars() - .skip(1) - .collect() + let res = svc.call(req); + async move { + let body = res + .await + .unwrap() + .body_mut() + .collect() + .await + .unwrap() + .to_bytes(); + String::from_utf8(body.to_vec()) + .unwrap() + .chars() + .skip(1) + .collect() + } } -pub async fn create_polling_connection(port: u16) -> String { +pub async fn create_polling_connection(svc: &mut EngineIoService) -> String { let body = send_req( - port, + svc, "transport=polling".to_string(), http::Method::GET, None, @@ -72,48 +85,98 @@ pub async fn create_polling_connection(port: u16) -> String { let open_packet: OpenPacket = serde_json::from_str(&body).unwrap(); open_packet.sid } -pub async fn create_ws_connection(port: u16) -> WebSocketStream> { - tokio_tungstenite::connect_async(format!( - "ws://127.0.0.1:{port}/engine.io/?EIO=4&transport=websocket" - )) +pub async fn create_ws_connection( + svc: &mut EngineIoService, +) -> WebSocketStream { + new_ws_mock_conn(svc, ProtocolVersion::V4, None).await +} + +pin_project_lite::pin_project! { + pub struct StreamImpl { + tx: mpsc::UnboundedSender>, + #[pin] + rx: StreamReader>, Bytes>, + } +} +impl StreamImpl { + pub fn new( + tx: mpsc::UnboundedSender>, + rx: mpsc::UnboundedReceiver>, + ) -> Self { + Self { + tx, + rx: StreamReader::new(UnboundedReceiverStream::new(rx)), + } + } +} + +impl AsyncRead for StreamImpl { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + self.project().rx.poll_read(cx, buf) + } +} +impl AsyncWrite for StreamImpl { + fn poll_write( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + let len = buf.len(); + self.project() + .tx + .send(Ok(Bytes::copy_from_slice(buf))) + .unwrap(); + Poll::Ready(Ok(len)) + } + + fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } +} +async fn new_ws_mock_conn( + svc: &mut EngineIoService, + protocol: ProtocolVersion, + sid: Option, +) -> WebSocketStream { + let (tx, rx) = mpsc::unbounded_channel(); + let (tx1, rx1) = mpsc::unbounded_channel(); + + let parts = Request::builder() + .method("GET") + .header("Host", "127.0.0.1") + .header("Connection", "Upgrade") + .header("Upgrade", "websocket") + .header("Sec-WebSocket-Version", "13") + .header("Sec-WebSocket-Key", generate_key()) + .uri("ws://127.0.0.1/engine.io/?EIO=4&transport=websocket") + .body(http_body_util::Empty::::new()) + .unwrap() + .into_parts() + .0; + tokio::spawn(svc.ws_init(StreamImpl::new(tx, rx1), protocol, sid, parts)); + + tokio_tungstenite::WebSocketStream::from_raw_socket( + StreamImpl::new(tx1, rx), + Role::Client, + Default::default(), + ) .await - .unwrap() - .0 } -pub async fn create_server(handler: H, port: u16) { +pub async fn create_server(handler: H) -> EngineIoService { let config = EngineIoConfig::builder() .ping_interval(Duration::from_millis(300)) .ping_timeout(Duration::from_millis(200)) .max_payload(1e6 as u64) .build(); - let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), port); - - let svc = EngineIoService::with_config(Arc::new(handler), config); - - let listener = TcpListener::bind(&addr).await.unwrap(); - tokio::spawn(async move { - // We start a loop to continuously accept incoming connections - loop { - let (stream, _) = listener.accept().await.unwrap(); - - // Use an adapter to access something implementing `tokio::io` traits as if they implement - // `hyper::rt` IO traits. - let io = TokioIo::new(stream); - let svc = svc.clone(); - - // Spawn a tokio task to serve multiple connections concurrently - tokio::task::spawn(async move { - // Finally, we bind the incoming connection to our `hello` service - if let Err(err) = http1::Builder::new() - .serve_connection(io, svc) - .with_upgrades() - .await - { - println!("Error serving connection: {:?}", err); - } - }); - } - }); + EngineIoService::with_config(Arc::new(handler), config) }