diff --git a/server/src/server.rs b/server/src/server.rs index 14a4222e1b..5a89a6d131 100644 --- a/server/src/server.rs +++ b/server/src/server.rs @@ -1146,12 +1146,19 @@ where RpcServiceCfg::OnlyCalls, )); - Box::pin( - http::call_with_service(request, batch_config, max_request_size, rpc_service, max_response_size) - .map(Ok), - ) + Box::pin(async move { + let rp = + http::call_with_service(request, batch_config, max_request_size, rpc_service, max_response_size) + .await; + // NOTE: The `conn guard` must be held until the response is processed + // to respect the `max_connections` limit. + drop(conn); + Ok(rp) + }) } else { - Box::pin(async { http::response::denied() }.map(Ok)) + // NOTE: the `conn guard` is dropped when this function which is fine + // because it doesn't rely on any async operations. + Box::pin(async { Ok(http::response::denied()) }) } } } diff --git a/server/src/tests/ws.rs b/server/src/tests/ws.rs index 6a6b1ecf2a..dbedf03d72 100644 --- a/server/src/tests/ws.rs +++ b/server/src/tests/ws.rs @@ -569,16 +569,17 @@ async fn custom_subscription_id_works() { let addr = server.local_addr().unwrap(); let mut module = RpcModule::new(()); module - .register_subscription("subscribe_hello", "subscribe_hello", "unsubscribe_hello", |_, sink, _, _| async { - let sink = sink.accept().await.unwrap(); - - assert!(matches!(sink.subscription_id(), SubscriptionId::Str(id) if id == "0xdeadbeef")); - - loop { - let _ = &sink; - tokio::time::sleep(std::time::Duration::from_secs(30)).await; - } - }) + .register_subscription::<(), _, _>( + "subscribe_hello", + "subscribe_hello", + "unsubscribe_hello", + |_, sink, _, _| async { + let sink = sink.accept().await.unwrap(); + assert!(matches!(sink.subscription_id(), SubscriptionId::Str(id) if id == "0xdeadbeef")); + // Keep idle until it's unsubscribed. + futures_util::future::pending::<()>().await; + }, + ) .unwrap(); let _handle = server.start(module); diff --git a/tests/tests/integration_tests.rs b/tests/tests/integration_tests.rs index 9e7a2a71ee..1571cfc4a7 100644 --- a/tests/tests/integration_tests.rs +++ b/tests/tests/integration_tests.rs @@ -50,7 +50,7 @@ use jsonrpsee::core::server::SubscriptionMessage; use jsonrpsee::core::{JsonValue, StringError}; use jsonrpsee::http_client::HttpClientBuilder; use jsonrpsee::server::middleware::http::HostFilterLayer; -use jsonrpsee::server::{ServerBuilder, ServerHandle}; +use jsonrpsee::server::{ConnectionGuard, ServerBuilder, ServerHandle}; use jsonrpsee::types::error::{ErrorObject, UNKNOWN_ERROR_CODE}; use jsonrpsee::ws_client::WsClientBuilder; use jsonrpsee::{rpc_params, ResponsePayload, RpcModule}; @@ -1543,3 +1543,71 @@ async fn server_ws_low_api_works() { Ok(local_addr) } } + +#[tokio::test] +async fn http_connection_guard_works() { + use jsonrpsee::{server::ServerBuilder, RpcModule}; + use tokio::sync::mpsc; + + init_logger(); + + let (tx, mut rx) = mpsc::channel::<()>(1); + + let server_url = { + let server = ServerBuilder::default().build("127.0.0.1:0").await.unwrap(); + let server_url = format!("http://{}", server.local_addr().unwrap()); + let mut module = RpcModule::new(tx); + + module + .register_async_method("wait_until", |_, wait, _| async move { + wait.send(()).await.unwrap(); + wait.closed().await; + true + }) + .unwrap(); + + module + .register_async_method("connection_count", |_, _, ctx| async move { + let conn = ctx.get::().unwrap(); + conn.max_connections() - conn.available_connections() + }) + .unwrap(); + + let handle = server.start(module); + + tokio::spawn(handle.stopped()); + + server_url + }; + + let waiting_calls: Vec<_> = (0..2) + .map(|_| { + let client = HttpClientBuilder::default().build(&server_url).unwrap(); + tokio::spawn(async move { + let _ = client.request::("wait_until", rpc_params!()).await; + }) + }) + .collect(); + + // Wait until both calls are ACK:ed by the server. + rx.recv().await.unwrap(); + rx.recv().await.unwrap(); + + // Assert that two calls are waiting to be answered and the current one. + { + let client = HttpClientBuilder::default().build(&server_url).unwrap(); + let conn_count = client.request::("connection_count", rpc_params!()).await.unwrap(); + assert_eq!(conn_count, 3); + } + + // Complete the waiting calls. + drop(rx); + futures::future::join_all(waiting_calls).await; + + // Assert that connection count is back to 1. + { + let client = HttpClientBuilder::default().build(&server_url).unwrap(); + let conn_count = client.request::("connection_count", rpc_params!()).await.unwrap(); + assert_eq!(conn_count, 1); + } +}