diff --git a/benches/bench.rs b/benches/bench.rs index 6669e409ad..b5956bfe87 100644 --- a/benches/bench.rs +++ b/benches/bench.rs @@ -63,6 +63,14 @@ pub fn websocket_requests(crit: &mut Criterion) { run_concurrent_round_trip(&rt, crit, client.clone(), "ws_concurrent_round_trip"); } +pub fn batched_ws_requests(crit: &mut Criterion) { + let rt = TokioRuntime::new().unwrap(); + let url = rt.block_on(helpers::ws_server()); + let client = + Arc::new(rt.block_on(WsClientBuilder::default().max_concurrent_requests(1024 * 1024).build(&url)).unwrap()); + run_round_trip_with_batch(&rt, crit, client.clone(), "ws batch requests"); +} + fn run_round_trip(rt: &TokioRuntime, crit: &mut Criterion, client: Arc, name: &str) { crit.bench_function(name, |b| { b.iter(|| { diff --git a/http-server/src/module.rs b/http-server/src/module.rs index 65e4017411..0e9ad58e8e 100644 --- a/http-server/src/module.rs +++ b/http-server/src/module.rs @@ -1,6 +1,6 @@ use jsonrpsee_types::v2::error::{JsonRpcErrorCode, JsonRpcErrorObject, CALL_EXECUTION_FAILED_CODE}; use jsonrpsee_types::{ - error::{CallError, Error, InvalidParams}, + error::{CallError, Error}, traits::RpcMethod, v2::params::RpcParams, }; @@ -36,7 +36,7 @@ impl RpcModule { pub fn register_method(&mut self, method_name: &'static str, callback: F) -> Result<(), Error> where R: Serialize, - F: RpcMethod, + F: RpcMethod, { self.verify_method_name(method_name)?; @@ -45,7 +45,16 @@ impl RpcModule { Box::new(move |id, params, tx, _| { match callback(params) { Ok(res) => send_response(id, tx, res), - Err(InvalidParams) => send_error(id, tx, JsonRpcErrorCode::InvalidParams.into()), + Err(CallError::InvalidParams) => send_error(id, tx, JsonRpcErrorCode::InvalidParams.into()), + Err(CallError::Failed(err)) => { + log::error!("Call failed with: {}", err); + let err = JsonRpcErrorObject { + code: JsonRpcErrorCode::ServerError(CALL_EXECUTION_FAILED_CODE), + message: &err.to_string(), + data: None, + }; + send_error(id, tx, err) + } }; Ok(()) @@ -99,7 +108,7 @@ impl RpcContextModule { Box::new(move |id, params, tx, _| { match callback(params, &*ctx) { Ok(res) => send_response(id, tx, res), - Err(CallError::InvalidParams(_)) => send_error(id, tx, JsonRpcErrorCode::InvalidParams.into()), + Err(CallError::InvalidParams) => send_error(id, tx, JsonRpcErrorCode::InvalidParams.into()), Err(CallError::Failed(err)) => { let err = JsonRpcErrorObject { code: JsonRpcErrorCode::ServerError(CALL_EXECUTION_FAILED_CODE), diff --git a/http-server/src/server.rs b/http-server/src/server.rs index df4aa7fa6f..8de5ee8ed0 100644 --- a/http-server/src/server.rs +++ b/http-server/src/server.rs @@ -36,15 +36,14 @@ use hyper::{ service::{make_service_fn, service_fn}, Error as HyperError, }; -use jsonrpsee_types::error::{Error, GenericTransportError, InvalidParams}; +use jsonrpsee_types::error::{CallError, Error, GenericTransportError}; use jsonrpsee_types::v2::request::{JsonRpcInvalidRequest, JsonRpcRequest}; use jsonrpsee_types::v2::{error::JsonRpcErrorCode, params::RpcParams}; use jsonrpsee_utils::{ hyper_helpers::read_response_to_body, - server::{send_error, RpcSender}, + server::{collect_batch_response, send_error, RpcSender}, }; use serde::Serialize; -use serde_json::value::RawValue; use socket2::{Domain, Socket, Type}; use std::{ cmp, @@ -129,7 +128,7 @@ impl Server { pub fn register_method(&mut self, method_name: &'static str, callback: F) -> Result<(), Error> where R: Serialize, - F: Fn(RpcParams) -> Result + Send + Sync + 'static, + F: Fn(RpcParams) -> Result + Send + Sync + 'static, { self.root.register_method(method_name, callback) } @@ -162,23 +161,23 @@ impl Server { // Look up the "method" (i.e. function pointer) from the registered methods and run it passing in // the params from the request. The result of the computation is sent back over the `tx` channel and // the result(s) are collected into a `String` and sent back over the wire. - let execute = - move |id: Option<&RawValue>, tx: RpcSender, method_name: &str, params: Option<&RawValue>| { - if let Some(method) = methods.get(method_name) { - let params = RpcParams::new(params.map(|params| params.get())); - // NOTE(niklasad1): connection ID is unused thus hardcoded to `0`. - if let Err(err) = (method)(id, params, &tx, 0) { - log::error!( - "execution of method call '{}' failed: {:?}, request id={:?}", - method_name, - err, - id - ); - } - } else { - send_error(id, tx, JsonRpcErrorCode::MethodNotFound.into()); + let execute = move |tx: RpcSender, req: JsonRpcRequest| { + if let Some(method) = methods.get(&*req.method) { + let params = RpcParams::new(req.params.map(|params| params.get())); + // NOTE(niklasad1): connection ID is unused thus hardcoded to `0`. + if let Err(err) = (method)(req.id, params, &tx, 0) { + log::error!( + "execution of method call '{}' failed: {:?}, request id={:?}", + req.method, + err, + req.id + ); + send_error(req.id, &tx, JsonRpcErrorCode::ServerError(-1).into()); } - }; + } else { + send_error(req.id, &tx, JsonRpcErrorCode::MethodNotFound.into()); + } + }; // Run some validation on the http request, then read the body and try to deserialize it into one of // two cases: a single RPC request or a batch of RPC requests. @@ -203,7 +202,7 @@ impl Server { }; // NOTE(niklasad1): it's a channel because it's needed for batch requests. - let (tx, mut rx) = mpsc::unbounded(); + let (tx, mut rx) = mpsc::unbounded::(); // Is this a single request or a batch (or error)? let mut single = true; @@ -213,15 +212,13 @@ impl Server { // batch case and lastly the error. For the worst case – unparseable input – we make three calls // to [`serde_json::from_slice`] which is pretty annoying. // Our [issue](https://github.com/paritytech/jsonrpsee/issues/296). - if let Ok(JsonRpcRequest { id, method: method_name, params, .. }) = - serde_json::from_slice::(&body) - { - execute(id, &tx, &method_name, params); + if let Ok(req) = serde_json::from_slice::(&body) { + execute(&tx, req); } else if let Ok(batch) = serde_json::from_slice::>(&body) { if !batch.is_empty() { single = false; - for JsonRpcRequest { id, method: method_name, params, .. } in batch { - execute(id, &tx, &method_name, params); + for req in batch { + execute(&tx, req); } } else { send_error(None, &tx, JsonRpcErrorCode::InvalidRequest.into()); @@ -243,7 +240,7 @@ impl Server { let response = if single { rx.next().await.expect("Sender is still alive managed by us above; qed") } else { - collect_batch_responses(rx).await + collect_batch_response(rx).await }; log::debug!("[service_fn] sending back: {:?}", &response[..cmp::min(response.len(), 1024)]); Ok::<_, HyperError>(response::ok_response(response)) @@ -257,24 +254,6 @@ impl Server { } } -// Collect the results of all computations sent back on the ['Stream'] into a single `String` appropriately wrapped in -// `[`/`]`. -async fn collect_batch_responses(rx: mpsc::UnboundedReceiver) -> String { - let mut buf = String::with_capacity(2048); - buf.push('['); - let mut buf = rx - .fold(buf, |mut acc, response| async { - acc = [acc, response].concat(); - acc.push(','); - acc - }) - .await; - // Remove trailing comma - buf.pop(); - buf.push(']'); - buf -} - // Checks to that access control of the received request is the same as configured. fn access_control_is_valid( access_control: &AccessControl, diff --git a/test-utils/src/helpers.rs b/test-utils/src/helpers.rs index c0318c6750..a1b4fda70b 100644 --- a/test-utils/src/helpers.rs +++ b/test-utils/src/helpers.rs @@ -72,6 +72,13 @@ pub fn internal_error(id: Id) -> String { ) } +pub fn server_error(id: Id) -> String { + format!( + r#"{{"jsonrpc":"2.0","error":{{"code":-32000,"message":"Server error"}},"id":{}}}"#, + serde_json::to_string(&id).unwrap() + ) +} + /// Hardcoded server response when a client initiates a new subscription. /// /// NOTE: works only for one subscription because the subscription ID is hardcoded. diff --git a/types/src/error.rs b/types/src/error.rs index a76f2f169e..75212b3882 100644 --- a/types/src/error.rs +++ b/types/src/error.rs @@ -25,15 +25,15 @@ pub struct InvalidParams; pub enum CallError { #[error("Invalid params in the RPC call")] /// Invalid params in the call. - InvalidParams(InvalidParams), + InvalidParams, #[error("RPC Call failed: {0}")] /// The call failed. Failed(#[source] Box), } impl From for CallError { - fn from(params: InvalidParams) -> Self { - Self::InvalidParams(params) + fn from(_params: InvalidParams) -> Self { + Self::InvalidParams } } diff --git a/utils/src/server.rs b/utils/src/server.rs index e767561520..1f36c49dee 100644 --- a/utils/src/server.rs +++ b/utils/src/server.rs @@ -1,6 +1,7 @@ //! Shared helpers for JSON-RPC Servers. use futures_channel::mpsc; +use futures_util::stream::StreamExt; use jsonrpsee_types::v2::error::{JsonRpcError, JsonRpcErrorCode, JsonRpcErrorObject}; use jsonrpsee_types::v2::params::{RpcParams, TwoPointZero}; use jsonrpsee_types::v2::response::JsonRpcResponse; @@ -50,3 +51,21 @@ pub fn send_error(id: RpcId, tx: RpcSender, error: JsonRpcErrorObject) { log::error!("Error sending response to the client: {:?}", err) } } + +/// Read all the results of all method calls in a batch request from the ['Stream']. Format the result into a single +/// `String` appropriately wrapped in `[`/`]`. +pub async fn collect_batch_response(rx: mpsc::UnboundedReceiver) -> String { + let mut buf = String::with_capacity(2048); + buf.push('['); + let mut buf = rx + .fold(buf, |mut acc, response| async { + acc = [acc, response].concat(); + acc.push(','); + acc + }) + .await; + // Remove trailing comma + buf.pop(); + buf.push(']'); + buf +} diff --git a/ws-server/src/server.rs b/ws-server/src/server.rs index 03021c8841..109bd78223 100644 --- a/ws-server/src/server.rs +++ b/ws-server/src/server.rs @@ -26,6 +26,7 @@ use futures_channel::mpsc; use futures_util::io::{BufReader, BufWriter}; +use futures_util::stream::StreamExt; use parking_lot::Mutex; use rustc_hash::FxHashMap; use serde::Serialize; @@ -34,14 +35,14 @@ use soketto::handshake::{server::Response, Server as SokettoServer}; use std::net::SocketAddr; use std::sync::Arc; use tokio::net::{TcpListener, ToSocketAddrs}; -use tokio_stream::{wrappers::TcpListenerStream, StreamExt}; +use tokio_stream::wrappers::TcpListenerStream; use tokio_util::compat::TokioAsyncReadCompatExt; -use jsonrpsee_types::error::{Error, InvalidParams}; +use jsonrpsee_types::error::{CallError, Error}; use jsonrpsee_types::v2::error::JsonRpcErrorCode; use jsonrpsee_types::v2::params::{JsonRpcNotificationParams, RpcParams, TwoPointZero}; use jsonrpsee_types::v2::request::{JsonRpcInvalidRequest, JsonRpcNotification, JsonRpcRequest}; -use jsonrpsee_utils::server::{send_error, ConnectionId, Methods}; +use jsonrpsee_utils::server::{collect_batch_response, send_error, ConnectionId, Methods, RpcSender}; mod module; @@ -105,7 +106,7 @@ impl Server { pub fn register_method(&mut self, method_name: &'static str, callback: F) -> Result<(), Error> where R: Serialize, - F: Fn(RpcParams) -> Result + Send + Sync + 'static, + F: Fn(RpcParams) -> Result + Send + Sync + 'static, { self.root.register_method(method_name, callback) } @@ -149,7 +150,11 @@ impl Server { } } -async fn background_task(socket: tokio::net::TcpStream, methods: Arc, id: ConnectionId) -> anyhow::Result<()> { +async fn background_task( + socket: tokio::net::TcpStream, + methods: Arc, + conn_id: ConnectionId, +) -> anyhow::Result<()> { // For each incoming background_task we perform a handshake. let mut server = SokettoServer::new(BufReader::new(BufWriter::new(socket.compat()))); @@ -166,6 +171,7 @@ async fn background_task(socket: tokio::net::TcpStream, methods: Arc, i let (mut sender, mut receiver) = server.into_builder().finish(); let (tx, mut rx) = mpsc::unbounded::(); + // Send results back to the client. tokio::spawn(async move { while let Some(response) = rx.next().await { let _ = sender.send_binary_mut(response.into_bytes()).await; @@ -173,31 +179,61 @@ async fn background_task(socket: tokio::net::TcpStream, methods: Arc, i } }); - let mut data = Vec::new(); + let mut data = Vec::with_capacity(100); + + // Look up the "method" (i.e. function pointer) from the registered methods and run it passing in + // the params from the request. The result of the computation is sent back over the `tx` channel and + // the result(s) are collected into a `String` and sent back over the wire. + let execute = move |tx: RpcSender, req: JsonRpcRequest| { + if let Some(method) = methods.get(&*req.method) { + let params = RpcParams::new(req.params.map(|params| params.get())); + if let Err(err) = (method)(req.id, params, &tx, conn_id) { + log::error!("execution of method call '{}' failed: {:?}, request id={:?}", req.method, err, req.id); + send_error(req.id, &tx, JsonRpcErrorCode::ServerError(-1).into()); + } + } else { + send_error(req.id, &tx, JsonRpcErrorCode::MethodNotFound.into()); + } + }; loop { data.clear(); receiver.receive_data(&mut data).await?; - match serde_json::from_slice::(&data) { - Ok(req) => { - let params = RpcParams::new(req.params.map(|params| params.get())); - - if let Some(method) = methods.get(&*req.method) { - (method)(req.id, params, &tx, id)?; - } else { - send_error(req.id, &tx, JsonRpcErrorCode::MethodNotFound.into()); + // For reasons outlined [here](https://github.com/serde-rs/json/issues/497), `RawValue` can't be used with + // untagged enums at the moment. This means we can't use an `SingleOrBatch` untagged enum here and have to try + // each case individually: first the single request case, then the batch case and lastly the error. For the + // worst case – unparseable input – we make three calls to [`serde_json::from_slice`] which is pretty annoying. + // Our [issue](https://github.com/paritytech/jsonrpsee/issues/296). + if let Ok(req) = serde_json::from_slice::(&data) { + execute(&tx, req); + } else if let Ok(batch) = serde_json::from_slice::>(&data) { + if !batch.is_empty() { + // Batch responses must be sent back as a single message so we read the results from each request in the + // batch and read the results off of a new channel, `rx_batch`, and then send the complete batch response + // back to the client over `tx`. + let (tx_batch, mut rx_batch) = mpsc::unbounded::(); + for req in batch { + execute(&tx_batch, req); + } + // Closes the receiving half of a channel without dropping it. This prevents any further messages from + // being sent on the channel. + rx_batch.close(); + let results = collect_batch_response(rx_batch).await; + if let Err(err) = tx.unbounded_send(results) { + log::error!("Error sending batch response to the client: {:?}", err) } + } else { + send_error(None, &tx, JsonRpcErrorCode::InvalidRequest.into()); } - Err(_) => { - let (id, code) = match serde_json::from_slice::(&data) { - Ok(req) => (req.id, JsonRpcErrorCode::InvalidRequest), - Err(_) => (None, JsonRpcErrorCode::ParseError), - }; + } else { + let (id, code) = match serde_json::from_slice::(&data) { + Ok(req) => (req.id, JsonRpcErrorCode::InvalidRequest), + Err(_) => (None, JsonRpcErrorCode::ParseError), + }; - send_error(id, &tx, code.into()); - } + send_error(id, &tx, code.into()); } } } diff --git a/ws-server/src/server/module.rs b/ws-server/src/server/module.rs index 3dd6840566..bed04c5a3e 100644 --- a/ws-server/src/server/module.rs +++ b/ws-server/src/server/module.rs @@ -1,9 +1,9 @@ use crate::server::{RpcParams, SubscriptionId, SubscriptionSink}; -use jsonrpsee_types::{error::InvalidParams, traits::RpcMethod, v2::error::CALL_EXECUTION_FAILED_CODE}; use jsonrpsee_types::{ error::{CallError, Error}, v2::error::{JsonRpcErrorCode, JsonRpcErrorObject}, }; +use jsonrpsee_types::{traits::RpcMethod, v2::error::CALL_EXECUTION_FAILED_CODE}; use jsonrpsee_utils::server::{send_error, send_response, Methods}; use parking_lot::Mutex; use rustc_hash::FxHashMap; @@ -38,7 +38,7 @@ impl RpcModule { pub fn register_method(&mut self, method_name: &'static str, callback: F) -> Result<(), Error> where R: Serialize, - F: RpcMethod, + F: RpcMethod, { self.verify_method_name(method_name)?; @@ -47,7 +47,16 @@ impl RpcModule { Box::new(move |id, params, tx, _| { match callback(params) { Ok(res) => send_response(id, tx, res), - Err(InvalidParams) => send_error(id, tx, JsonRpcErrorCode::InvalidParams.into()), + Err(CallError::InvalidParams) => send_error(id, tx, JsonRpcErrorCode::InvalidParams.into()), + Err(CallError::Failed(err)) => { + log::error!("Call failed with: {}", err); + let err = JsonRpcErrorObject { + code: JsonRpcErrorCode::ServerError(CALL_EXECUTION_FAILED_CODE), + message: &err.to_string(), + data: None, + }; + send_error(id, tx, err) + } }; Ok(()) @@ -157,7 +166,7 @@ impl RpcContextModule { Box::new(move |id, params, tx, _| { match callback(params, &*ctx) { Ok(res) => send_response(id, tx, res), - Err(CallError::InvalidParams(_)) => send_error(id, tx, JsonRpcErrorCode::InvalidParams.into()), + Err(CallError::InvalidParams) => send_error(id, tx, JsonRpcErrorCode::InvalidParams.into()), Err(CallError::Failed(err)) => { let err = JsonRpcErrorObject { code: JsonRpcErrorCode::ServerError(CALL_EXECUTION_FAILED_CODE), diff --git a/ws-server/src/tests.rs b/ws-server/src/tests.rs index 9fb3b2c67a..7534ae2388 100644 --- a/ws-server/src/tests.rs +++ b/ws-server/src/tests.rs @@ -5,8 +5,19 @@ use jsonrpsee_test_utils::helpers::*; use jsonrpsee_test_utils::types::{Id, TestContext, WebSocketTestClient}; use jsonrpsee_types::error::{CallError, Error}; use serde_json::Value as JsonValue; +use std::fmt; use std::net::SocketAddr; +/// Applications can/should provide their own error. +#[derive(Debug)] +struct MyAppError; +impl fmt::Display for MyAppError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "MyAppError") + } +} +impl std::error::Error for MyAppError {} + /// Spawns a dummy `JSONRPC v2 WebSocket` /// It has two hardcoded methods: "say_hello" and "add" pub async fn server() -> SocketAddr { @@ -25,6 +36,15 @@ pub async fn server() -> SocketAddr { Ok(sum) }) .unwrap(); + server.register_method("invalid_params", |_params| Err::<(), _>(CallError::InvalidParams)).unwrap(); + server.register_method("call_fail", |_params| Err::<(), _>(CallError::Failed(Box::new(MyAppError)))).unwrap(); + server + .register_method("sleep_for", |params| { + let sleep: Vec = params.parse()?; + std::thread::sleep(std::time::Duration::from_millis(sleep[0])); + Ok("Yawn!") + }) + .unwrap(); let addr = server.local_addr().unwrap(); tokio::spawn(async { server.start().await }); @@ -61,7 +81,7 @@ pub async fn server_with_context() -> SocketAddr { } #[tokio::test] -async fn single_method_call_works() { +async fn single_method_calls_works() { let addr = server().await; let mut client = WebSocketTestClient::new(addr).await.unwrap(); @@ -73,6 +93,54 @@ async fn single_method_call_works() { } } +#[tokio::test] +async fn slow_method_calls_works() { + let addr = server().await; + let mut client = WebSocketTestClient::new(addr).await.unwrap(); + + let req = r#"{"jsonrpc":"2.0","method":"sleep_for","params":[1000],"id":123}"#; + let response = client.send_request_text(req).await.unwrap(); + + assert_eq!(response, ok_response(JsonValue::String("Yawn!".to_owned()), Id::Num(123))); +} + +#[tokio::test] +async fn batch_method_call_works() { + let addr = server().await; + let mut client = WebSocketTestClient::new(addr).await.unwrap(); + + let mut batch = Vec::new(); + batch.push(r#"{"jsonrpc":"2.0","method":"sleep_for","params":[1000],"id":123}"#.to_string()); + for i in 1..4 { + batch.push(format!(r#"{{"jsonrpc":"2.0","method":"say_hello","id":{}}}"#, i)); + } + let batch = format!("[{}]", batch.join(",")); + let response = client.send_request_text(batch).await.unwrap(); + assert_eq!( + response, + r#"[{"jsonrpc":"2.0","result":"Yawn!","id":123},{"jsonrpc":"2.0","result":"hello","id":1},{"jsonrpc":"2.0","result":"hello","id":2},{"jsonrpc":"2.0","result":"hello","id":3}]"# + ); +} + +#[tokio::test] +async fn batch_method_call_where_some_calls_fail() { + let addr = server().await; + let mut client = WebSocketTestClient::new(addr).await.unwrap(); + + let mut batch = Vec::new(); + batch.push(r#"{"jsonrpc":"2.0","method":"say_hello","id":1}"#); + batch.push(r#"{"jsonrpc":"2.0","method":"call_fail","id":2}"#); + batch.push(r#"{"jsonrpc":"2.0","method":"add","params":[34, 45],"id":3}"#); + let batch = format!("[{}]", batch.join(",")); + + let response = client.send_request_text(batch).await.unwrap(); + + assert_eq!( + response, + r#"[{"jsonrpc":"2.0","result":"hello","id":1},{"jsonrpc":"2.0","error":{"code":-32000,"message":"MyAppError"},"id":2},{"jsonrpc":"2.0","result":79,"id":3}]"# + ); +} + #[tokio::test] async fn single_method_call_with_params_works() { let addr = server().await; @@ -202,3 +270,24 @@ async fn invalid_request_should_not_close_connection() { let response = client.send_request_text(request).await.unwrap(); assert_eq!(response, ok_response(JsonValue::String("hello".to_owned()), Id::Num(33))); } + +#[tokio::test] +async fn valid_request_that_fails_to_execute_should_not_close_connection() { + let addr = server().await; + let mut client = WebSocketTestClient::new(addr).await.unwrap(); + + // Good request, executes fine + let request = r#"{"jsonrpc":"2.0","method":"say_hello","id":33}"#; + let response = client.send_request_text(request).await.unwrap(); + assert_eq!(response, ok_response(JsonValue::String("hello".to_owned()), Id::Num(33))); + + // Good request, but causes error. + let req = r#"{"jsonrpc":"2.0","method":"call_fail","params":[],"id":123}"#; + let response = client.send_request_text(req).await.unwrap(); + assert_eq!(response, r#"{"jsonrpc":"2.0","error":{"code":-32000,"message":"MyAppError"},"id":123}"#); + + // Connection is still good. + let request = r#"{"jsonrpc":"2.0","method":"say_hello","id":333}"#; + let response = client.send_request_text(request).await.unwrap(); + assert_eq!(response, ok_response(JsonValue::String("hello".to_owned()), Id::Num(333))); +}