diff --git a/client/ws-client/src/tests.rs b/client/ws-client/src/tests.rs index c4254bd64a..9da58eec80 100644 --- a/client/ws-client/src/tests.rs +++ b/client/ws-client/src/tests.rs @@ -38,6 +38,7 @@ use jsonrpsee_test_utils::helpers::*; use jsonrpsee_test_utils::mocks::{Id, WebSocketTestServer}; use jsonrpsee_test_utils::TimeoutFutureExt; use jsonrpsee_types::error::ErrorObjectOwned; +use jsonrpsee_types::{Notification, SubscriptionId, SubscriptionPayload, SubscriptionResponse}; use serde_json::Value as JsonValue; fn init_logger() { @@ -152,7 +153,7 @@ async fn subscription_works() { let server = WebSocketTestServer::with_hardcoded_subscription( "127.0.0.1:0".parse().unwrap(), server_subscription_id_response(Id::Num(0)), - server_subscription_response(JsonValue::String("hello my friend".to_owned())), + server_subscription_response("subscribe_hello", "hello my friend".into()), ) .with_default_timeout() .await @@ -192,10 +193,28 @@ async fn notification_handler_works() { } #[tokio::test] -async fn batched_notification_handler_works() { - let server = WebSocketTestServer::with_hardcoded_notification( +async fn batched_notifs_works() { + init_logger(); + + let notifs = vec![ + serde_json::to_value(&Notification::new("test".into(), "method_notif".to_string())).unwrap(), + serde_json::to_value(&Notification::new("sub".into(), "method_notif".to_string())).unwrap(), + serde_json::to_value(&SubscriptionResponse::new( + "sub".into(), + SubscriptionPayload { + subscription: SubscriptionId::Str("D3wwzU6vvoUUYehv4qoFzq42DZnLoAETeFzeyk8swH4o".into()), + result: "sub_notif".to_string(), + }, + )) + .unwrap(), + ]; + + let serialized_batch = serde_json::to_string(¬ifs).unwrap(); + + let server = WebSocketTestServer::with_hardcoded_subscription( "127.0.0.1:0".parse().unwrap(), - server_batched_notification("test", "batched server originated notification works".into()), + server_subscription_id_response(Id::Num(0)), + serialized_batch, ) .with_default_timeout() .await @@ -203,11 +222,22 @@ async fn batched_notification_handler_works() { let uri = to_ws_uri_string(server.local_addr()); let client = WsClientBuilder::default().build(&uri).with_default_timeout().await.unwrap().unwrap(); + + // Ensure that subscription is returned back to the correct handle + // and is handled separately from ordinary notifications. { let mut nh: Subscription = - client.subscribe_to_method("test").with_default_timeout().await.unwrap().unwrap(); + client.subscribe("sub", rpc_params![], "unsub").with_default_timeout().await.unwrap().unwrap(); + let response: String = nh.next().with_default_timeout().await.unwrap().unwrap().unwrap(); + assert_eq!("sub_notif", response); + } + + // Ensure that method notif is returned back to the correct handle. + { + let mut nh: Subscription = + client.subscribe_to_method("sub").with_default_timeout().await.unwrap().unwrap(); let response: String = nh.next().with_default_timeout().await.unwrap().unwrap().unwrap(); - assert_eq!("batched server originated notification works".to_owned(), response); + assert_eq!("method_notif", response); } } diff --git a/core/src/client/async_client/mod.rs b/core/src/client/async_client/mod.rs index 6ebb5f8827..ff29978055 100644 --- a/core/src/client/async_client/mod.rs +++ b/core/src/client/async_client/mod.rs @@ -724,14 +724,15 @@ fn handle_backend_messages( message: Option>, manager: &ThreadSafeRequestManager, max_buffer_capacity_per_subscription: usize, -) -> Result, Error> { +) -> Result, Error> { // Handle raw messages of form `ReceivedMessage::Bytes` (Vec) or ReceivedMessage::Data` (String). fn handle_recv_message( raw: &[u8], manager: &ThreadSafeRequestManager, max_buffer_capacity_per_subscription: usize, - ) -> Result, Error> { + ) -> Result, Error> { let first_non_whitespace = raw.iter().find(|byte| !byte.is_ascii_whitespace()); + let mut messages = Vec::new(); match first_non_whitespace { Some(b'{') => { @@ -741,13 +742,13 @@ fn handle_backend_messages( process_single_response(&mut manager.lock(), single, max_buffer_capacity_per_subscription)?; if let Some(unsub) = maybe_unsub { - return Ok(Some(FrontToBack::Request(unsub))); + return Ok(vec![FrontToBack::Request(unsub)]); } } // Subscription response. else if let Ok(response) = serde_json::from_slice::>(raw) { if let Some(sub_id) = process_subscription_response(&mut manager.lock(), response) { - return Ok(Some(FrontToBack::SubscriptionClosed(sub_id))); + return Ok(vec![FrontToBack::SubscriptionClosed(sub_id)]); } } // Subscription error response. @@ -784,6 +785,14 @@ fn handle_backend_messages( if id > r.end { r.end = id; } + } else if let Ok(response) = serde_json::from_str::>(r.get()) { + got_notif = true; + if let Some(sub_id) = process_subscription_response(&mut manager.lock(), response) { + messages.push(FrontToBack::SubscriptionClosed(sub_id)); + } + } else if let Ok(response) = serde_json::from_slice::>(raw) { + got_notif = true; + process_subscription_close_response(&mut manager.lock(), response); } else if let Ok(notif) = serde_json::from_str::>(r.get()) { got_notif = true; process_notification(&mut manager.lock(), notif); @@ -808,13 +817,13 @@ fn handle_backend_messages( } }; - Ok(None) + Ok(messages) } match message { Some(Ok(ReceivedMessage::Pong)) => { tracing::debug!(target: LOG_TARGET, "Received pong"); - Ok(None) + Ok(vec![]) } Some(Ok(ReceivedMessage::Bytes(raw))) => { handle_recv_message(raw.as_ref(), manager, max_buffer_capacity_per_subscription) @@ -1036,14 +1045,15 @@ where let Some(msg) = maybe_msg else { break Ok(()) }; match handle_backend_messages::(Some(msg), &manager, max_buffer_capacity_per_subscription) { - Ok(Some(msg)) => { - pending_unsubscribes.push(to_send_task.send(msg)); + Ok(messages) => { + for msg in messages { + pending_unsubscribes.push(to_send_task.send(msg)); + } } Err(e) => { tracing::error!(target: LOG_TARGET, "Failed to read message: {e}"); break Err(e); } - Ok(None) => (), } } _ = inactivity_stream.next() => { diff --git a/test-utils/src/helpers.rs b/test-utils/src/helpers.rs index 16d7e8184d..5ec7bc05f5 100644 --- a/test-utils/src/helpers.rs +++ b/test-utils/src/helpers.rs @@ -174,9 +174,9 @@ pub fn server_subscription_id_response(id: Id) -> String { } /// Server response to a hardcoded pending subscription -pub fn server_subscription_response(result: Value) -> String { +pub fn server_subscription_response(method: &str, result: Value) -> String { format!( - r#"{{"jsonrpc":"2.0","method":"bar","params":{{"subscription":"D3wwzU6vvoUUYehv4qoFzq42DZnLoAETeFzeyk8swH4o","result":{}}}}}"#, + r#"{{"jsonrpc":"2.0","method":"{method}","params":{{"subscription":"D3wwzU6vvoUUYehv4qoFzq42DZnLoAETeFzeyk8swH4o","result":{}}}}}"#, serde_json::to_string(&result).unwrap() ) } @@ -186,11 +186,6 @@ pub fn server_notification(method: &str, params: Value) -> String { format!(r#"{{"jsonrpc":"2.0","method":"{}", "params":{} }}"#, method, serde_json::to_string(¶ms).unwrap()) } -/// Batched server originated notification -pub fn server_batched_notification(method: &str, params: Value) -> String { - format!(r#"[{{"jsonrpc":"2.0","method":"{}", "params":{} }}]"#, method, serde_json::to_string(¶ms).unwrap()) -} - pub async fn http_request(body: Body, uri: Uri) -> Result { let client = hyper::Client::new(); http_post(client, body, uri).await