From b987b81115d2934a72f4aa12a944328b0620c30a Mon Sep 17 00:00:00 2001 From: Niklas Adolfsson Date: Fri, 4 Jun 2021 11:49:09 +0200 Subject: [PATCH 01/21] [rpc module]: server-side close subscription (#355) * [rpc module]: server-side close subscription Add functionality that closes the subscription after the sink has been dropped. * [integration tests]: add timeout on futures * remove global subscriber mutex * fix nit * [client types]: fix #349 Subscription::next() propogate error when parsing the response fails * [client types]: fix #349 Subscription::next() propogate error when parsing the response fails * unify subscription and notification * rename again * send notif response when subscription is dropped * stray debug stuff * Update utils/src/server/rpc_module.rs Co-authored-by: David * Revert "[client types]: fix #349" This reverts commit c4fefade719ebd8c964a62d0eb16f89db3346ccd. * Revert "[integration tests]: add timeout on futures" This reverts commit 21dfb99649aaaa8c847a693b4510e0202498fa2c. * show that actual edge-case * fix nit * Update types/src/traits.rs * fix bad merge * ugly; but works * complete solution * get rid of Option * Update tests/tests/integration_tests.rs * Update utils/src/server/rpc_module.rs * Update utils/src/server/rpc_module.rs Co-authored-by: David * grumbles: fix faulty early return * remove weird abstraction KeepAlive * fix nits * revert test timeouts * address grumbles * fix build Co-authored-by: David --- examples/weather.rs | 2 +- examples/ws_sub_with_params.rs | 4 +- examples/ws_subscription.rs | 7 ++- tests/tests/helpers.rs | 16 ++++- tests/tests/integration_tests.rs | 15 +++++ types/src/client.rs | 18 ++++-- types/src/error.rs | 18 ++++++ utils/src/server/rpc_module.rs | 103 ++++++++++++++++++++++--------- 8 files changed, 143 insertions(+), 40 deletions(-) diff --git a/examples/weather.rs b/examples/weather.rs index fa9320ef5e..c7bab569e3 100644 --- a/examples/weather.rs +++ b/examples/weather.rs @@ -106,7 +106,7 @@ async fn run_server() -> anyhow::Result { let cx = Mutex::new(WeatherApiCx { api_client, last_weather }); let mut module = RpcModule::new(cx); module - .register_subscription("weather_sub", "weather_unsub", |params, sink, cx| { + .register_subscription("weather_sub", "weather_unsub", |params, mut sink, cx| { let params: (String, String) = params.parse()?; log::debug!(target: "server", "Subscribed with params={:?}", params); std::thread::spawn(move || loop { diff --git a/examples/ws_sub_with_params.rs b/examples/ws_sub_with_params.rs index fafd163296..3a56f9b363 100644 --- a/examples/ws_sub_with_params.rs +++ b/examples/ws_sub_with_params.rs @@ -56,7 +56,7 @@ async fn run_server() -> anyhow::Result { let mut server = WsServerBuilder::default().build("127.0.0.1:0").await?; let mut module = RpcModule::new(()); module - .register_subscription("sub_one_param", "unsub_one_param", |params, sink, _| { + .register_subscription("sub_one_param", "unsub_one_param", |params, mut sink, _| { let idx: usize = params.one()?; std::thread::spawn(move || loop { let _ = sink.send(&LETTERS.chars().nth(idx)); @@ -66,7 +66,7 @@ async fn run_server() -> anyhow::Result { }) .unwrap(); module - .register_subscription("sub_params_two", "unsub_params_two", |params, sink, _| { + .register_subscription("sub_params_two", "unsub_params_two", |params, mut sink, _| { let (one, two): (usize, usize) = params.parse()?; std::thread::spawn(move || loop { let _ = sink.send(&LETTERS[one..two].to_string()); diff --git a/examples/ws_subscription.rs b/examples/ws_subscription.rs index 7e9141813c..2942d06858 100644 --- a/examples/ws_subscription.rs +++ b/examples/ws_subscription.rs @@ -25,6 +25,7 @@ // DEALINGS IN THE SOFTWARE. use jsonrpsee::{ + types::Error, ws_client::{traits::SubscriptionClient, v2::params::JsonRpcParams, Subscription, WsClientBuilder}, ws_server::{RpcModule, WsServerBuilder}, }; @@ -55,9 +56,11 @@ async fn main() -> anyhow::Result<()> { async fn run_server() -> anyhow::Result { let mut server = WsServerBuilder::default().build("127.0.0.1:0").await?; let mut module = RpcModule::new(()); - module.register_subscription("subscribe_hello", "unsubscribe_hello", |_, sink, _| { + module.register_subscription("subscribe_hello", "unsubscribe_hello", |_, mut sink, _| { std::thread::spawn(move || loop { - sink.send(&"hello my friend").unwrap(); + if let Err(Error::SubscriptionClosed(_)) = sink.send(&"hello my friend") { + return; + } std::thread::sleep(std::time::Duration::from_secs(1)); }); Ok(()) diff --git a/tests/tests/helpers.rs b/tests/tests/helpers.rs index 788884cda0..2eaf8e1b5a 100644 --- a/tests/tests/helpers.rs +++ b/tests/tests/helpers.rs @@ -41,7 +41,7 @@ pub async fn websocket_server_with_subscription() -> SocketAddr { module.register_method("say_hello", |_, _| Ok("hello")).unwrap(); module - .register_subscription("subscribe_hello", "unsubscribe_hello", |_, sink, _| { + .register_subscription("subscribe_hello", "unsubscribe_hello", |_, mut sink, _| { std::thread::spawn(move || loop { let _ = sink.send(&"hello from subscription"); std::thread::sleep(Duration::from_millis(50)); @@ -51,7 +51,7 @@ pub async fn websocket_server_with_subscription() -> SocketAddr { .unwrap(); module - .register_subscription("subscribe_foo", "unsubscribe_foo", |_, sink, _| { + .register_subscription("subscribe_foo", "unsubscribe_foo", |_, mut sink, _| { std::thread::spawn(move || loop { let _ = sink.send(&1337); std::thread::sleep(Duration::from_millis(100)); @@ -61,7 +61,7 @@ pub async fn websocket_server_with_subscription() -> SocketAddr { .unwrap(); module - .register_subscription("subscribe_add_one", "unsubscribe_add_one", |params, sink, _| { + .register_subscription("subscribe_add_one", "unsubscribe_add_one", |params, mut sink, _| { let mut count: usize = params.one()?; std::thread::spawn(move || loop { count = count.wrapping_add(1); @@ -72,6 +72,16 @@ pub async fn websocket_server_with_subscription() -> SocketAddr { }) .unwrap(); + module + .register_subscription("subscribe_noop", "unsubscribe_noop", |_, mut sink, _| { + std::thread::spawn(move || { + std::thread::sleep(Duration::from_secs(1)); + sink.close("Server closed the stream because it was lazy".into()) + }); + Ok(()) + }) + .unwrap(); + server.register_module(module).unwrap(); rt.block_on(async move { server_started_tx.send(server.local_addr().unwrap()).unwrap(); diff --git a/tests/tests/integration_tests.rs b/tests/tests/integration_tests.rs index 2cf460da8b..47a5663496 100644 --- a/tests/tests/integration_tests.rs +++ b/tests/tests/integration_tests.rs @@ -241,3 +241,18 @@ async fn ws_unsubscribe_releases_request_slots() { let _: Subscription = client.subscribe("subscribe_hello", JsonRpcParams::NoParams, "unsubscribe_hello").await.unwrap(); } + +#[tokio::test] +async fn server_should_be_able_to_close_subscriptions() { + let server_addr = websocket_server_with_subscription().await; + let server_url = format!("ws://{}", server_addr); + + let client = WsClientBuilder::default().build(&server_url).await.unwrap(); + + let mut sub: Subscription = + client.subscribe("subscribe_noop", JsonRpcParams::NoParams, "unsubscribe_noop").await.unwrap(); + + let res = sub.next().await; + + assert!(matches!(res, Err(Error::SubscriptionClosed(_)))); +} diff --git a/types/src/client.rs b/types/src/client.rs index 8f4a87dac3..5793d83357 100644 --- a/types/src/client.rs +++ b/types/src/client.rs @@ -1,8 +1,8 @@ -use crate::{v2::params::SubscriptionId, Error}; +use crate::{error::SubscriptionClosedError, v2::params::SubscriptionId, Error}; use core::marker::PhantomData; use futures_channel::{mpsc, oneshot}; use futures_util::{future::FutureExt, sink::SinkExt, stream::StreamExt}; -use serde::de::DeserializeOwned; +use serde::{de::DeserializeOwned, Deserialize, Serialize}; use serde_json::Value as JsonValue; /// Subscription kind @@ -15,6 +15,15 @@ pub enum SubscriptionKind { Method(String), } +/// Internal type to detect whether a subscription response from +/// the server was a valid notification or should be treated as an error. +#[derive(Debug, Deserialize, Serialize)] +#[serde(untagged)] +enum NotifResponse { + Ok(Notif), + Err(SubscriptionClosedError), +} + /// Active subscription on the client. /// /// It will automatically unsubscribe in the [`Subscription::drop`] so no need to explicitly call @@ -124,8 +133,9 @@ where /// may happen if the channel becomes full or is dropped. pub async fn next(&mut self) -> Result, Error> { match self.notifs_rx.next().await { - Some(n) => match serde_json::from_value(n) { - Ok(parsed) => Ok(Some(parsed)), + Some(n) => match serde_json::from_value::>(n) { + Ok(NotifResponse::Ok(parsed)) => Ok(Some(parsed)), + Ok(NotifResponse::Err(e)) => Err(Error::SubscriptionClosed(e)), Err(e) => Err(e.into()), }, None => Ok(None), diff --git a/types/src/error.rs b/types/src/error.rs index 61cb4066b9..607e193fde 100644 --- a/types/src/error.rs +++ b/types/src/error.rs @@ -1,3 +1,4 @@ +use serde::{Deserialize, Serialize}; use std::fmt; /// Convenience type for displaying errors. @@ -68,6 +69,9 @@ pub enum Error { /// Subscribe and unsubscribe method names are the same. #[error("Cannot use the same method name for subscribe and unsubscribe, used: {0}")] SubscriptionNameConflict(String), + /// Subscription got closed. + #[error("Subscription closed: {0:?}")] + SubscriptionClosed(SubscriptionClosedError), /// Request timeout #[error("Request timeout")] RequestTimeout, @@ -79,6 +83,20 @@ pub enum Error { Custom(String), } +/// Error type with a special `subscription_closed` field to detect that +/// a subscription has been closed to distinguish valid items produced +/// by the server on the subscription stream from an error. +#[derive(Deserialize, Serialize, Debug)] +pub struct SubscriptionClosedError { + subscription_closed: String, +} + +impl From for SubscriptionClosedError { + fn from(msg: String) -> Self { + Self { subscription_closed: msg } + } +} + /// Generic transport error. #[derive(Debug, thiserror::Error)] pub enum GenericTransportError { diff --git a/utils/src/server/rpc_module.rs b/utils/src/server/rpc_module.rs index 2e9c4f9dcc..602e724fd8 100644 --- a/utils/src/server/rpc_module.rs +++ b/utils/src/server/rpc_module.rs @@ -1,7 +1,7 @@ use crate::server::helpers::{send_error, send_response}; use futures_channel::{mpsc, oneshot}; use futures_util::{future::BoxFuture, FutureExt}; -use jsonrpsee_types::error::{CallError, Error}; +use jsonrpsee_types::error::{CallError, Error, SubscriptionClosedError}; use jsonrpsee_types::v2::error::{JsonRpcErrorCode, JsonRpcErrorObject, CALL_EXECUTION_FAILED_CODE}; use jsonrpsee_types::v2::params::{Id, JsonRpcNotificationParams, OwnedId, OwnedRpcParams, RpcParams, TwoPointZero}; use jsonrpsee_types::v2::request::JsonRpcRequest; @@ -31,7 +31,14 @@ pub type SubscriptionId = u64; /// Sink that is used to send back the result to the server for a specific method. pub type MethodSink = mpsc::UnboundedSender; -type Subscribers = Arc)>>>; +type Subscribers = Arc)>>>; + +/// Represent a unique subscription entry based on [`SubscriptionId`] and [`ConnectionId`]. +#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)] +struct SubscriptionKey { + conn_id: ConnectionId, + sub_id: SubscriptionId, +} /// Callback wrapper that can be either sync or async. pub enum MethodCallback { @@ -134,13 +141,12 @@ impl Methods { pub struct RpcModule { ctx: Arc, methods: Methods, - subscribers: Subscribers, } impl RpcModule { /// Create a new module with a given shared `Context`. pub fn new(ctx: Context) -> Self { - Self { ctx: Arc::new(ctx), methods: Default::default(), subscribers: Default::default() } + Self { ctx: Arc::new(ctx), methods: Default::default() } } /// Register a new synchronous RPC method, which computes the response with the given callback. pub fn register_method(&mut self, method_name: &'static str, callback: F) -> Result<(), Error> @@ -228,7 +234,7 @@ impl RpcModule { /// use jsonrpsee_utils::server::rpc_module::RpcModule; /// /// let mut ctx = RpcModule::new(99_usize); - /// ctx.register_subscription("sub", "unsub", |params, sink, ctx| { + /// ctx.register_subscription("sub", "unsub", |params, mut sink, ctx| { /// let x: usize = params.one()?; /// std::thread::spawn(move || { /// let sum = x + (*ctx); @@ -255,17 +261,20 @@ impl RpcModule { self.methods.verify_method_name(unsubscribe_method_name)?; let ctx = self.ctx.clone(); + let subscribers = Subscribers::default(); + { - let subscribers = self.subscribers.clone(); + let subscribers = subscribers.clone(); self.methods.callbacks.insert( subscribe_method_name, - MethodCallback::Sync(Box::new(move |id, params, method_sink, conn| { - let (online_tx, online_rx) = oneshot::channel::<()>(); + MethodCallback::Sync(Box::new(move |id, params, method_sink, conn_id| { + let (conn_tx, conn_rx) = oneshot::channel::<()>(); let sub_id = { const JS_NUM_MASK: SubscriptionId = !0 >> 11; let sub_id = rand::random::() & JS_NUM_MASK; + let uniq_sub = SubscriptionKey { conn_id, sub_id }; - subscribers.lock().insert((conn, sub_id), (method_sink.clone(), online_rx)); + subscribers.lock().insert(uniq_sub, (method_sink.clone(), conn_rx)); sub_id }; @@ -274,8 +283,9 @@ impl RpcModule { let sink = SubscriptionSink { inner: method_sink.clone(), method: subscribe_method_name, - sub_id, - is_online: online_tx, + subscribers: subscribers.clone(), + uniq_sub: SubscriptionKey { conn_id, sub_id }, + is_connected: Some(conn_tx), }; callback(params, sink, ctx.clone()) })), @@ -283,13 +293,12 @@ impl RpcModule { } { - let subscribers = self.subscribers.clone(); self.methods.callbacks.insert( unsubscribe_method_name, - MethodCallback::Sync(Box::new(move |id, params, tx, conn| { + MethodCallback::Sync(Box::new(move |id, params, tx, conn_id| { let sub_id = params.one()?; - subscribers.lock().remove(&(conn, sub_id)); - send_response(id, tx, "Unsubscribed"); + subscribers.lock().remove(&SubscriptionKey { conn_id, sub_id }); + send_response(id, &tx, "Unsubscribed"); Ok(()) })), @@ -320,42 +329,80 @@ pub struct SubscriptionSink { inner: mpsc::UnboundedSender, /// Method. method: &'static str, - /// SubscriptionID, - sub_id: SubscriptionId, - /// Whether the subscriber is still alive (to avoid send messages that the subscriber is not interested in). - is_online: oneshot::Sender<()>, + /// Unique subscription. + uniq_sub: SubscriptionKey, + /// Shared Mutex of subscriptions for this method. + subscribers: Subscribers, + /// A type to track whether the subscription is active (the subscriber is connected). + /// + /// None - implies that the subscription as been closed. + is_connected: Option>, } impl SubscriptionSink { /// Send message on this subscription. - pub fn send(&self, result: &T) -> Result<(), Error> { + pub fn send(&mut self, result: &T) -> Result<(), Error> { let result = to_raw_value(result)?; self.send_raw_value(&result) } - fn send_raw_value(&self, result: &RawValue) -> Result<(), Error> { + fn send_raw_value(&mut self, result: &RawValue) -> Result<(), Error> { let msg = serde_json::to_string(&JsonRpcSubscriptionResponse { jsonrpc: TwoPointZero, method: self.method, - params: JsonRpcNotificationParams { subscription: self.sub_id, result: &*result }, + params: JsonRpcNotificationParams { subscription: self.uniq_sub.sub_id, result: &*result }, })?; self.inner_send(msg).map_err(Into::into) } - fn inner_send(&self, msg: String) -> Result<(), Error> { - if self.is_online() { - self.inner.unbounded_send(msg).map_err(|e| Error::Internal(e.into_send_error())) + fn inner_send(&mut self, msg: String) -> Result<(), Error> { + let res = if let Some(conn) = self.is_connected.as_ref() { + if !conn.is_canceled() { + // unbounded send only fails if the receiver has been dropped. + self.inner.unbounded_send(msg).map_err(|_| subscription_closed_by_client()) + } else { + Err(subscription_closed_by_client()) + } } else { - Err(Error::Custom("Subscription canceled".into())) + Err(subscription_closed_by_client()) + }; + + if let Err(e) = &res { + self.close(e.to_string()); } + + res } - fn is_online(&self) -> bool { - !self.is_online.is_canceled() + /// Close the subscription sink with a customized error message. + pub fn close(&mut self, close_reason: String) { + self.is_connected.take(); + if let Some((sink, _)) = self.subscribers.lock().remove(&self.uniq_sub) { + let result = + to_raw_value(&SubscriptionClosedError::from(close_reason)).expect("valid json infallible; qed"); + let msg = serde_json::to_string(&JsonRpcSubscriptionResponse { + jsonrpc: TwoPointZero, + method: self.method, + params: JsonRpcNotificationParams { subscription: self.uniq_sub.sub_id, result: &*result }, + }) + .expect("valid json infallible; qed"); + let _ = sink.unbounded_send(msg); + } } } +impl Drop for SubscriptionSink { + fn drop(&mut self) { + self.close(format!("Subscription: {} closed by the server", self.uniq_sub.sub_id)); + } +} + +fn subscription_closed_by_client() -> Error { + const CLOSE_REASON: &str = "Subscription closed by the client"; + Error::SubscriptionClosed(CLOSE_REASON.to_owned().into()) +} + #[cfg(test)] mod tests { use super::*; From a3feec75d7151c7a1f7a549096e5b5d63c78b2a2 Mon Sep 17 00:00:00 2001 From: Niklas Adolfsson Date: Fri, 4 Jun 2021 13:57:03 +0200 Subject: [PATCH 02/21] chore(scripts): publish script (#354) * chore(scripts): publish script * use script from jsonrpc --- .scripts/publish.sh | 96 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 96 insertions(+) create mode 100755 .scripts/publish.sh diff --git a/.scripts/publish.sh b/.scripts/publish.sh new file mode 100755 index 0000000000..4619d6c048 --- /dev/null +++ b/.scripts/publish.sh @@ -0,0 +1,96 @@ +#!/usr/bin/env bash +# +# This script is copied from `https://github.com/paritytech/jsonrpc` with some minor tweaks. + +set -eu + +ORDER=(types proc-macros utils http-client http-server ws-client ws-server jsonrpsee) + +function read_toml () { + NAME="" + VERSION="" + NAME=$(grep "^name" ./Cargo.toml | sed -e 's/.*"\(.*\)"/\1/') + VERSION=$(grep "^version" ./Cargo.toml | sed -e 's/.*"\(.*\)"/\1/') +} +function remote_version () { + REMOTE_VERSION="" + REMOTE_VERSION=$(cargo search "$NAME" | grep "^$NAME =" | sed -e 's/.*"\(.*\)".*/\1/') +} + +# First display the plan +for CRATE_DIR in ${ORDER[@]}; do + cd $CRATE_DIR > /dev/null + read_toml + echo "$NAME@$VERSION" + cd - > /dev/null +done + +read -p ">>>> Really publish?. Press [enter] to continue. " + +set -x + +cargo clean + +set +x + +# Then actually perform publishing. +for CRATE_DIR in ${ORDER[@]}; do + cd $CRATE_DIR > /dev/null + read_toml + remote_version + # Seems the latest version matches, skip by default. + if [ "$REMOTE_VERSION" = "$VERSION" ] || [[ "$REMOTE_VERSION" > "$VERSION" ]]; then + RET="" + echo "Seems like $NAME@$REMOTE_VERSION is already published. Continuing in 5s. " + read -t 5 -p ">>>> Type [r][enter] to retry, or [enter] to continue... " RET || true + if [ "$RET" != "r" ]; then + echo "Skipping $NAME@$VERSION" + cd - > /dev/null + continue + fi + fi + + # Attempt to publish (allow retries) + while : ; do + # give the user an opportunity to abort or skip before publishing + echo "🚀 Publishing $NAME@$VERSION..." + sleep 3 + + set +e && set -x + cargo publish $@ + RES=$? + set +x && set -e + # Check if it succeeded + if [ "$RES" != "0" ]; then + CHOICE="" + echo "##### Publishing $NAME failed" + read -p ">>>>> Type [s][enter] to skip, or [enter] to retry.. " CHOICE + if [ "$CHOICE" = "s" ]; then + break + fi + else + break + fi + done + + # Wait again to make sure that the new version is published and available. + echo "Waiting for $NAME@$VERSION to become available at the registry..." + while : ; do + sleep 3 + remote_version + if [ "$REMOTE_VERSION" = "$VERSION" ]; then + echo "🥳 $NAME@$VERSION published succesfully." + sleep 3 + break + else + echo "#### Got $NAME@$REMOTE_VERSION but expected $NAME@$VERSION. Retrying..." + fi + done + cd - > /dev/null +done + +echo "Tagging jsonrpsee@$VERSION" +set -x +git tag -a v$VERSION -m "Version $VERSION" +sleep 3 +git push --tags From b3a0748b5eeb72c7183ef6a3e431a9d6533b2a20 Mon Sep 17 00:00:00 2001 From: David Date: Fri, 4 Jun 2021 13:58:22 +0200 Subject: [PATCH 03/21] Release prep for v0.2 (#368) * Release prep * Mention proc macro limitations * Mention publish script --- CHANGELOG.md | 20 ++++++++++++++++++++ http-client/Cargo.toml | 6 +++--- http-server/Cargo.toml | 6 +++--- jsonrpsee/Cargo.toml | 16 ++++++++-------- proc-macros/Cargo.toml | 2 +- test-utils/Cargo.toml | 2 +- types/Cargo.toml | 2 +- utils/Cargo.toml | 4 ++-- ws-client/Cargo.toml | 4 ++-- ws-server/Cargo.toml | 6 +++--- 10 files changed, 44 insertions(+), 24 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 545cf7dff8..aa3d09b1e5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,3 +5,23 @@ The format is based on [Keep a Changelog]. [Keep a Changelog]: http://keepachangelog.com/en/1.0.0/ ## [Unreleased] + +## [v0.2.0] – 2021-06-04 + +[changed] The crate structure changed to several smaller crates, enabling users to pick and choose. The `jsonrpsee` crate works as a façade crate for users to pick&chose what components they wish to use. + +[changed] Starting with this release, the project is assuming `tokio` is the async executor. + +[changed] Revamped RPC subscription/method definition: users now provide closures when initializing the server and it is no longer possible to register new methods after the server started. + +[changed] Refactored the internals from the ground up. + +[added] Support for async methods + +[added] Support for batch requests (http/ws) + +[changed] the proc macros are currently limited to client side. + +[added] crate publication script + +## [v0.1.0] - 2020-02-28 diff --git a/http-client/Cargo.toml b/http-client/Cargo.toml index 691e96c030..f9959c5a26 100644 --- a/http-client/Cargo.toml +++ b/http-client/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "jsonrpsee-http-client" -version = "0.2.0-alpha.7" +version = "0.2.0" authors = ["Parity Technologies ", "Pierre Krieger "] description = "HTTP client for JSON-RPC" edition = "2018" @@ -15,8 +15,8 @@ hyper13-rustls = { package = "hyper-rustls", version = "0.21", optional = true } hyper14-rustls = { package = "hyper-rustls", version = "0.22", optional = true } hyper14 = { package = "hyper", version = "0.14", features = ["client", "http1", "http2", "tcp"], optional = true } hyper13 = { package = "hyper", version = "0.13", optional = true } -jsonrpsee-types = { path = "../types", version = "=0.2.0-alpha.7" } -jsonrpsee-utils = { path = "../utils", version = "=0.2.0-alpha.7", optional = true } +jsonrpsee-types = { path = "../types", version = "0.2.0" } +jsonrpsee-utils = { path = "../utils", version = "0.2.0", optional = true } log = "0.4" serde = { version = "1.0", default-features = false, features = ["derive"] } serde_json = "1.0" diff --git a/http-server/Cargo.toml b/http-server/Cargo.toml index 36b03df117..a5cf25f3dd 100644 --- a/http-server/Cargo.toml +++ b/http-server/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "jsonrpsee-http-server" -version = "0.2.0-alpha.7" +version = "0.2.0" authors = ["Parity Technologies ", "Pierre Krieger "] description = "HTTP server for JSON-RPC" edition = "2018" @@ -14,8 +14,8 @@ thiserror = "1" hyper = { version = "0.14", features = ["server", "http1", "http2", "tcp"] } futures-channel = "0.3.14" futures-util = { version = "0.3.14", default-features = false } -jsonrpsee-types = { path = "../types", version = "=0.2.0-alpha.7" } -jsonrpsee-utils = { path = "../utils", version = "=0.2.0-alpha.7", features = ["server", "hyper_14"] } +jsonrpsee-types = { path = "../types", version = "0.2.0" } +jsonrpsee-utils = { path = "../utils", version = "0.2.0", features = ["server", "hyper_14"] } globset = "0.4" lazy_static = "1.4" log = "0.4" diff --git a/jsonrpsee/Cargo.toml b/jsonrpsee/Cargo.toml index 8fb44b5654..a131b934d6 100644 --- a/jsonrpsee/Cargo.toml +++ b/jsonrpsee/Cargo.toml @@ -1,7 +1,7 @@ [package] name = "jsonrpsee" description = "JSON-RPC crate" -version = "0.2.0-alpha.7" +version = "0.2.0" authors = ["Parity Technologies ", "Pierre Krieger "] license = "MIT" edition = "2018" @@ -10,13 +10,13 @@ homepage = "https://github.com/paritytech/jsonrpsee" documentation = "https://docs.rs/jsonrpsee" [dependencies] -http-client = { path = "../http-client", version = "=0.2.0-alpha.7", package = "jsonrpsee-http-client", optional = true } -http-server = { path = "../http-server", version = "=0.2.0-alpha.7", package = "jsonrpsee-http-server", optional = true } -ws-client = { path = "../ws-client", version = "=0.2.0-alpha.7", package = "jsonrpsee-ws-client", optional = true } -ws-server = { path = "../ws-server", version = "=0.2.0-alpha.7", package = "jsonrpsee-ws-server", optional = true } -proc-macros = { path = "../proc-macros", version = "=0.2.0-alpha.7", package = "jsonrpsee-proc-macros", optional = true } -utils = { path = "../utils", version = "=0.2.0-alpha.7", package = "jsonrpsee-utils", optional = true } -types = { path = "../types", version = "=0.2.0-alpha.7", package = "jsonrpsee-types", optional = true } +http-client = { path = "../http-client", version = "0.2.0", package = "jsonrpsee-http-client", optional = true } +http-server = { path = "../http-server", version = "0.2.0", package = "jsonrpsee-http-server", optional = true } +ws-client = { path = "../ws-client", version = "0.2.0", package = "jsonrpsee-ws-client", optional = true } +ws-server = { path = "../ws-server", version = "0.2.0", package = "jsonrpsee-ws-server", optional = true } +proc-macros = { path = "../proc-macros", version = "0.2.0", package = "jsonrpsee-proc-macros", optional = true } +utils = { path = "../utils", version = "0.2.0", package = "jsonrpsee-utils", optional = true } +types = { path = "../types", version = "0.2.0", package = "jsonrpsee-types", optional = true } [features] client = ["http-client", "ws-client"] diff --git a/proc-macros/Cargo.toml b/proc-macros/Cargo.toml index c6bcb99fdb..93d1c71086 100644 --- a/proc-macros/Cargo.toml +++ b/proc-macros/Cargo.toml @@ -1,7 +1,7 @@ [package] name = "jsonrpsee-proc-macros" description = "Procedueral macros for jsonrpsee" -version = "0.2.0-alpha.7" +version = "0.2.0" authors = ["Parity Technologies ", "Pierre Krieger "] license = "MIT" edition = "2018" diff --git a/test-utils/Cargo.toml b/test-utils/Cargo.toml index f2a3653d24..9915300453 100644 --- a/test-utils/Cargo.toml +++ b/test-utils/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "jsonrpsee-test-utils" -version = "0.2.0-alpha.7" +version = "0.2.0" authors = ["Parity Technologies "] license = "MIT" edition = "2018" diff --git a/types/Cargo.toml b/types/Cargo.toml index 43f0184c95..31c2c32779 100644 --- a/types/Cargo.toml +++ b/types/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "jsonrpsee-types" -version = "0.2.0-alpha.7" +version = "0.2.0" authors = ["Parity Technologies "] description = "Shared types for jsonrpsee" edition = "2018" diff --git a/utils/Cargo.toml b/utils/Cargo.toml index bf77c0812b..c9cf1213d2 100644 --- a/utils/Cargo.toml +++ b/utils/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "jsonrpsee-utils" -version = "0.2.0-alpha.7" +version = "0.2.0" authors = ["Parity Technologies "] description = "Utilities for jsonrpsee" edition = "2018" @@ -12,7 +12,7 @@ futures-channel = { version = "0.3.14", default-features = false, optional = tru futures-util = { version = "0.3.14", default-features = false, optional = true } hyper13 = { package = "hyper", version = "0.13", default-features = false, features = ["stream"], optional = true } hyper14 = { package = "hyper", version = "0.14", default-features = false, features = ["stream"], optional = true } -jsonrpsee-types = { path = "../types", version = "=0.2.0-alpha.7", optional = true } +jsonrpsee-types = { path = "../types", version = "0.2.0", optional = true } log = { version = "0.4", optional = true } rustc-hash = { version = "1", optional = true } rand = { version = "0.8", optional = true } diff --git a/ws-client/Cargo.toml b/ws-client/Cargo.toml index 8ce542c04f..92c7256156 100644 --- a/ws-client/Cargo.toml +++ b/ws-client/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "jsonrpsee-ws-client" -version = "0.2.0-alpha.7" +version = "0.2.0" authors = ["Parity Technologies ", "Pierre Krieger "] description = "WebSocket client for JSON-RPC" edition = "2018" @@ -23,7 +23,7 @@ tokioV02-util = { package="tokio-util", version = "0.3", features = ["compat"], async-trait = "0.1" fnv = "1" futures = { version = "0.3.14", default-features = false, features = ["std"] } -jsonrpsee-types = { path = "../types", version = "=0.2.0-alpha.7" } +jsonrpsee-types = { path = "../types", version = "0.2.0" } log = "0.4" serde = "1" serde_json = "1" diff --git a/ws-server/Cargo.toml b/ws-server/Cargo.toml index 0e4b245dce..20ad2a1619 100644 --- a/ws-server/Cargo.toml +++ b/ws-server/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "jsonrpsee-ws-server" -version = "0.2.0-alpha.7" +version = "0.2.0" authors = ["Parity Technologies ", "Pierre Krieger "] description = "WebSocket server for JSON-RPC" edition = "2018" @@ -13,8 +13,8 @@ documentation = "https://docs.rs/jsonrpsee-ws-server" thiserror = "1" futures-channel = "0.3.14" futures-util = { version = "0.3.14", default-features = false, features = ["io"] } -jsonrpsee-types = { path = "../types", version = "=0.2.0-alpha.7" } -jsonrpsee-utils = { path = "../utils", version = "=0.2.0-alpha.7", features = ["server"] } +jsonrpsee-types = { path = "../types", version = "0.2.0" } +jsonrpsee-utils = { path = "../utils", version = "0.2.0", features = ["server"] } log = "0.4" rustc-hash = "1.1.0" serde = { version = "1", default-features = false, features = ["derive"] } From 3f804de16c4ce07cd40b8677cb610d418afaa8b6 Mon Sep 17 00:00:00 2001 From: David Date: Fri, 4 Jun 2021 14:19:59 +0200 Subject: [PATCH 04/21] Add missing `rt` feature (#369) * Add missing `rt` feature * Use rt-multi-thread actually * More feature flag foo --- ws-client/Cargo.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ws-client/Cargo.toml b/ws-client/Cargo.toml index 92c7256156..3a52d54bb7 100644 --- a/ws-client/Cargo.toml +++ b/ws-client/Cargo.toml @@ -11,12 +11,12 @@ documentation = "https://docs.rs/jsonrpsee-ws-client" [dependencies] # Tokio v1 deps -tokioV1 = { package="tokio", version = "1", features = ["net", "time"], optional=true } +tokioV1 = { package="tokio", version = "1", features = ["net", "time", "rt-multi-thread"], optional=true } tokioV1-rustls = { package="tokio-rustls", version = "0.22", optional=true } tokioV1-util = { package="tokio-util", version = "0.6", features = ["compat"], optional=true } # Tokio v0.2 deps -tokioV02 = { package="tokio", version = "0.2", features = ["net", "time"], optional=true } +tokioV02 = { package="tokio", version = "0.2", features = ["net", "time", "rt-threaded", "sync"], optional=true } tokioV02-rustls = { package="tokio-rustls", version = "0.15", optional=true } tokioV02-util = { package="tokio-util", version = "0.3", features = ["compat"], optional=true } From d5ba2bd8d3bba846da56774fdc1955e9bea2a1ed Mon Sep 17 00:00:00 2001 From: Maciej Hirsz <1096222+maciejhirsz@users.noreply.github.com> Date: Mon, 7 Jun 2021 10:51:12 +0200 Subject: [PATCH 05/21] Concat -> simple push (#370) --- utils/src/server/helpers.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/utils/src/server/helpers.rs b/utils/src/server/helpers.rs index 178252de63..f69f19f7f7 100644 --- a/utils/src/server/helpers.rs +++ b/utils/src/server/helpers.rs @@ -44,8 +44,8 @@ pub async fn collect_batch_response(rx: mpsc::UnboundedReceiver) -> Stri let mut buf = String::with_capacity(2048); buf.push('['); let mut buf = rx - .fold(buf, |mut acc, response| async { - acc = [acc, response].concat(); + .fold(buf, |mut acc, response| async move { + acc.push_str(&response); acc.push(','); acc }) From 67e7c3db7305321b59ec408e9b65d26f24f0a3e8 Mon Sep 17 00:00:00 2001 From: Igor Aleksanov Date: Tue, 8 Jun 2021 11:27:07 +0400 Subject: [PATCH 06/21] Fix link to ws server in README.md (#373) * Fix link to ws server in README.md * Fix http client as well --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index bd561b3ac5..ee18b0a3e3 100644 --- a/README.md +++ b/README.md @@ -11,11 +11,11 @@ Support `WebSocket` and `HTTP` transports for both client and server. The library is still under development; do not use in production. ## Sub-projects -- [jsonrpsee-http-client](./http-client) [![crates.io][ws-client-image]][ws-client-url] +- [jsonrpsee-http-client](./http-client) [![crates.io][http-client-image]][http-client-url] - [jsonrpsee-http-server](./http-server) [![crates.io][http-server-image]][http-server-url] - [jsonrpsee-proc-macros](./proc-macros) [![crates.io][proc-macros-image]][proc-macros-url] - [jsonrpsee-ws-client](./ws-client) [![crates.io][ws-client-image]][ws-client-url] -- [jsonrpsee-ws-server](./http-server) [![crates.io][http-server-image]][http-server-url] +- [jsonrpsee-ws-server](./http-server) [![crates.io][ws-server-image]][ws-server-url] [http-client-image]: https://img.shields.io/crates/v/jsonrpsee-http-client.svg [http-client-url]: https://crates.io/crates/jsonrpsee-http-client From 9a02c10a311c36185b13d2d8d71d37ee44c16c0b Mon Sep 17 00:00:00 2001 From: Niklas Adolfsson Date: Thu, 10 Jun 2021 12:23:12 +0200 Subject: [PATCH 07/21] send text (#374) --- ws-server/src/server.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ws-server/src/server.rs b/ws-server/src/server.rs index 6e6b112d61..e5f0b7efe4 100644 --- a/ws-server/src/server.rs +++ b/ws-server/src/server.rs @@ -122,7 +122,7 @@ async fn background_task( tokio::spawn(async move { while let Some(response) = rx.next().await { log::debug!("send: {}", response); - let _ = sender.send_binary_mut(response.into_bytes()).await; + let _ = sender.send_text(response).await; let _ = sender.flush().await; } }); From ca11d1264c75e7c8b6e9a4beeea03f7be1f332fe Mon Sep 17 00:00:00 2001 From: Igor Aleksanov Date: Mon, 14 Jun 2021 11:53:42 +0400 Subject: [PATCH 08/21] Async/subscription benches (#372) * Add benches for async methods * Benches for subscriptions --- benches/Cargo.toml | 3 +- benches/bench.rs | 194 ++++++++++++++++++++++++++++++++++++--------- benches/helpers.rs | 20 ++++- 3 files changed, 178 insertions(+), 39 deletions(-) diff --git a/benches/Cargo.toml b/benches/Cargo.toml index 394f91b67e..f5e6e12285 100644 --- a/benches/Cargo.toml +++ b/benches/Cargo.toml @@ -9,7 +9,8 @@ publish = false [dev-dependencies] criterion = "0.3" -futures-channel = "0.3.14" +futures-channel = "0.3.15" +futures-util = "0.3.15" jsonrpsee = { path = "../jsonrpsee", features = ["full"] } num_cpus = "1" serde_json = "1" diff --git a/benches/bench.rs b/benches/bench.rs index 1f00fdcd98..84a077f934 100644 --- a/benches/bench.rs +++ b/benches/bench.rs @@ -1,4 +1,5 @@ use criterion::*; +use helpers::{SUB_METHOD_NAME, UNSUB_METHOD_NAME}; use jsonrpsee::{ http_client::{ traits::Client, @@ -6,6 +7,7 @@ use jsonrpsee::{ v2::request::JsonRpcCallSer, HttpClientBuilder, }, + types::traits::SubscriptionClient, ws_client::WsClientBuilder, }; use std::sync::Arc; @@ -13,8 +15,44 @@ use tokio::runtime::Runtime as TokioRuntime; mod helpers; -criterion_group!(benches, http_requests, batched_http_requests, websocket_requests, jsonrpsee_types_v2); -criterion_main!(benches); +criterion_group!(types_benches, jsonrpsee_types_v2); +criterion_group!( + sync_benches, + SyncBencher::http_requests, + SyncBencher::batched_http_requests, + SyncBencher::websocket_requests +); +criterion_group!( + async_benches, + AsyncBencher::http_requests, + AsyncBencher::batched_http_requests, + AsyncBencher::websocket_requests +); +criterion_group!(subscriptions, AsyncBencher::subscriptions); +criterion_main!(types_benches, sync_benches, async_benches, subscriptions); + +#[derive(Debug, Clone, Copy)] +enum RequestType { + Sync, + Async, +} + +impl RequestType { + fn method_name(self) -> &'static str { + match self { + RequestType::Sync => crate::helpers::SYNC_METHOD_NAME, + RequestType::Async => crate::helpers::ASYNC_METHOD_NAME, + } + } + + fn group_name(self, name: &str) -> String { + let request_type_name = match self { + RequestType::Sync => "sync", + RequestType::Async => "async", + }; + format!("{}/{}", request_type_name, name) + } +} fn v2_serialize(req: JsonRpcCallSer<'_>) -> String { serde_json::to_string(&req).unwrap() @@ -39,53 +77,135 @@ pub fn jsonrpsee_types_v2(crit: &mut Criterion) { }); } -pub fn http_requests(crit: &mut Criterion) { - let rt = TokioRuntime::new().unwrap(); - let url = rt.block_on(helpers::http_server()); - let client = Arc::new(HttpClientBuilder::default().build(&url).unwrap()); - run_round_trip(&rt, crit, client.clone(), "http_round_trip"); - run_concurrent_round_trip(&rt, crit, client, "http_concurrent_round_trip"); -} +trait RequestBencher { + const REQUEST_TYPE: RequestType; -pub fn batched_http_requests(crit: &mut Criterion) { - let rt = TokioRuntime::new().unwrap(); - let url = rt.block_on(helpers::http_server()); - let client = Arc::new(HttpClientBuilder::default().build(&url).unwrap()); - run_round_trip_with_batch(&rt, crit, client, "http batch requests"); + fn http_requests(crit: &mut Criterion) { + let rt = TokioRuntime::new().unwrap(); + let url = rt.block_on(helpers::http_server()); + let client = Arc::new(HttpClientBuilder::default().build(&url).unwrap()); + run_round_trip(&rt, crit, client.clone(), "http_round_trip", Self::REQUEST_TYPE); + run_concurrent_round_trip(&rt, crit, client, "http_concurrent_round_trip", Self::REQUEST_TYPE); + } + + fn batched_http_requests(crit: &mut Criterion) { + let rt = TokioRuntime::new().unwrap(); + let url = rt.block_on(helpers::http_server()); + let client = Arc::new(HttpClientBuilder::default().build(&url).unwrap()); + run_round_trip_with_batch(&rt, crit, client, "http batch requests", Self::REQUEST_TYPE); + } + + fn websocket_requests(crit: &mut Criterion) { + let rt = TokioRuntime::new().unwrap(); + let url = rt.block_on(helpers::ws_server()); + let client = + Arc::new(rt.block_on(WsClientBuilder::default().max_concurrent_requests(1024 * 1024).build(&url)).unwrap()); + run_round_trip(&rt, crit, client.clone(), "ws_round_trip", Self::REQUEST_TYPE); + run_concurrent_round_trip(&rt, crit, client, "ws_concurrent_round_trip", Self::REQUEST_TYPE); + } + + fn batched_ws_requests(crit: &mut Criterion) { + let rt = TokioRuntime::new().unwrap(); + let url = rt.block_on(helpers::ws_server()); + let client = + Arc::new(rt.block_on(WsClientBuilder::default().max_concurrent_requests(1024 * 1024).build(&url)).unwrap()); + run_round_trip_with_batch(&rt, crit, client, "ws batch requests", Self::REQUEST_TYPE); + } + + fn subscriptions(crit: &mut Criterion) { + let rt = TokioRuntime::new().unwrap(); + let url = rt.block_on(helpers::ws_server()); + let client = + Arc::new(rt.block_on(WsClientBuilder::default().max_concurrent_requests(1024 * 1024).build(&url)).unwrap()); + run_sub_round_trip(&rt, crit, client, "subscriptions"); + } } -pub fn websocket_requests(crit: &mut Criterion) { - let rt = TokioRuntime::new().unwrap(); - let url = rt.block_on(helpers::ws_server()); - let client = - Arc::new(rt.block_on(WsClientBuilder::default().max_concurrent_requests(1024 * 1024).build(&url)).unwrap()); - run_round_trip(&rt, crit, client.clone(), "ws_round_trip"); - run_concurrent_round_trip(&rt, crit, client, "ws_concurrent_round_trip"); +pub struct SyncBencher; + +impl RequestBencher for SyncBencher { + const REQUEST_TYPE: RequestType = RequestType::Sync; } +pub struct AsyncBencher; -pub fn batched_ws_requests(crit: &mut Criterion) { - let rt = TokioRuntime::new().unwrap(); - let url = rt.block_on(helpers::ws_server()); - let client = - Arc::new(rt.block_on(WsClientBuilder::default().max_concurrent_requests(1024 * 1024).build(&url)).unwrap()); - run_round_trip_with_batch(&rt, crit, client, "ws batch requests"); +impl RequestBencher for AsyncBencher { + const REQUEST_TYPE: RequestType = RequestType::Async; } -fn run_round_trip(rt: &TokioRuntime, crit: &mut Criterion, client: Arc, name: &str) { - crit.bench_function(name, |b| { +fn run_round_trip(rt: &TokioRuntime, crit: &mut Criterion, client: Arc, name: &str, request: RequestType) { + crit.bench_function(&request.group_name(name), |b| { b.iter(|| { rt.block_on(async { - black_box(client.request::("say_hello", JsonRpcParams::NoParams).await.unwrap()); + black_box(client.request::(request.method_name(), JsonRpcParams::NoParams).await.unwrap()); }) }) }); } -/// Benchmark http batch requests over batch sizes of 2, 5, 10, 50 and 100 RPCs in each batch. -fn run_round_trip_with_batch(rt: &TokioRuntime, crit: &mut Criterion, client: Arc, name: &str) { +fn run_sub_round_trip(rt: &TokioRuntime, crit: &mut Criterion, client: Arc, name: &str) { let mut group = crit.benchmark_group(name); + group.bench_function("subscribe", |b| { + b.iter_with_large_drop(|| { + rt.block_on(async { + black_box( + client + .subscribe::(SUB_METHOD_NAME, JsonRpcParams::NoParams, UNSUB_METHOD_NAME) + .await + .unwrap(), + ); + }) + }) + }); + group.bench_function("subscribe_response", |b| { + b.iter_with_setup( + || { + rt.block_on(async { + client + .subscribe::(SUB_METHOD_NAME, JsonRpcParams::NoParams, UNSUB_METHOD_NAME) + .await + .unwrap() + }) + }, + |mut sub| { + rt.block_on(async { black_box(sub.next().await.unwrap()) }); + // Note that this benchmark will include costs for measuring `drop` for subscription, + // since it's not possible to combine both `iter_with_setup` and `iter_with_large_drop`. + // To estimate pure cost of method, one should subtract the result of `unsub` bench + // from this one. + }, + ) + }); + group.bench_function("unsub", |b| { + b.iter_with_setup( + || { + rt.block_on(async { + client + .subscribe::(SUB_METHOD_NAME, JsonRpcParams::NoParams, UNSUB_METHOD_NAME) + .await + .unwrap() + }) + }, + |sub| { + // Subscription will be closed inside of the drop impl. + // Actually, it just sends a notification about object being closed, + // but it's still important to know that drop impl is not too expensive. + drop(black_box(sub)); + }, + ) + }); +} + +/// Benchmark http batch requests over batch sizes of 2, 5, 10, 50 and 100 RPCs in each batch. +fn run_round_trip_with_batch( + rt: &TokioRuntime, + crit: &mut Criterion, + client: Arc, + name: &str, + request: RequestType, +) { + let mut group = crit.benchmark_group(request.group_name(name)); for batch_size in [2, 5, 10, 50, 100usize].iter() { - let batch = vec![("say_hello", JsonRpcParams::NoParams); *batch_size]; + let batch = vec![(request.method_name(), JsonRpcParams::NoParams); *batch_size]; group.throughput(Throughput::Elements(*batch_size as u64)); group.bench_with_input(BenchmarkId::from_parameter(batch_size), batch_size, |b, _| { b.iter(|| rt.block_on(async { client.batch_request::(batch.clone()).await.unwrap() })) @@ -99,8 +219,9 @@ fn run_concurrent_round_trip( crit: &mut Criterion, client: Arc, name: &str, + request: RequestType, ) { - let mut group = crit.benchmark_group(name); + let mut group = crit.benchmark_group(request.group_name(name)); for num_concurrent_tasks in helpers::concurrent_tasks() { group.bench_function(format!("{}", num_concurrent_tasks), |b| { b.iter(|| { @@ -108,8 +229,9 @@ fn run_concurrent_round_trip( for _ in 0..num_concurrent_tasks { let client_rc = client.clone(); let task = rt.spawn(async move { - let _ = - black_box(client_rc.request::("say_hello", JsonRpcParams::NoParams).await.unwrap()); + let _ = black_box( + client_rc.request::(request.method_name(), JsonRpcParams::NoParams).await.unwrap(), + ); }); tasks.push(task); } diff --git a/benches/helpers.rs b/benches/helpers.rs index 3dc4e4c5a2..64e7da54d9 100644 --- a/benches/helpers.rs +++ b/benches/helpers.rs @@ -1,9 +1,15 @@ use futures_channel::oneshot; +use futures_util::future::FutureExt; use jsonrpsee::{ http_server::HttpServerBuilder, ws_server::{RpcModule, WsServerBuilder}, }; +pub(crate) const SYNC_METHOD_NAME: &str = "say_hello"; +pub(crate) const ASYNC_METHOD_NAME: &str = "say_hello_async"; +pub(crate) const SUB_METHOD_NAME: &str = "sub"; +pub(crate) const UNSUB_METHOD_NAME: &str = "unsub"; + /// Run jsonrpsee HTTP server for benchmarks. pub async fn http_server() -> String { let (server_started_tx, server_started_rx) = oneshot::channel(); @@ -11,7 +17,8 @@ pub async fn http_server() -> String { let mut server = HttpServerBuilder::default().max_request_body_size(u32::MAX).build("127.0.0.1:0".parse().unwrap()).unwrap(); let mut module = RpcModule::new(()); - module.register_method("say_hello", |_, _| Ok("lo")).unwrap(); + module.register_method(SYNC_METHOD_NAME, |_, _| Ok("lo")).unwrap(); + module.register_async_method(ASYNC_METHOD_NAME, |_, _| (async { Ok("lo") }).boxed()).unwrap(); server.register_module(module).unwrap(); server_started_tx.send(server.local_addr().unwrap()).unwrap(); server.start().await @@ -25,7 +32,16 @@ pub async fn ws_server() -> String { tokio::spawn(async move { let mut server = WsServerBuilder::default().build("127.0.0.1:0").await.unwrap(); let mut module = RpcModule::new(()); - module.register_method("say_hello", |_, _| Ok("lo")).unwrap(); + module.register_method(SYNC_METHOD_NAME, |_, _| Ok("lo")).unwrap(); + module.register_async_method(ASYNC_METHOD_NAME, |_, _| (async { Ok("lo") }).boxed()).unwrap(); + module + .register_subscription(SUB_METHOD_NAME, UNSUB_METHOD_NAME, |_params, mut sink, _ctx| { + let x = "Hello"; + tokio::spawn(async move { sink.send(&x) }); + Ok(()) + }) + .unwrap(); + server.register_module(module).unwrap(); server_started_tx.send(server.local_addr().unwrap()).unwrap(); server.start().await From 82b161424e47704048f77c909ac2f766ff6600be Mon Sep 17 00:00:00 2001 From: Igor Aleksanov Date: Wed, 16 Jun 2021 15:27:50 +0400 Subject: [PATCH 09/21] Use criterion's async bencher (#385) * Use criterion's async bencher * Rewrite concurrent roundtrip in functional style --- benches/Cargo.toml | 2 +- benches/bench.rs | 68 ++++++++++++++++++++++------------------------ 2 files changed, 33 insertions(+), 37 deletions(-) diff --git a/benches/Cargo.toml b/benches/Cargo.toml index f5e6e12285..7c8c15c339 100644 --- a/benches/Cargo.toml +++ b/benches/Cargo.toml @@ -8,7 +8,7 @@ license = "MIT" publish = false [dev-dependencies] -criterion = "0.3" +criterion = { version = "0.3", features = ["async_tokio", "html_reports"] } futures-channel = "0.3.15" futures-util = "0.3.15" jsonrpsee = { path = "../jsonrpsee", features = ["full"] } diff --git a/benches/bench.rs b/benches/bench.rs index 84a077f934..3c450373e7 100644 --- a/benches/bench.rs +++ b/benches/bench.rs @@ -1,4 +1,5 @@ use criterion::*; +use futures_util::future::join_all; use helpers::{SUB_METHOD_NAME, UNSUB_METHOD_NAME}; use jsonrpsee::{ http_client::{ @@ -134,10 +135,8 @@ impl RequestBencher for AsyncBencher { fn run_round_trip(rt: &TokioRuntime, crit: &mut Criterion, client: Arc, name: &str, request: RequestType) { crit.bench_function(&request.group_name(name), |b| { - b.iter(|| { - rt.block_on(async { - black_box(client.request::(request.method_name(), JsonRpcParams::NoParams).await.unwrap()); - }) + b.to_async(rt).iter(|| async { + black_box(client.request::(request.method_name(), JsonRpcParams::NoParams).await.unwrap()); }) }); } @@ -145,29 +144,28 @@ fn run_round_trip(rt: &TokioRuntime, crit: &mut Criterion, client: Arc, name: &str) { let mut group = crit.benchmark_group(name); group.bench_function("subscribe", |b| { - b.iter_with_large_drop(|| { - rt.block_on(async { - black_box( - client - .subscribe::(SUB_METHOD_NAME, JsonRpcParams::NoParams, UNSUB_METHOD_NAME) - .await - .unwrap(), - ); - }) + b.to_async(rt).iter_with_large_drop(|| async { + black_box( + client.subscribe::(SUB_METHOD_NAME, JsonRpcParams::NoParams, UNSUB_METHOD_NAME).await.unwrap(), + ); }) }); group.bench_function("subscribe_response", |b| { - b.iter_with_setup( + b.to_async(rt).iter_with_setup( || { - rt.block_on(async { - client - .subscribe::(SUB_METHOD_NAME, JsonRpcParams::NoParams, UNSUB_METHOD_NAME) - .await - .unwrap() + // We have to use `block_in_place` here since `b.to_async(rt)` automatically enters the + // runtime context and simply calling `block_on` here will cause the code to panic. + tokio::task::block_in_place(|| { + tokio::runtime::Handle::current().block_on(async { + client + .subscribe::(SUB_METHOD_NAME, JsonRpcParams::NoParams, UNSUB_METHOD_NAME) + .await + .unwrap() + }) }) }, - |mut sub| { - rt.block_on(async { black_box(sub.next().await.unwrap()) }); + |mut sub| async move { + black_box(sub.next().await.unwrap()); // Note that this benchmark will include costs for measuring `drop` for subscription, // since it's not possible to combine both `iter_with_setup` and `iter_with_large_drop`. // To estimate pure cost of method, one should subtract the result of `unsub` bench @@ -208,7 +206,7 @@ fn run_round_trip_with_batch( let batch = vec![(request.method_name(), JsonRpcParams::NoParams); *batch_size]; group.throughput(Throughput::Elements(*batch_size as u64)); group.bench_with_input(BenchmarkId::from_parameter(batch_size), batch_size, |b, _| { - b.iter(|| rt.block_on(async { client.batch_request::(batch.clone()).await.unwrap() })) + b.to_async(rt).iter(|| async { client.batch_request::(batch.clone()).await.unwrap() }) }); } group.finish(); @@ -224,21 +222,19 @@ fn run_concurrent_round_trip( let mut group = crit.benchmark_group(request.group_name(name)); for num_concurrent_tasks in helpers::concurrent_tasks() { group.bench_function(format!("{}", num_concurrent_tasks), |b| { - b.iter(|| { - let mut tasks = Vec::new(); - for _ in 0..num_concurrent_tasks { - let client_rc = client.clone(); - let task = rt.spawn(async move { - let _ = black_box( - client_rc.request::(request.method_name(), JsonRpcParams::NoParams).await.unwrap(), - ); + b.to_async(rt).iter_with_setup( + || (0..num_concurrent_tasks).map(|_| client.clone()), + |clients| async { + let tasks = clients.map(|client| { + rt.spawn(async move { + let _ = black_box( + client.request::(request.method_name(), JsonRpcParams::NoParams).await.unwrap(), + ); + }) }); - tasks.push(task); - } - for task in tasks { - rt.block_on(task).unwrap(); - } - }) + join_all(tasks).await; + }, + ) }); } group.finish(); From 6c69a8c06e11fbb04825c722d8f090d4631ba705 Mon Sep 17 00:00:00 2001 From: Maciej Hirsz <1096222+maciejhirsz@users.noreply.github.com> Date: Fri, 18 Jun 2021 13:31:33 +0200 Subject: [PATCH 10/21] Method aliases + RpcModule: Clone (#383) * Make sync methods into Arc pointers * impl Clone for RpcModule and Methods * No need to wrap Methods in Arc anymore * Simplify generics * register_alias * fmt * grammar Co-authored-by: James Wilson * Use a separate Arc counter for tracking max_connections Co-authored-by: James Wilson --- types/src/error.rs | 3 ++ utils/src/server/rpc_module.rs | 86 ++++++++++++++++++++++++---------- ws-server/src/server.rs | 16 +++++-- 3 files changed, 76 insertions(+), 29 deletions(-) diff --git a/types/src/error.rs b/types/src/error.rs index 607e193fde..f502c27b11 100644 --- a/types/src/error.rs +++ b/types/src/error.rs @@ -66,6 +66,9 @@ pub enum Error { /// Method was already registered. #[error("Method: {0} was already registered")] MethodAlreadyRegistered(String), + /// Method with that name has not yet been registered. + #[error("Method: {0} has not yet been registered")] + MethodNotFound(String), /// Subscribe and unsubscribe method names are the same. #[error("Cannot use the same method name for subscribe and unsubscribe, used: {0}")] SubscriptionNameConflict(String), diff --git a/utils/src/server/rpc_module.rs b/utils/src/server/rpc_module.rs index 602e724fd8..bd4f1c7e44 100644 --- a/utils/src/server/rpc_module.rs +++ b/utils/src/server/rpc_module.rs @@ -18,7 +18,7 @@ use std::sync::Arc; /// implemented as a function pointer to a `Fn` function taking four arguments: /// the `id`, `params`, a channel the function uses to communicate the result (or error) /// back to `jsonrpsee`, and the connection ID (useful for the websocket transport). -pub type SyncMethod = Box Result<(), Error>>; +pub type SyncMethod = Arc Result<(), Error>>; /// Similar to [`SyncMethod`], but represents an asynchronous handler. pub type AsyncMethod = Arc< dyn Send + Sync + Fn(OwnedId, OwnedRpcParams, MethodSink, ConnectionId) -> BoxFuture<'static, Result<(), Error>>, @@ -41,6 +41,7 @@ struct SubscriptionKey { } /// Callback wrapper that can be either sync or async. +#[derive(Clone)] pub enum MethodCallback { /// Synchronous method handler. Sync(SyncMethod), @@ -81,10 +82,10 @@ impl Debug for MethodCallback { } } -/// Collection of synchronous and asynchronous methods. -#[derive(Default, Debug)] +/// Reference-counted, clone-on-write collection of synchronous and asynchronous methods. +#[derive(Default, Debug, Clone)] pub struct Methods { - callbacks: FxHashMap<&'static str, MethodCallback>, + callbacks: Arc>, } impl Methods { @@ -101,15 +102,22 @@ impl Methods { Ok(()) } + /// Helper for obtaining a mut ref to the callbacks HashMap. + fn mut_callbacks(&mut self) -> &mut FxHashMap<&'static str, MethodCallback> { + Arc::make_mut(&mut self.callbacks) + } + /// Merge two [`Methods`]'s by adding all [`MethodCallback`]s from `other` into `self`. /// Fails if any of the methods in `other` is present already. - pub fn merge(&mut self, other: Methods) -> Result<(), Error> { + pub fn merge(&mut self, mut other: Methods) -> Result<(), Error> { for name in other.callbacks.keys() { self.verify_method_name(name)?; } - for (name, callback) in other.callbacks { - self.callbacks.insert(name, callback); + let callbacks = self.mut_callbacks(); + + for (name, callback) in other.mut_callbacks().drain() { + callbacks.insert(name, callback); } Ok(()) @@ -137,17 +145,33 @@ impl Methods { /// Sets of JSON-RPC methods can be organized into a "module"s that are in turn registered on the server or, /// alternatively, merged with other modules to construct a cohesive API. [`RpcModule`] wraps an additional context /// argument that can be used to access data during call execution. -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct RpcModule { ctx: Arc, methods: Methods, } -impl RpcModule { +impl RpcModule { /// Create a new module with a given shared `Context`. pub fn new(ctx: Context) -> Self { Self { ctx: Arc::new(ctx), methods: Default::default() } } + + /// Convert a module into methods. Consumes self. + pub fn into_methods(self) -> Methods { + self.methods + } + + /// Merge two [`RpcModule`]'s by adding all [`Methods`] `other` into `self`. + /// Fails if any of the methods in `other` is present already. + pub fn merge(&mut self, other: RpcModule) -> Result<(), Error> { + self.methods.merge(other.methods)?; + + Ok(()) + } +} + +impl RpcModule { /// Register a new synchronous RPC method, which computes the response with the given callback. pub fn register_method(&mut self, method_name: &'static str, callback: F) -> Result<(), Error> where @@ -159,9 +183,9 @@ impl RpcModule { let ctx = self.ctx.clone(); - self.methods.callbacks.insert( + self.methods.mut_callbacks().insert( method_name, - MethodCallback::Sync(Box::new(move |id, params, tx, _| { + MethodCallback::Sync(Arc::new(move |id, params, tx, _| { match callback(params, &*ctx) { Ok(res) => send_response(id, tx, res), Err(CallError::InvalidParams) => send_error(id, tx, JsonRpcErrorCode::InvalidParams.into()), @@ -192,7 +216,7 @@ impl RpcModule { let ctx = self.ctx.clone(); - self.methods.callbacks.insert( + self.methods.mut_callbacks().insert( method_name, MethodCallback::Async(Arc::new(move |id, params, tx, _| { let ctx = ctx.clone(); @@ -265,9 +289,9 @@ impl RpcModule { { let subscribers = subscribers.clone(); - self.methods.callbacks.insert( + self.methods.mut_callbacks().insert( subscribe_method_name, - MethodCallback::Sync(Box::new(move |id, params, method_sink, conn_id| { + MethodCallback::Sync(Arc::new(move |id, params, method_sink, conn_id| { let (conn_tx, conn_rx) = oneshot::channel::<()>(); let sub_id = { const JS_NUM_MASK: SubscriptionId = !0 >> 11; @@ -293,9 +317,9 @@ impl RpcModule { } { - self.methods.callbacks.insert( + self.methods.mut_callbacks().insert( unsubscribe_method_name, - MethodCallback::Sync(Box::new(move |id, params, tx, conn_id| { + MethodCallback::Sync(Arc::new(move |id, params, tx, conn_id| { let sub_id = params.one()?; subscribers.lock().remove(&SubscriptionKey { conn_id, sub_id }); send_response(id, &tx, "Unsubscribed"); @@ -308,15 +332,16 @@ impl RpcModule { Ok(()) } - /// Convert a module into methods. Consumes self. - pub fn into_methods(self) -> Methods { - self.methods - } + /// Register an `alias` name for an `existing_method`. + pub fn register_alias(&mut self, alias: &'static str, existing_method: &'static str) -> Result<(), Error> { + self.methods.verify_method_name(alias)?; - /// Merge two [`RpcModule`]'s by adding all [`Methods`] `other` into `self`. - /// Fails if any of the methods in `other` is present already. - pub fn merge(&mut self, other: RpcModule) -> Result<(), Error> { - self.methods.merge(other.methods)?; + let callback = match self.methods.callbacks.get(existing_method) { + Some(callback) => callback.clone(), + None => return Err(Error::MethodNotFound(existing_method.into())), + }; + + self.methods.mut_callbacks().insert(alias, callback); Ok(()) } @@ -431,4 +456,17 @@ mod tests { assert!(methods.method("hi").is_some()); assert!(methods.method("goodbye").is_some()); } + + #[test] + fn rpc_register_alias() { + let mut module = RpcModule::new(()); + + module.register_method("hello_world", |_: RpcParams, _| Ok(())).unwrap(); + module.register_alias("hello_foobar", "hello_world").unwrap(); + + let methods = module.into_methods(); + + assert!(methods.method("hello_world").is_some()); + assert!(methods.method("hello_foobar").is_some()); + } } diff --git a/ws-server/src/server.rs b/ws-server/src/server.rs index e5f0b7efe4..62f82daac4 100644 --- a/ws-server/src/server.rs +++ b/ws-server/src/server.rs @@ -56,7 +56,7 @@ impl Server { /// Register all methods from a [`Methods`] of provided [`RpcModule`] on this server. /// In case a method already is registered with the same name, no method is added and a [`Error::MethodAlreadyRegistered`] /// is returned. Note that the [`RpcModule`] is consumed after this call. - pub fn register_module(&mut self, module: RpcModule) -> Result<(), Error> { + pub fn register_module(&mut self, module: RpcModule) -> Result<(), Error> { self.methods.merge(module.into_methods())?; Ok(()) } @@ -74,7 +74,8 @@ impl Server { /// Start responding to connections requests. This will block current thread until the server is stopped. pub async fn start(self) { let mut incoming = TcpListenerStream::new(self.listener); - let methods = Arc::new(self.methods); + let methods = self.methods; + let conn_counter = Arc::new(()); let cfg = self.cfg; let mut id = 0; @@ -82,13 +83,18 @@ impl Server { if let Ok(socket) = socket { socket.set_nodelay(true).unwrap_or_else(|e| panic!("Could not set NODELAY on socket: {:?}", e)); - if Arc::strong_count(&methods) > self.cfg.max_connections as usize { + if Arc::strong_count(&conn_counter) > self.cfg.max_connections as usize { log::warn!("Too many connections. Try again in a while"); continue; } let methods = methods.clone(); + let counter = conn_counter.clone(); - tokio::spawn(background_task(socket, id, methods, cfg)); + tokio::spawn(async move { + let r = background_task(socket, id, methods, cfg).await; + drop(counter); + r + }); id += 1; } @@ -99,7 +105,7 @@ impl Server { async fn background_task( socket: tokio::net::TcpStream, conn_id: ConnectionId, - methods: Arc, + methods: Methods, cfg: Settings, ) -> Result<(), Error> { // For each incoming background_task we perform a handshake. From 26b061360791c08c5ddde9663bf7de60a4bdb89c Mon Sep 17 00:00:00 2001 From: Maciej Hirsz <1096222+maciejhirsz@users.noreply.github.com> Date: Fri, 18 Jun 2021 14:13:58 +0200 Subject: [PATCH 11/21] Cross-origin protection (#375) * Initial implementation * Comments * Send a 403 on denied origin * Noodling around with `set_allowed_origins` * Error on empty list * Soketto 0.6 * fmt * Add `Builder::allow_all_origins`, clarify doc comments * Rename Cors -> AllowedOrigins, nits, no panic --- test-utils/Cargo.toml | 2 +- test-utils/src/types.rs | 6 +-- types/Cargo.toml | 2 +- types/src/error.rs | 3 ++ ws-client/Cargo.toml | 2 +- ws-server/Cargo.toml | 2 +- ws-server/src/server.rs | 92 +++++++++++++++++++++++++++++++++++++---- 7 files changed, 93 insertions(+), 16 deletions(-) diff --git a/test-utils/Cargo.toml b/test-utils/Cargo.toml index 9915300453..aca58d3cc3 100644 --- a/test-utils/Cargo.toml +++ b/test-utils/Cargo.toml @@ -15,6 +15,6 @@ hyper = { version = "0.14", features = ["full"] } log = "0.4" serde = { version = "1", default-features = false, features = ["derive"] } serde_json = "1" -soketto = "0.5" +soketto = "0.6" tokio = { version = "1", features = ["net", "rt-multi-thread", "macros", "time"] } tokio-util = { version = "0.6", features = ["compat"] } diff --git a/test-utils/src/types.rs b/test-utils/src/types.rs index 419a937a41..5c5aa8c92c 100644 --- a/test-utils/src/types.rs +++ b/test-utils/src/types.rs @@ -199,12 +199,12 @@ async fn server_backend(listener: tokio::net::TcpListener, mut exit: Receiver<() async fn connection_task(socket: tokio::net::TcpStream, mode: ServerMode, mut exit: Receiver<()>) { let mut server = Server::new(socket.compat()); - let websocket_key = match server.receive_request().await { - Ok(req) => req.into_key(), + let key = match server.receive_request().await { + Ok(req) => req.key(), Err(_) => return, }; - let accept = server.send_response(&Response::Accept { key: &websocket_key, protocol: None }).await; + let accept = server.send_response(&Response::Accept { key, protocol: None }).await; if accept.is_err() { return; diff --git a/types/Cargo.toml b/types/Cargo.toml index 31c2c32779..6964760be6 100644 --- a/types/Cargo.toml +++ b/types/Cargo.toml @@ -18,5 +18,5 @@ log = { version = "0.4", default-features = false } serde = { version = "1", default-features = false, features = ["derive"] } serde_json = { version = "1", default-features = false, features = ["alloc", "raw_value", "std"] } thiserror = "1.0" -soketto = "0.5" +soketto = "0.6" hyper = "0.14" diff --git a/types/src/error.rs b/types/src/error.rs index f502c27b11..84a699971e 100644 --- a/types/src/error.rs +++ b/types/src/error.rs @@ -81,6 +81,9 @@ pub enum Error { /// Configured max number of request slots exceeded. #[error("Configured max number of request slots exceeded")] MaxSlotsExceeded, + /// List passed into `set_allowed_origins` was empty + #[error("Must set at least one allowed origin")] + EmptyAllowedOrigins, /// Custom error. #[error("Custom error: {0}")] Custom(String), diff --git a/ws-client/Cargo.toml b/ws-client/Cargo.toml index 3a52d54bb7..04949e29bc 100644 --- a/ws-client/Cargo.toml +++ b/ws-client/Cargo.toml @@ -27,7 +27,7 @@ jsonrpsee-types = { path = "../types", version = "0.2.0" } log = "0.4" serde = "1" serde_json = "1" -soketto = "0.5" +soketto = "0.6" pin-project = "1" thiserror = "1" url = "2" diff --git a/ws-server/Cargo.toml b/ws-server/Cargo.toml index 20ad2a1619..6601735702 100644 --- a/ws-server/Cargo.toml +++ b/ws-server/Cargo.toml @@ -19,7 +19,7 @@ log = "0.4" rustc-hash = "1.1.0" serde = { version = "1", default-features = false, features = ["derive"] } serde_json = { version = "1", features = ["raw_value"] } -soketto = "0.5" +soketto = "0.6" tokio = { version = "1", features = ["net", "rt-multi-thread", "macros"] } tokio-stream = { version = "0.1.1", features = ["net"] } tokio-util = { version = "0.6", features = ["compat"] } diff --git a/ws-server/src/server.rs b/ws-server/src/server.rs index 62f82daac4..21bcb8a8cd 100644 --- a/ws-server/src/server.rs +++ b/ws-server/src/server.rs @@ -76,12 +76,14 @@ impl Server { let mut incoming = TcpListenerStream::new(self.listener); let methods = self.methods; let conn_counter = Arc::new(()); - let cfg = self.cfg; let mut id = 0; while let Some(socket) = incoming.next().await { if let Ok(socket) = socket { - socket.set_nodelay(true).unwrap_or_else(|e| panic!("Could not set NODELAY on socket: {:?}", e)); + if let Err(e) = socket.set_nodelay(true) { + log::error!("Could not set NODELAY on socket: {:?}", e); + continue; + } if Arc::strong_count(&conn_counter) > self.cfg.max_connections as usize { log::warn!("Too many connections. Try again in a while"); @@ -89,6 +91,7 @@ impl Server { } let methods = methods.clone(); let counter = conn_counter.clone(); + let cfg = self.cfg.clone(); tokio::spawn(async move { let r = background_task(socket, id, methods, cfg).await; @@ -111,14 +114,24 @@ async fn background_task( // For each incoming background_task we perform a handshake. let mut server = SokettoServer::new(BufReader::new(BufWriter::new(socket.compat()))); - let websocket_key = { + let key = { let req = server.receive_request().await?; - req.into_key() + + cfg.allowed_origins.verify(req.headers().origin).map(|()| req.key()) }; - // Here we accept the client unconditionally. - let accept = Response::Accept { key: &websocket_key, protocol: None }; - server.send_response(&accept).await?; + match key { + Ok(key) => { + let accept = Response::Accept { key, protocol: None }; + server.send_response(&accept).await?; + } + Err(error) => { + let reject = Response::Reject { status_code: 403 }; + server.send_response(&reject).await?; + + return Err(error); + } + } // And we can finally transition to a websocket background_task. let (mut sender, mut receiver) = server.into_builder().finish(); @@ -185,18 +198,44 @@ async fn background_task( } } +#[derive(Debug, Clone)] +enum AllowedOrigins { + Any, + OneOf(Arc<[String]>), +} + +impl AllowedOrigins { + fn verify(&self, origin: Option<&[u8]>) -> Result<(), Error> { + if let (AllowedOrigins::OneOf(list), Some(origin)) = (self, origin) { + if !list.iter().any(|o| o.as_bytes() == origin) { + let error = format!("Origin denied: {}", String::from_utf8_lossy(origin)); + log::warn!("{}", error); + return Err(Error::Request(error)); + } + } + + Ok(()) + } +} + /// JSON-RPC Websocket server settings. -#[derive(Debug, Clone, Copy)] +#[derive(Debug, Clone)] struct Settings { /// Maximum size in bytes of a request. max_request_body_size: u32, /// Maximum number of incoming connections allowed. max_connections: u64, + /// Cross-origin policy by which to accept or deny incoming requests. + allowed_origins: AllowedOrigins, } impl Default for Settings { fn default() -> Self { - Self { max_request_body_size: TEN_MB_SIZE_BYTES, max_connections: MAX_CONNECTIONS } + Self { + max_request_body_size: TEN_MB_SIZE_BYTES, + max_connections: MAX_CONNECTIONS, + allowed_origins: AllowedOrigins::Any, + } } } @@ -219,6 +258,41 @@ impl Builder { self } + /// Set a list of allowed origins. During the handshake, the `Origin` header will be + /// checked against the list, connections without a matching origin will be denied. + /// Values should include protocol. + /// + /// ```rust + /// # let mut builder = jsonrpsee_ws_server::WsServerBuilder::default(); + /// builder.set_allowed_origins(vec!["https://example.com"]); + /// ``` + /// + /// By default allows any `Origin`. + /// + /// Will return an error if `list` is empty. Use [`allow_all_origins`](Builder::allow_all_origins) to restore the default. + pub fn set_allowed_origins(mut self, list: List) -> Result + where + List: IntoIterator, + Origin: Into, + { + let list: Arc<_> = list.into_iter().map(Into::into).collect(); + + if list.len() == 0 { + return Err(Error::EmptyAllowedOrigins); + } + + self.settings.allowed_origins = AllowedOrigins::OneOf(list); + + Ok(self) + } + + /// Restores the default behavior of allowing connections with `Origin` header + /// containing any value. This will undo any list set by [`set_allowed_origins`](Builder::set_allowed_origins). + pub fn allow_all_origins(mut self) -> Self { + self.settings.allowed_origins = AllowedOrigins::Any; + self + } + /// Finalize the configuration of the server. Consumes the [`Builder`]. pub async fn build(self, addr: impl ToSocketAddrs) -> Result { let listener = TcpListener::bind(addr).await?; From edf5b9938eb699f1f88f77a27f8af76e44946bac Mon Sep 17 00:00:00 2001 From: Igor Aleksanov Date: Wed, 23 Jun 2021 18:05:52 +0400 Subject: [PATCH 12/21] Update roadmap link in readme (#390) --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index ee18b0a3e3..0908d26d21 100644 --- a/README.md +++ b/README.md @@ -37,7 +37,7 @@ The library is still under development; do not use in production. ## Roadmap -See [tracking issue for next stable release](https://github.com/paritytech/jsonrpsee/issues/251) +See [tracking issue for next stable release (0.3)](https://github.com/paritytech/jsonrpsee/issues/376) ## Users From 13ea5b2186a9cffa43f54afeb1c70a8183a365a8 Mon Sep 17 00:00:00 2001 From: Niklas Adolfsson Date: Thu, 24 Jun 2021 10:33:55 +0200 Subject: [PATCH 13/21] [jsonrpsee types]: unify a couple of types + more tests (#389) * [jsonrpsee types]: unify types + more tests * address grumbles --- http-server/src/server.rs | 8 ++++-- tests/tests/proc_macros.rs | 3 +- types/src/v2/params.rs | 39 +++++++++++++++++--------- types/src/v2/request.rs | 7 ++--- types/src/v2/response.rs | 51 +++++++++++++--------------------- utils/src/server/rpc_module.rs | 37 ++++++++++++------------ ws-client/src/client.rs | 11 +++++--- ws-client/src/helpers.rs | 10 +++---- 8 files changed, 84 insertions(+), 82 deletions(-) diff --git a/http-server/src/server.rs b/http-server/src/server.rs index 9d485db331..ed530f93c1 100644 --- a/http-server/src/server.rs +++ b/http-server/src/server.rs @@ -43,6 +43,7 @@ use jsonrpsee_utils::server::{ rpc_module::Methods, }; +use serde_json::value::RawValue; use socket2::{Domain, Socket, Type}; use std::{ cmp, @@ -196,7 +197,8 @@ impl Server { if let Ok(req) = serde_json::from_slice::(&body) { // NOTE: we don't need to track connection id on HTTP, so using hardcoded 0 here. methods.execute(&tx, req, 0).await; - } else if let Ok(_req) = serde_json::from_slice::(&body) { + } else if let Ok(_req) = serde_json::from_slice::>>(&body) + { return Ok::<_, HyperError>(response::ok_response("".into())); } else if let Ok(batch) = serde_json::from_slice::>(&body) { if !batch.is_empty() { @@ -207,7 +209,9 @@ impl Server { } else { send_error(Id::Null, &tx, JsonRpcErrorCode::InvalidRequest.into()); } - } else if let Ok(_batch) = serde_json::from_slice::>(&body) { + } else if let Ok(_batch) = + serde_json::from_slice::>>>(&body) + { return Ok::<_, HyperError>(response::ok_response("".into())); } else { log::error!( diff --git a/tests/tests/proc_macros.rs b/tests/tests/proc_macros.rs index d05427aa78..adf58cef30 100644 --- a/tests/tests/proc_macros.rs +++ b/tests/tests/proc_macros.rs @@ -79,6 +79,5 @@ async fn proc_macros_generic_http_client_api() { assert_eq!(Test::::say_hello(&client).await.unwrap(), "hello".to_string()); assert_eq!(Test2::::foo(&client, 99_u16).await.unwrap(), "hello".to_string()); - // TODO: https://github.com/paritytech/jsonrpsee/issues/212 - //assert!(Registrar::register_para(&client, 99, "para").await.is_ok()); + assert!(Registrar::register_para(&client, 99, "para").await.is_ok()); } diff --git a/types/src/v2/params.rs b/types/src/v2/params.rs index ae3053adf0..3b7777ea69 100644 --- a/types/src/v2/params.rs +++ b/types/src/v2/params.rs @@ -4,22 +4,12 @@ use beef::Cow; use serde::de::{self, Deserializer, Unexpected, Visitor}; use serde::ser::Serializer; use serde::{Deserialize, Serialize}; -use serde_json::{value::RawValue, Value as JsonValue}; +use serde_json::Value as JsonValue; use std::fmt; /// JSON-RPC parameter values for subscriptions. #[derive(Serialize, Deserialize, Debug)] -pub struct JsonRpcNotificationParams<'a> { - /// Subscription ID - pub subscription: u64, - /// Result. - #[serde(borrow)] - pub result: &'a RawValue, -} - -/// JSON-RPC parameter values for subscriptions with support for number and strings. -#[derive(Deserialize, Debug)] -pub struct JsonRpcNotificationParamsAlloc { +pub struct JsonRpcSubscriptionParams { /// Subscription ID pub subscription: SubscriptionId, /// Result. @@ -245,7 +235,9 @@ impl<'a> From> for OwnedId { #[cfg(test)] mod test { - use super::{Cow, Id, JsonRpcParams, JsonValue, RpcParams, SubscriptionId, TwoPointZero}; + use super::{ + Cow, Id, JsonRpcParams, JsonRpcSubscriptionParams, JsonValue, RpcParams, SubscriptionId, TwoPointZero, + }; #[test] fn id_deserialization() { @@ -335,4 +327,25 @@ mod test { assert_eq!(&serialized, initial_ser); } } + + #[test] + fn subscription_params_serialize_work() { + let ser = + serde_json::to_string(&JsonRpcSubscriptionParams { subscription: SubscriptionId::Num(12), result: "goal" }) + .unwrap(); + let exp = r#"{"subscription":12,"result":"goal"}"#; + assert_eq!(ser, exp); + } + + #[test] + fn subscription_params_deserialize_work() { + let ser = r#"{"subscription":"9","result":"offside"}"#; + assert!( + serde_json::from_str::>(ser).is_err(), + "invalid type should not be deserializable" + ); + let dsr: JsonRpcSubscriptionParams = serde_json::from_str(ser).unwrap(); + assert_eq!(dsr.subscription, SubscriptionId::Str("9".into())); + assert_eq!(dsr.result, serde_json::json!("offside")); + } } diff --git a/types/src/v2/request.rs b/types/src/v2/request.rs index 3b4555a098..dec6874770 100644 --- a/types/src/v2/request.rs +++ b/types/src/v2/request.rs @@ -31,14 +31,13 @@ pub struct JsonRpcInvalidRequest<'a> { /// JSON-RPC notification (a request object without a request ID). #[derive(Serialize, Deserialize, Debug)] #[serde(deny_unknown_fields)] -pub struct JsonRpcNotification<'a> { +pub struct JsonRpcNotification<'a, T> { /// JSON-RPC version. pub jsonrpc: TwoPointZero, /// Name of the method to be invoked. pub method: &'a str, /// Parameter values of the request. - #[serde(borrow)] - pub params: Option<&'a RawValue>, + pub params: T, } /// Serializable [JSON-RPC object](https://www.jsonrpc.org/specification#request-object) @@ -116,7 +115,7 @@ mod test { #[test] fn deserialize_valid_notif_works() { let ser = r#"{"jsonrpc":"2.0","method":"say_hello","params":[]}"#; - let dsr: JsonRpcNotification = serde_json::from_str(ser).unwrap(); + let dsr: JsonRpcNotification<&RawValue> = serde_json::from_str(ser).unwrap(); assert_eq!(dsr.method, "say_hello"); assert_eq!(dsr.jsonrpc, TwoPointZero); } diff --git a/types/src/v2/response.rs b/types/src/v2/response.rs index c50607b41d..937a797a23 100644 --- a/types/src/v2/response.rs +++ b/types/src/v2/response.rs @@ -1,4 +1,4 @@ -use crate::v2::params::{Id, JsonRpcNotificationParams, JsonRpcNotificationParamsAlloc, TwoPointZero}; +use crate::v2::params::{Id, TwoPointZero}; use serde::{Deserialize, Serialize}; /// JSON-RPC successful response object. @@ -14,37 +14,24 @@ pub struct JsonRpcResponse<'a, T> { pub id: Id<'a>, } -/// JSON-RPC subscription response. -#[derive(Serialize, Debug)] -pub struct JsonRpcSubscriptionResponse<'a> { - /// JSON-RPC version. - pub jsonrpc: TwoPointZero, - /// Method - pub method: &'a str, - /// Params. - pub params: JsonRpcNotificationParams<'a>, -} +#[cfg(test)] +mod tests { + use super::{Id, JsonRpcResponse, TwoPointZero}; -/// JSON-RPC subscription response. -#[derive(Deserialize, Debug)] -#[serde(deny_unknown_fields)] -pub struct JsonRpcSubscriptionResponseAlloc<'a, T> { - /// JSON-RPC version. - pub jsonrpc: TwoPointZero, - /// Method - pub method: &'a str, - /// Params. - pub params: JsonRpcNotificationParamsAlloc, -} + #[test] + fn serialize_call_response() { + let ser = + serde_json::to_string(&JsonRpcResponse { jsonrpc: TwoPointZero, result: "ok", id: Id::Number(1) }).unwrap(); + let exp = r#"{"jsonrpc":"2.0","result":"ok","id":1}"#; + assert_eq!(ser, exp); + } -/// JSON-RPC notification response. -#[derive(Deserialize, Serialize, Debug)] -#[serde(deny_unknown_fields)] -pub struct JsonRpcNotifResponse<'a, T> { - /// JSON-RPC version. - pub jsonrpc: TwoPointZero, - /// Method - pub method: &'a str, - /// Params. - pub params: T, + #[test] + fn deserialize_call() { + let exp = JsonRpcResponse { jsonrpc: TwoPointZero, result: 99_u64, id: Id::Number(11) }; + let dsr: JsonRpcResponse = serde_json::from_str(r#"{"jsonrpc":"2.0", "result":99, "id":11}"#).unwrap(); + assert_eq!(dsr.jsonrpc, exp.jsonrpc); + assert_eq!(dsr.result, exp.result); + assert_eq!(dsr.id, exp.id); + } } diff --git a/utils/src/server/rpc_module.rs b/utils/src/server/rpc_module.rs index bd4f1c7e44..f001cf29c1 100644 --- a/utils/src/server/rpc_module.rs +++ b/utils/src/server/rpc_module.rs @@ -3,14 +3,15 @@ use futures_channel::{mpsc, oneshot}; use futures_util::{future::BoxFuture, FutureExt}; use jsonrpsee_types::error::{CallError, Error, SubscriptionClosedError}; use jsonrpsee_types::v2::error::{JsonRpcErrorCode, JsonRpcErrorObject, CALL_EXECUTION_FAILED_CODE}; -use jsonrpsee_types::v2::params::{Id, JsonRpcNotificationParams, OwnedId, OwnedRpcParams, RpcParams, TwoPointZero}; -use jsonrpsee_types::v2::request::JsonRpcRequest; -use jsonrpsee_types::v2::response::JsonRpcSubscriptionResponse; +use jsonrpsee_types::v2::params::{ + Id, JsonRpcSubscriptionParams, OwnedId, OwnedRpcParams, RpcParams, SubscriptionId as JsonRpcSubscriptionId, + TwoPointZero, +}; +use jsonrpsee_types::v2::request::{JsonRpcNotification, JsonRpcRequest}; use parking_lot::Mutex; use rustc_hash::FxHashMap; use serde::Serialize; -use serde_json::value::{to_raw_value, RawValue}; use std::fmt::Debug; use std::sync::Arc; @@ -367,18 +368,20 @@ pub struct SubscriptionSink { impl SubscriptionSink { /// Send message on this subscription. pub fn send(&mut self, result: &T) -> Result<(), Error> { - let result = to_raw_value(result)?; - self.send_raw_value(&result) + let msg = self.build_message(result)?; + self.inner_send(msg).map_err(Into::into) } - fn send_raw_value(&mut self, result: &RawValue) -> Result<(), Error> { - let msg = serde_json::to_string(&JsonRpcSubscriptionResponse { + fn build_message(&self, result: &T) -> Result { + serde_json::to_string(&JsonRpcNotification { jsonrpc: TwoPointZero, method: self.method, - params: JsonRpcNotificationParams { subscription: self.uniq_sub.sub_id, result: &*result }, - })?; - - self.inner_send(msg).map_err(Into::into) + params: JsonRpcSubscriptionParams { + subscription: JsonRpcSubscriptionId::Num(self.uniq_sub.sub_id), + result, + }, + }) + .map_err(Into::into) } fn inner_send(&mut self, msg: String) -> Result<(), Error> { @@ -404,14 +407,8 @@ impl SubscriptionSink { pub fn close(&mut self, close_reason: String) { self.is_connected.take(); if let Some((sink, _)) = self.subscribers.lock().remove(&self.uniq_sub) { - let result = - to_raw_value(&SubscriptionClosedError::from(close_reason)).expect("valid json infallible; qed"); - let msg = serde_json::to_string(&JsonRpcSubscriptionResponse { - jsonrpc: TwoPointZero, - method: self.method, - params: JsonRpcNotificationParams { subscription: self.uniq_sub.sub_id, result: &*result }, - }) - .expect("valid json infallible; qed"); + let msg = + self.build_message(&SubscriptionClosedError::from(close_reason)).expect("valid json infallible; qed"); let _ = sink.unbounded_send(msg); } } diff --git a/ws-client/src/client.rs b/ws-client/src/client.rs index 6073f3172d..7e61cb79f0 100644 --- a/ws-client/src/client.rs +++ b/ws-client/src/client.rs @@ -29,8 +29,8 @@ use crate::traits::{Client, SubscriptionClient}; use crate::transport::{Receiver as WsReceiver, Sender as WsSender, Target, WsTransportClientBuilder}; use crate::v2::error::JsonRpcError; use crate::v2::params::{Id, JsonRpcParams}; -use crate::v2::request::{JsonRpcCallSer, JsonRpcNotificationSer}; -use crate::v2::response::{JsonRpcNotifResponse, JsonRpcResponse, JsonRpcSubscriptionResponseAlloc}; +use crate::v2::request::{JsonRpcCallSer, JsonRpcNotification, JsonRpcNotificationSer}; +use crate::v2::response::JsonRpcResponse; use crate::TEN_MB_SIZE_BYTES; use crate::{ helpers::{ @@ -51,6 +51,7 @@ use futures::{ sink::SinkExt, }; +use jsonrpsee_types::v2::params::JsonRpcSubscriptionParams; use jsonrpsee_types::SubscriptionKind; use serde::de::DeserializeOwned; use std::{ @@ -621,14 +622,16 @@ async fn background_task( } } // Subscription response. - else if let Ok(notif) = serde_json::from_slice::>(&raw) { + else if let Ok(notif) = + serde_json::from_slice::>>(&raw) + { log::debug!("[backend]: recv subscription {:?}", notif); if let Err(Some(unsub)) = process_subscription_response(&mut manager, notif) { let _ = stop_subscription(&mut sender, &mut manager, unsub).await; } } // Incoming Notification - else if let Ok(notif) = serde_json::from_slice::>(&raw) { + else if let Ok(notif) = serde_json::from_slice::>(&raw) { log::debug!("[backend]: recv notification {:?}", notif); let _ = process_notification(&mut manager, notif); } diff --git a/ws-client/src/helpers.rs b/ws-client/src/helpers.rs index c88e4e75d9..bd8957c740 100644 --- a/ws-client/src/helpers.rs +++ b/ws-client/src/helpers.rs @@ -1,9 +1,9 @@ use crate::manager::{RequestManager, RequestStatus}; use crate::transport::Sender as WsSender; use futures::channel::mpsc; -use jsonrpsee_types::v2::params::{Id, JsonRpcParams, SubscriptionId}; -use jsonrpsee_types::v2::request::JsonRpcCallSer; -use jsonrpsee_types::v2::response::{JsonRpcNotifResponse, JsonRpcResponse, JsonRpcSubscriptionResponseAlloc}; +use jsonrpsee_types::v2::params::{Id, JsonRpcParams, JsonRpcSubscriptionParams, SubscriptionId}; +use jsonrpsee_types::v2::request::{JsonRpcCallSer, JsonRpcNotification}; +use jsonrpsee_types::v2::response::JsonRpcResponse; use jsonrpsee_types::{v2::error::JsonRpcError, Error, RequestMessage}; use serde_json::Value as JsonValue; @@ -46,7 +46,7 @@ pub fn process_batch_response(manager: &mut RequestManager, rps: Vec, + notif: JsonRpcNotification>, ) -> Result<(), Option> { let sub_id = notif.params.subscription; let request_id = match manager.get_request_id_by_subscription_id(&sub_id) { @@ -75,7 +75,7 @@ pub fn process_subscription_response( /// /// Returns Ok() if the response was successfully handled /// Returns Err() if there was no handler for the method -pub fn process_notification(manager: &mut RequestManager, notif: JsonRpcNotifResponse) -> Result<(), Error> { +pub fn process_notification(manager: &mut RequestManager, notif: JsonRpcNotification) -> Result<(), Error> { match manager.as_notification_handler_mut(notif.method.to_owned()) { Some(send_back_sink) => match send_back_sink.try_send(notif.params) { Ok(()) => Ok(()), From c93b1e7a4640ae60a0e89b2e7344ea37e631e080 Mon Sep 17 00:00:00 2001 From: Igor Aleksanov Date: Thu, 24 Jun 2021 14:00:28 +0400 Subject: [PATCH 14/21] Add a way to stop servers (#386) * Add a way to stop HTTP server * Add a way to stop WS server * Apply suggestions from code review Co-authored-by: David * Ensure the concrete type of error in stop test * Resolve merge artifacts * Add public re-exports of stop handle Co-authored-by: David --- http-server/src/lib.rs | 2 +- http-server/src/server.rs | 46 ++++++++++++++- http-server/src/tests.rs | 32 ++++++++++- types/src/error.rs | 3 + ws-server/src/lib.rs | 2 +- ws-server/src/server.rs | 115 +++++++++++++++++++++++++++++--------- ws-server/src/tests.rs | 35 ++++++++++-- 7 files changed, 196 insertions(+), 39 deletions(-) diff --git a/http-server/src/lib.rs b/http-server/src/lib.rs index 155451d7e1..a19cd64598 100644 --- a/http-server/src/lib.rs +++ b/http-server/src/lib.rs @@ -42,7 +42,7 @@ pub use access_control::{ }; pub use jsonrpsee_types::{Error, TEN_MB_SIZE_BYTES}; pub use jsonrpsee_utils::server::rpc_module::RpcModule; -pub use server::{Builder as HttpServerBuilder, Server as HttpServer}; +pub use server::{Builder as HttpServerBuilder, Server as HttpServer, StopHandle as HttpStopHandle}; #[cfg(test)] mod tests; diff --git a/http-server/src/server.rs b/http-server/src/server.rs index ed530f93c1..a7ad679bcb 100644 --- a/http-server/src/server.rs +++ b/http-server/src/server.rs @@ -26,7 +26,7 @@ use crate::{response, AccessControl, TEN_MB_SIZE_BYTES}; use futures_channel::mpsc; -use futures_util::stream::StreamExt; +use futures_util::{lock::Mutex, stream::StreamExt, SinkExt}; use hyper::{ server::{conn::AddrIncoming, Builder as HyperBuilder}, service::{make_service_fn, service_fn}, @@ -96,12 +96,16 @@ impl Builder { let local_addr = listener.local_addr().ok(); let listener = hyper::Server::from_tcp(listener)?; + + let stop_pair = mpsc::channel(1); Ok(Server { listener, local_addr, methods: Methods::default(), access_control: self.access_control, max_request_body_size: self.max_request_body_size, + stop_pair, + stop_handle: Arc::new(Mutex::new(())), }) } } @@ -112,6 +116,25 @@ impl Default for Builder { } } +/// Handle used to stop the running server. +#[derive(Debug, Clone)] +pub struct StopHandle { + stop_sender: mpsc::Sender<()>, + stop_handle: Arc>, +} + +impl StopHandle { + /// Requests server to stop. Returns an error if server was already stopped. + pub async fn stop(&mut self) -> Result<(), Error> { + self.stop_sender.send(()).await.map_err(|_| Error::AlreadyStopped) + } + + /// Blocks indefinitely until the server is stopped. + pub async fn wait_for_stop(&self) { + self.stop_handle.lock().await; + } +} + /// An HTTP JSON RPC server. #[derive(Debug)] pub struct Server { @@ -125,6 +148,10 @@ pub struct Server { max_request_body_size: u32, /// Access control access_control: AccessControl, + /// Pair of channels to stop the server. + stop_pair: (mpsc::Sender<()>, mpsc::Receiver<()>), + /// Stop handle that indicates whether server has been stopped. + stop_handle: Arc>, } impl Server { @@ -146,11 +173,21 @@ impl Server { self.local_addr.ok_or_else(|| Error::Custom("Local address not found".into())) } + /// Returns the handle to stop the running server. + pub fn stop_handle(&self) -> StopHandle { + StopHandle { stop_sender: self.stop_pair.0.clone(), stop_handle: self.stop_handle.clone() } + } + /// Start the server. pub async fn start(self) -> Result<(), Error> { + // Lock the stop mutex so existing stop handles can wait for server to stop. + // It will be unlocked once this function returns. + let _stop_handle = self.stop_handle.lock().await; + let methods = Arc::new(self.methods); let max_request_body_size = self.max_request_body_size; let access_control = self.access_control; + let mut stop_receiver = self.stop_pair.1; let make_service = make_service_fn(move |_| { let methods = methods.clone(); @@ -240,7 +277,12 @@ impl Server { }); let server = self.listener.serve(make_service); - server.await.map_err(Into::into) + server + .with_graceful_shutdown(async move { + stop_receiver.next().await; + }) + .await + .map_err(Into::into) } } diff --git a/http-server/src/tests.rs b/http-server/src/tests.rs index dd3d520f13..14532e04b1 100644 --- a/http-server/src/tests.rs +++ b/http-server/src/tests.rs @@ -2,15 +2,20 @@ use std::net::SocketAddr; -use crate::{HttpServerBuilder, RpcModule}; +use crate::{server::StopHandle, HttpServerBuilder, RpcModule}; use futures_util::FutureExt; use jsonrpsee_test_utils::helpers::*; use jsonrpsee_test_utils::types::{Id, StatusCode, TestContext}; use jsonrpsee_test_utils::TimeoutFutureExt; use jsonrpsee_types::error::{CallError, Error}; use serde_json::Value as JsonValue; +use tokio::task::JoinHandle; async fn server() -> SocketAddr { + server_with_handles().await.0 +} + +async fn server_with_handles() -> (SocketAddr, JoinHandle>, StopHandle) { let mut server = HttpServerBuilder::default().build("127.0.0.1:0".parse().unwrap()).unwrap(); let ctx = TestContext; let mut module = RpcModule::new(ctx); @@ -56,8 +61,9 @@ async fn server() -> SocketAddr { .unwrap(); server.register_module(module).unwrap(); - tokio::spawn(async move { server.start().with_default_timeout().await.unwrap() }); - addr + let stop_handle = server.stop_handle(); + let join_handle = tokio::spawn(async move { server.start().with_default_timeout().await.unwrap() }); + (addr, join_handle, stop_handle) } #[tokio::test] @@ -308,3 +314,23 @@ async fn can_register_modules() { assert_eq!(err.to_string(), expected_err.to_string()); assert_eq!(server.method_names().len(), 2); } + +#[tokio::test] +async fn stop_works() { + let _ = env_logger::try_init(); + let (_addr, join_handle, mut stop_handle) = server_with_handles().with_default_timeout().await.unwrap(); + stop_handle.stop().with_default_timeout().await.unwrap().unwrap(); + stop_handle.wait_for_stop().with_default_timeout().await.unwrap(); + + // After that we should be able to wait for task handle to finish. + // First `unwrap` is timeout, second is `JoinHandle`'s one, third is the server future result. + join_handle + .with_default_timeout() + .await + .expect("Timeout") + .expect("Join error") + .expect("Server stopped with an error"); + + // After server was stopped, attempt to stop it again should result in an error. + assert!(matches!(stop_handle.stop().with_default_timeout().await.unwrap(), Err(Error::AlreadyStopped))); +} diff --git a/types/src/error.rs b/types/src/error.rs index 84a699971e..5fe57efadf 100644 --- a/types/src/error.rs +++ b/types/src/error.rs @@ -81,6 +81,9 @@ pub enum Error { /// Configured max number of request slots exceeded. #[error("Configured max number of request slots exceeded")] MaxSlotsExceeded, + /// Attempted to stop server that is already stopped. + #[error("Attempted to stop server that is already stopped")] + AlreadyStopped, /// List passed into `set_allowed_origins` was empty #[error("Must set at least one allowed origin")] EmptyAllowedOrigins, diff --git a/ws-server/src/lib.rs b/ws-server/src/lib.rs index 3f6bef39c8..1f8d6d45e1 100644 --- a/ws-server/src/lib.rs +++ b/ws-server/src/lib.rs @@ -39,4 +39,4 @@ mod tests; pub use jsonrpsee_types::error::Error; pub use jsonrpsee_utils::server::rpc_module::{RpcModule, SubscriptionSink}; -pub use server::{Builder as WsServerBuilder, Server as WsServer}; +pub use server::{Builder as WsServerBuilder, Server as WsServer, StopHandle as WsStopHandle}; diff --git a/ws-server/src/server.rs b/ws-server/src/server.rs index 21bcb8a8cd..6a9bb99455 100644 --- a/ws-server/src/server.rs +++ b/ws-server/src/server.rs @@ -25,12 +25,18 @@ // DEALINGS IN THE SOFTWARE. use futures_channel::mpsc; -use futures_util::io::{BufReader, BufWriter}; use futures_util::stream::StreamExt; +use futures_util::{ + io::{BufReader, BufWriter}, + SinkExt, +}; use jsonrpsee_types::TEN_MB_SIZE_BYTES; use soketto::handshake::{server::Response, Server as SokettoServer}; use std::{net::SocketAddr, sync::Arc}; -use tokio::net::{TcpListener, ToSocketAddrs}; +use tokio::{ + net::{TcpListener, ToSocketAddrs}, + sync::Mutex, +}; use tokio_stream::wrappers::TcpListenerStream; use tokio_util::compat::TokioAsyncReadCompatExt; @@ -50,6 +56,10 @@ pub struct Server { methods: Methods, listener: TcpListener, cfg: Settings, + /// Pair of channels to stop the server. + stop_pair: (mpsc::Sender<()>, mpsc::Receiver<()>), + /// Stop handle that indicates whether server has been stopped. + stop_handle: Arc>, } impl Server { @@ -71,35 +81,57 @@ impl Server { self.listener.local_addr().map_err(Into::into) } + /// Returns the handle to stop the running server. + pub fn stop_handle(&self) -> StopHandle { + StopHandle { stop_sender: self.stop_pair.0.clone(), stop_handle: self.stop_handle.clone() } + } + /// Start responding to connections requests. This will block current thread until the server is stopped. pub async fn start(self) { - let mut incoming = TcpListenerStream::new(self.listener); + // Lock the stop mutex so existing stop handles can wait for server to stop. + // It will be unlocked once this function returns. + let _stop_handle = self.stop_handle.lock().await; + + let mut incoming = TcpListenerStream::new(self.listener).fuse(); let methods = self.methods; let conn_counter = Arc::new(()); let mut id = 0; - - while let Some(socket) = incoming.next().await { - if let Ok(socket) = socket { - if let Err(e) = socket.set_nodelay(true) { - log::error!("Could not set NODELAY on socket: {:?}", e); - continue; - } - - if Arc::strong_count(&conn_counter) > self.cfg.max_connections as usize { - log::warn!("Too many connections. Try again in a while"); - continue; - } - let methods = methods.clone(); - let counter = conn_counter.clone(); - let cfg = self.cfg.clone(); - - tokio::spawn(async move { - let r = background_task(socket, id, methods, cfg).await; - drop(counter); - r - }); - - id += 1; + let mut stop_receiver = self.stop_pair.1; + + loop { + futures_util::select! { + socket = incoming.next() => { + if let Some(Ok(socket)) = socket { + if let Err(e) = socket.set_nodelay(true) { + log::error!("Could not set NODELAY on socket: {:?}", e); + continue; + } + + if Arc::strong_count(&conn_counter) > self.cfg.max_connections as usize { + log::warn!("Too many connections. Try again in a while"); + continue; + } + let methods = methods.clone(); + let counter = conn_counter.clone(); + let cfg = self.cfg.clone(); + + tokio::spawn(async move { + let r = background_task(socket, id, methods, cfg).await; + drop(counter); + r + }); + + id += 1; + } else { + break; + } + }, + stop = stop_receiver.next() => { + if stop.is_some() { + break; + } + }, + complete => break, } } } @@ -296,7 +328,14 @@ impl Builder { /// Finalize the configuration of the server. Consumes the [`Builder`]. pub async fn build(self, addr: impl ToSocketAddrs) -> Result { let listener = TcpListener::bind(addr).await?; - Ok(Server { listener, methods: Methods::default(), cfg: self.settings }) + let stop_pair = mpsc::channel(1); + Ok(Server { + listener, + methods: Methods::default(), + cfg: self.settings, + stop_pair, + stop_handle: Arc::new(Mutex::new(())), + }) } } @@ -305,3 +344,25 @@ impl Default for Builder { Self { settings: Settings::default() } } } + +/// Handle that is able to stop the running server. +#[derive(Debug, Clone)] +pub struct StopHandle { + stop_sender: mpsc::Sender<()>, + stop_handle: Arc>, +} + +impl StopHandle { + /// Requests server to stop. Returns an error if server was already stopped. + /// + /// Note: This method *does not* abort spawned futures, e.g. `tokio::spawn` handlers + /// for subscriptions. It only prevents server from accepting new connections. + pub async fn stop(&mut self) -> Result<(), Error> { + self.stop_sender.send(()).await.map_err(|_| Error::AlreadyStopped) + } + + /// Blocks indefinitely until the server is stopped. + pub async fn wait_for_stop(&self) { + self.stop_handle.lock().await; + } +} diff --git a/ws-server/src/tests.rs b/ws-server/src/tests.rs index 897261c7ce..fdf138502b 100644 --- a/ws-server/src/tests.rs +++ b/ws-server/src/tests.rs @@ -1,6 +1,6 @@ #![cfg(test)] -use crate::{RpcModule, WsServerBuilder}; +use crate::{server::StopHandle, RpcModule, WsServerBuilder}; use futures_util::FutureExt; use jsonrpsee_test_utils::helpers::*; use jsonrpsee_test_utils::types::{Id, TestContext, WebSocketTestClient}; @@ -12,6 +12,7 @@ use jsonrpsee_types::{ use serde_json::Value as JsonValue; use std::fmt; use std::net::SocketAddr; +use tokio::task::JoinHandle; /// Applications can/should provide their own error. #[derive(Debug)] @@ -26,6 +27,13 @@ impl std::error::Error for MyAppError {} /// Spawns a dummy `JSONRPC v2 WebSocket` /// It has two hardcoded methods: "say_hello" and "add" async fn server() -> SocketAddr { + server_with_handles().await.0 +} + +/// Spawns a dummy `JSONRPC v2 WebSocket` +/// It has two hardcoded methods: "say_hello" and "add" +/// Returns the address together with handles for server future and server stop. +async fn server_with_handles() -> (SocketAddr, JoinHandle<()>, StopHandle) { let mut server = WsServerBuilder::default().build("127.0.0.1:0").with_default_timeout().await.unwrap().unwrap(); let mut module = RpcModule::new(()); module @@ -64,8 +72,10 @@ async fn server() -> SocketAddr { let addr = server.local_addr().unwrap(); server.register_module(module).unwrap(); - tokio::spawn(async { server.start().await }); - addr + + let stop_handle = server.stop_handle(); + let join_handle = tokio::spawn(server.start()); + (addr, join_handle, stop_handle) } /// Run server with user provided context. @@ -114,7 +124,7 @@ async fn server_with_context() -> SocketAddr { server.register_module(rpc_module).unwrap(); let addr = server.local_addr().unwrap(); - tokio::spawn(async { server.start().await }); + tokio::spawn(server.start()); addr } @@ -305,7 +315,7 @@ async fn async_method_call_that_fails() { let req = r#"{"jsonrpc":"2.0","method":"err_async", "params":[],"id":1}"#; let response = client.send_request_text(req).await.unwrap(); - assert_eq!(response, call_execution_failed("nah".into(), Id::Num(1))); + assert_eq!(response, call_execution_failed("nah", Id::Num(1))); } #[tokio::test] @@ -442,3 +452,18 @@ async fn can_register_modules() { assert!(matches!(err, _expected_err)); assert_eq!(server.method_names().len(), 2); } + +#[tokio::test] +async fn stop_works() { + let _ = env_logger::try_init(); + let (_addr, join_handle, mut stop_handle) = server_with_handles().with_default_timeout().await.unwrap(); + stop_handle.stop().with_default_timeout().await.unwrap().unwrap(); + stop_handle.wait_for_stop().with_default_timeout().await.unwrap(); + + // After that we should be able to wait for task handle to finish. + // First `unwrap` is timeout, second is `JoinHandle`'s one. + join_handle.with_default_timeout().await.expect("Timeout").expect("Join error"); + + // After server was stopped, attempt to stop it again should result in an error. + assert!(matches!(stop_handle.stop().with_default_timeout().await.unwrap(), Err(Error::AlreadyStopped))); +} From 2ca8355a82c683b0f17746e5160fe06890ccbc3c Mon Sep 17 00:00:00 2001 From: Niklas Adolfsson Date: Fri, 25 Jun 2021 10:56:34 +0200 Subject: [PATCH 15/21] [ci]: test each individual crate's manifest (#392) * [ci]: test each crate outside workspace We have bitten by these a few times now with that some features are leaked from the workspace which makes it compile in the workspace but not using it's own Cargo.toml. * [ci]: add tests for macos and windows * add missed `WsServer` and `HttpServer` * [ws server]: fix features * debug failure in CI * remove platform dependent assertion * fix nit; proc-macros is proc-macros * restore removed assertion * remove whitespaces --- .github/workflows/ci.yml | 132 ++++++++++++++++++++++++++++++++++++--- test-utils/src/types.rs | 14 +++-- ws-server/Cargo.toml | 2 +- ws-server/src/tests.rs | 10 +-- 4 files changed, 141 insertions(+), 17 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index a10dab90b0..238b614e60 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -28,8 +28,8 @@ jobs: override: true components: clippy, rustfmt - - name: Rust Cache - uses: Swatinem/rust-cache@v1.3.0 + - name: Rust Cache + uses: Swatinem/rust-cache@v1.3.0 - name: Cargo fmt uses: actions-rs/cargo@v1.0.3 @@ -59,10 +59,10 @@ jobs: toolchain: stable override: true - - name: Rust Cache - uses: Swatinem/rust-cache@v1.3.0 + - name: Rust Cache + uses: Swatinem/rust-cache@v1.3.0 - - name: Cargo check all targets + - name: Cargo check all targets (use Cargo.toml in workspace) uses: actions-rs/cargo@v1.0.3 with: command: check @@ -74,8 +74,68 @@ jobs: command: check args: --manifest-path http-client/Cargo.toml --no-default-features --features tokio02 + - name: Cargo check HTTP client + uses: actions-rs/cargo@v1.0.3 + with: + command: check + args: --manifest-path http-client/Cargo.toml + + - name: Cargo check HTTP server + uses: actions-rs/cargo@v1.0.3 + with: + command: check + args: --manifest-path http-server/Cargo.toml + + - name: Cargo check WS client + uses: actions-rs/cargo@v1.0.3 + with: + command: check + args: --manifest-path ws-client/Cargo.toml + + - name: Cargo check WS client with tokio02 + uses: actions-rs/cargo@v1.0.3 + with: + command: check + args: --manifest-path ws-client/Cargo.toml --no-default-features --features tokio02 + + - name: Cargo check WS server + uses: actions-rs/cargo@v1.0.3 + with: + command: check + args: --manifest-path ws-server/Cargo.toml + + - name: Cargo check types + uses: actions-rs/cargo@v1.0.3 + with: + command: check + args: --manifest-path types/Cargo.toml + + - name: Cargo check utils + uses: actions-rs/cargo@v1.0.3 + with: + command: check + args: --manifest-path utils/Cargo.toml + + - name: Cargo check proc macros + uses: actions-rs/cargo@v1.0.3 + with: + command: check + args: --manifest-path proc-macros/Cargo.toml + + - name: Cargo check test utils + uses: actions-rs/cargo@v1.0.3 + with: + command: check + args: --manifest-path test-utils/Cargo.toml + + - name: Cargo check examples + uses: actions-rs/cargo@v1.0.3 + with: + command: check + args: --manifest-path examples/Cargo.toml + tests: - name: Run tests + name: Run tests Ubuntu runs-on: ubuntu-latest steps: - name: Checkout sources @@ -88,8 +148,64 @@ jobs: toolchain: stable override: true - - name: Rust Cache - uses: Swatinem/rust-cache@v1.3.0 + - name: Rust Cache + uses: Swatinem/rust-cache@v1.3.0 + + - name: Cargo build + uses: actions-rs/cargo@v1.0.3 + with: + command: build + args: --workspace + + - name: Cargo test + uses: actions-rs/cargo@v1.0.3 + with: + command: test + + tests_macos: + name: Run tests macos + runs-on: macos-latest + steps: + - name: Checkout sources + uses: actions/checkout@v2.3.4 + + - name: Install Rust stable toolchain + uses: actions-rs/toolchain@v1.0.7 + with: + profile: minimal + toolchain: stable + override: true + + - name: Rust Cache + uses: Swatinem/rust-cache@v1.3.0 + + - name: Cargo build + uses: actions-rs/cargo@v1.0.3 + with: + command: build + args: --workspace + + - name: Cargo test + uses: actions-rs/cargo@v1.0.3 + with: + command: test + + tests_windows: + name: Run tests Windows + runs-on: windows-latest + steps: + - name: Checkout sources + uses: actions/checkout@v2.3.4 + + - name: Install Rust stable toolchain + uses: actions-rs/toolchain@v1.0.7 + with: + profile: minimal + toolchain: stable + override: true + + - name: Rust Cache + uses: Swatinem/rust-cache@v1.3.0 - name: Cargo build uses: actions-rs/cargo@v1.0.3 diff --git a/test-utils/src/types.rs b/test-utils/src/types.rs index 5c5aa8c92c..17f4b6e3e4 100644 --- a/test-utils/src/types.rs +++ b/test-utils/src/types.rs @@ -8,8 +8,8 @@ use futures_util::{ stream::{self, StreamExt}, }; use serde::{Deserialize, Serialize}; -use soketto::handshake; -use soketto::handshake::{server::Response, Server}; +use soketto::handshake::{self, server::Response, Error as SokettoError, Server}; +use std::io; use std::net::SocketAddr; use std::time::Duration; use tokio::net::TcpStream; @@ -63,7 +63,7 @@ impl std::fmt::Debug for WebSocketTestClient { } impl WebSocketTestClient { - pub async fn new(url: SocketAddr) -> Result { + pub async fn new(url: SocketAddr) -> Result { let socket = TcpStream::connect(url).await?; let mut client = handshake::Client::new(BufReader::new(BufWriter::new(socket.compat())), "test-client", "/"); match client.handshake().await { @@ -71,7 +71,13 @@ impl WebSocketTestClient { let (tx, rx) = client.into_builder().finish(); Ok(Self { tx, rx }) } - r => Err(format!("WebSocketHandshake failed: {:?}", r).into()), + Ok(handshake::ServerResponse::Redirect { .. }) => { + Err(SokettoError::Io(io::Error::new(io::ErrorKind::Other, "Redirection not supported in tests"))) + } + Ok(handshake::ServerResponse::Rejected { .. }) => { + Err(SokettoError::Io(io::Error::new(io::ErrorKind::Other, "Rejected"))) + } + Err(err) => Err(err), } } diff --git a/ws-server/Cargo.toml b/ws-server/Cargo.toml index 6601735702..63fc103961 100644 --- a/ws-server/Cargo.toml +++ b/ws-server/Cargo.toml @@ -12,7 +12,7 @@ documentation = "https://docs.rs/jsonrpsee-ws-server" [dependencies] thiserror = "1" futures-channel = "0.3.14" -futures-util = { version = "0.3.14", default-features = false, features = ["io"] } +futures-util = { version = "0.3.14", default-features = false, features = ["io", "async-await-macro"] } jsonrpsee-types = { path = "../types", version = "0.2.0" } jsonrpsee-utils = { path = "../utils", version = "0.2.0", features = ["server"] } log = "0.4" diff --git a/ws-server/src/tests.rs b/ws-server/src/tests.rs index fdf138502b..bf3b286bb4 100644 --- a/ws-server/src/tests.rs +++ b/ws-server/src/tests.rs @@ -171,10 +171,12 @@ async fn can_set_max_connections() { assert!(conn2.is_ok()); // Third connection is rejected assert!(conn3.is_err()); - let err = conn3.unwrap_err(); - assert!(err.to_string().contains("WebSocketHandshake failed")); - assert!(err.to_string().contains("Connection reset by peer")); - // Err(Io(Os { code: 54, kind: ConnectionReset, message: \"Connection reset by peer\" }))"); + + let err = match conn3 { + Err(soketto::handshake::Error::Io(err)) => err, + _ => panic!("Invalid error kind; expected std::io::Error"), + }; + assert_eq!(err.kind(), std::io::ErrorKind::ConnectionReset); // Decrement connection count drop(conn2); From 8b65edf8ce083cd0239d6c2ffb0d6dc1a4bfd042 Mon Sep 17 00:00:00 2001 From: Niklas Adolfsson Date: Tue, 29 Jun 2021 15:28:13 +0200 Subject: [PATCH 16/21] feat: customizable JSON-RPC error codes via new enum variant on `CallErrror` (#394) * feat: customizable error via RpcError trait This commit introduces a new trait for defining user customizable error codes and messages * revert trait stuff * use RawValue * fix docs * rexport to_json_raw_value --- types/src/error.rs | 17 ++++++++++++++--- types/src/lib.rs | 5 ++++- utils/src/server/rpc_module.rs | 22 +++++++++++++++------- 3 files changed, 33 insertions(+), 11 deletions(-) diff --git a/types/src/error.rs b/types/src/error.rs index 5fe57efadf..7ca23178ad 100644 --- a/types/src/error.rs +++ b/types/src/error.rs @@ -1,4 +1,5 @@ use serde::{Deserialize, Serialize}; +use serde_json::value::RawValue; use std::fmt; /// Convenience type for displaying errors. @@ -19,12 +20,22 @@ impl fmt::Display for Mismatch { /// Error that occurs when a call failed. #[derive(Debug, thiserror::Error)] pub enum CallError { - #[error("Invalid params in the RPC call")] /// Invalid params in the call. + #[error("Invalid params in the call")] InvalidParams, + /// The call failed (let jsonrpsee assign default error code and error message). #[error("RPC Call failed: {0}")] - /// The call failed. - Failed(#[source] Box), + Failed(Box), + /// Custom error with specific JSON-RPC error code, message and data. + #[error("RPC Call failed: code: {code}, message: {message}, data: {data:?}")] + Custom { + /// JSON-RPC error code + code: i32, + /// Short description of the error. + message: String, + /// A primitive or structured value that contains additional information about the error. + data: Option>, + }, } /// Error type. diff --git a/types/src/lib.rs b/types/src/lib.rs index 9d14f62b92..0009aee4b6 100644 --- a/types/src/lib.rs +++ b/types/src/lib.rs @@ -24,4 +24,7 @@ pub use beef::Cow; pub use client::*; pub use error::Error; pub use serde::{de::DeserializeOwned, Serialize}; -pub use serde_json::{to_value as to_json_value, value::RawValue as JsonRawValue, Value as JsonValue}; +pub use serde_json::{ + to_value as to_json_value, value::to_raw_value as to_json_raw_value, value::RawValue as JsonRawValue, + Value as JsonValue, +}; diff --git a/utils/src/server/rpc_module.rs b/utils/src/server/rpc_module.rs index f001cf29c1..46995fce72 100644 --- a/utils/src/server/rpc_module.rs +++ b/utils/src/server/rpc_module.rs @@ -69,7 +69,7 @@ impl MethodCallback { if let Err(err) = result { log::error!("execution of method call '{}' failed: {:?}, request id={:?}", req.method, err, id); - send_error(id, &tx, JsonRpcErrorCode::ServerError(-1).into()); + send_error(id, tx, JsonRpcErrorCode::ServerError(-1).into()); } } } @@ -190,14 +190,18 @@ impl RpcModule { match callback(params, &*ctx) { Ok(res) => send_response(id, tx, res), Err(CallError::InvalidParams) => send_error(id, tx, JsonRpcErrorCode::InvalidParams.into()), - Err(CallError::Failed(err)) => { + Err(CallError::Failed(e)) => { let err = JsonRpcErrorObject { code: JsonRpcErrorCode::ServerError(CALL_EXECUTION_FAILED_CODE), - message: &err.to_string(), + message: &e.to_string(), data: None, }; send_error(id, tx, err) } + Err(CallError::Custom { code, message, data }) => { + let err = JsonRpcErrorObject { code: code.into(), message: &message, data: data.as_deref() }; + send_error(id, tx, err) + } }; Ok(()) @@ -227,15 +231,19 @@ impl RpcModule { match callback(params, ctx).await { Ok(res) => send_response(id, &tx, res), Err(CallError::InvalidParams) => send_error(id, &tx, JsonRpcErrorCode::InvalidParams.into()), - Err(CallError::Failed(err)) => { - log::error!("Call failed with: {}", err); + Err(CallError::Failed(e)) => { let err = JsonRpcErrorObject { code: JsonRpcErrorCode::ServerError(CALL_EXECUTION_FAILED_CODE), - message: &err.to_string(), + message: &e.to_string(), data: None, }; send_error(id, &tx, err) } + Err(CallError::Custom { code, message, data }) => { + let err = + JsonRpcErrorObject { code: code.into(), message: &message, data: data.as_deref() }; + send_error(id, &tx, err) + } }; Ok(()) }; @@ -323,7 +331,7 @@ impl RpcModule { MethodCallback::Sync(Arc::new(move |id, params, tx, conn_id| { let sub_id = params.one()?; subscribers.lock().remove(&SubscriptionKey { conn_id, sub_id }); - send_response(id, &tx, "Unsubscribed"); + send_response(id, tx, "Unsubscribed"); Ok(()) })), From 7a33bf5020896c65ce20c9c4d988e10743145f84 Mon Sep 17 00:00:00 2001 From: Niklas Adolfsson Date: Wed, 30 Jun 2021 11:29:26 +0200 Subject: [PATCH 17/21] [ws server]: terminate already established connection(s) when the server is stopped (#396) * [ws server]: terminate connection when closed. * fix tests * add test * address grumbles: return Ok when server stopped * revert log * revert outdated documentation * use wrapping add for conn id * address grumbles: replace Mutex with AtomicBool * add comment to assertion * fix nits * address grumbles: naming of variables * address grumbles: RwLock to wait for tasks This commit introduces a RwLock instead of the Mutex to the shared by the background tasks and the stop handle won't signal until all readers has been dropped. * fix nit * Update ws-server/src/server.rs * remove AtomicBool; use stop_sender instead * Update ws-server/src/server.rs * Update ws-server/src/server.rs Co-authored-by: David * correct subscription err messages Co-authored-by: David --- tests/tests/helpers.rs | 11 ++++--- tests/tests/integration_tests.rs | 38 ++++++++++++++++++---- utils/src/server/rpc_module.rs | 13 ++++---- ws-server/src/server.rs | 56 +++++++++++++++++++------------- 4 files changed, 78 insertions(+), 40 deletions(-) diff --git a/tests/tests/helpers.rs b/tests/tests/helpers.rs index 2eaf8e1b5a..2b6991355e 100644 --- a/tests/tests/helpers.rs +++ b/tests/tests/helpers.rs @@ -25,11 +25,15 @@ // DEALINGS IN THE SOFTWARE. use futures_channel::oneshot; -use jsonrpsee::{http_server::HttpServerBuilder, ws_server::WsServerBuilder, RpcModule}; +use jsonrpsee::{ + http_server::HttpServerBuilder, + ws_server::{WsServerBuilder, WsStopHandle}, + RpcModule, +}; use std::net::SocketAddr; use std::time::Duration; -pub async fn websocket_server_with_subscription() -> SocketAddr { +pub async fn websocket_server_with_subscription() -> (SocketAddr, WsStopHandle) { let (server_started_tx, server_started_rx) = oneshot::channel(); std::thread::spawn(move || { @@ -84,8 +88,7 @@ pub async fn websocket_server_with_subscription() -> SocketAddr { server.register_module(module).unwrap(); rt.block_on(async move { - server_started_tx.send(server.local_addr().unwrap()).unwrap(); - + server_started_tx.send((server.local_addr().unwrap(), server.stop_handle())).unwrap(); server.start().await }); }); diff --git a/tests/tests/integration_tests.rs b/tests/tests/integration_tests.rs index 47a5663496..292adfc840 100644 --- a/tests/tests/integration_tests.rs +++ b/tests/tests/integration_tests.rs @@ -39,7 +39,7 @@ use std::time::Duration; #[tokio::test] async fn ws_subscription_works() { - let server_addr = websocket_server_with_subscription().await; + let (server_addr, _) = websocket_server_with_subscription().await; let server_url = format!("ws://{}", server_addr); let client = WsClientBuilder::default().build(&server_url).await.unwrap(); let mut hello_sub: Subscription = @@ -57,7 +57,7 @@ async fn ws_subscription_works() { #[tokio::test] async fn ws_subscription_with_input_works() { - let server_addr = websocket_server_with_subscription().await; + let (server_addr, _) = websocket_server_with_subscription().await; let server_url = format!("ws://{}", server_addr); let client = WsClientBuilder::default().build(&server_url).await.unwrap(); let mut add_one: Subscription = @@ -89,7 +89,7 @@ async fn http_method_call_works() { #[tokio::test] async fn ws_subscription_several_clients() { - let server_addr = websocket_server_with_subscription().await; + let (server_addr, _) = websocket_server_with_subscription().await; let server_url = format!("ws://{}", server_addr); let mut clients = Vec::with_capacity(10); @@ -105,7 +105,7 @@ async fn ws_subscription_several_clients() { #[tokio::test] async fn ws_subscription_several_clients_with_drop() { - let server_addr = websocket_server_with_subscription().await; + let (server_addr, _) = websocket_server_with_subscription().await; let server_url = format!("ws://{}", server_addr); let mut clients = Vec::with_capacity(10); @@ -153,7 +153,7 @@ async fn ws_subscription_several_clients_with_drop() { #[tokio::test] async fn ws_subscription_without_polling_doesnt_make_client_unuseable() { - let server_addr = websocket_server_with_subscription().await; + let (server_addr, _) = websocket_server_with_subscription().await; let server_url = format!("ws://{}", server_addr); let client = WsClientBuilder::default().max_notifs_per_subscription(4).build(&server_url).await.unwrap(); @@ -230,7 +230,7 @@ async fn http_with_non_ascii_url_doesnt_hang_or_panic() { #[tokio::test] async fn ws_unsubscribe_releases_request_slots() { - let server_addr = websocket_server_with_subscription().await; + let (server_addr, _) = websocket_server_with_subscription().await; let server_url = format!("ws://{}", server_addr); let client = WsClientBuilder::default().max_concurrent_requests(1).build(&server_url).await.unwrap(); @@ -244,7 +244,7 @@ async fn ws_unsubscribe_releases_request_slots() { #[tokio::test] async fn server_should_be_able_to_close_subscriptions() { - let server_addr = websocket_server_with_subscription().await; + let (server_addr, _) = websocket_server_with_subscription().await; let server_url = format!("ws://{}", server_addr); let client = WsClientBuilder::default().build(&server_url).await.unwrap(); @@ -256,3 +256,27 @@ async fn server_should_be_able_to_close_subscriptions() { assert!(matches!(res, Err(Error::SubscriptionClosed(_)))); } + +#[tokio::test] +async fn ws_close_pending_subscription_when_server_terminated() { + let (server_addr, mut handle) = websocket_server_with_subscription().await; + let server_url = format!("ws://{}", server_addr); + + let c1 = WsClientBuilder::default().build(&server_url).await.unwrap(); + + let mut sub: Subscription = + c1.subscribe("subscribe_hello", JsonRpcParams::NoParams, "unsubscribe_hello").await.unwrap(); + + assert!(matches!(sub.next().await, Ok(Some(_)))); + + handle.stop().await.unwrap(); + handle.wait_for_stop().await; + + let sub2: Result, _> = + c1.subscribe("subscribe_hello", JsonRpcParams::NoParams, "unsubscribe_hello").await; + + // no new request should be accepted. + assert!(matches!(sub2, Err(_))); + // the already established subscription should also be closed. + assert!(matches!(sub.next().await, Ok(None))); +} diff --git a/utils/src/server/rpc_module.rs b/utils/src/server/rpc_module.rs index 46995fce72..5342045e3e 100644 --- a/utils/src/server/rpc_module.rs +++ b/utils/src/server/rpc_module.rs @@ -396,12 +396,12 @@ impl SubscriptionSink { let res = if let Some(conn) = self.is_connected.as_ref() { if !conn.is_canceled() { // unbounded send only fails if the receiver has been dropped. - self.inner.unbounded_send(msg).map_err(|_| subscription_closed_by_client()) + self.inner.unbounded_send(msg).map_err(|_| subscription_closed_err(self.uniq_sub.sub_id)) } else { - Err(subscription_closed_by_client()) + Err(subscription_closed_err(self.uniq_sub.sub_id)) } } else { - Err(subscription_closed_by_client()) + Err(subscription_closed_err(self.uniq_sub.sub_id)) }; if let Err(e) = &res { @@ -424,13 +424,12 @@ impl SubscriptionSink { impl Drop for SubscriptionSink { fn drop(&mut self) { - self.close(format!("Subscription: {} closed by the server", self.uniq_sub.sub_id)); + self.close(format!("Subscription: {} is closed and dropped", self.uniq_sub.sub_id)); } } -fn subscription_closed_by_client() -> Error { - const CLOSE_REASON: &str = "Subscription closed by the client"; - Error::SubscriptionClosed(CLOSE_REASON.to_owned().into()) +fn subscription_closed_err(sub_id: u64) -> Error { + Error::SubscriptionClosed(format!("Subscription {} is closed but not yet dropped", sub_id).into()) } #[cfg(test)] diff --git a/ws-server/src/server.rs b/ws-server/src/server.rs index 6a9bb99455..96a69ee345 100644 --- a/ws-server/src/server.rs +++ b/ws-server/src/server.rs @@ -35,7 +35,7 @@ use soketto::handshake::{server::Response, Server as SokettoServer}; use std::{net::SocketAddr, sync::Arc}; use tokio::{ net::{TcpListener, ToSocketAddrs}, - sync::Mutex, + sync::RwLock, }; use tokio_stream::wrappers::TcpListenerStream; use tokio_util::compat::TokioAsyncReadCompatExt; @@ -59,7 +59,7 @@ pub struct Server { /// Pair of channels to stop the server. stop_pair: (mpsc::Sender<()>, mpsc::Receiver<()>), /// Stop handle that indicates whether server has been stopped. - stop_handle: Arc>, + stop_handle: Arc>, } impl Server { @@ -88,15 +88,16 @@ impl Server { /// Start responding to connections requests. This will block current thread until the server is stopped. pub async fn start(self) { - // Lock the stop mutex so existing stop handles can wait for server to stop. - // It will be unlocked once this function returns. - let _stop_handle = self.stop_handle.lock().await; + // Acquire read access to the lock such that additional reader(s) may share this lock. + // Write access to this lock will only be possible after the server and all background tasks have stopped. + let _stop_handle = self.stop_handle.read().await; let mut incoming = TcpListenerStream::new(self.listener).fuse(); let methods = self.methods; let conn_counter = Arc::new(()); let mut id = 0; let mut stop_receiver = self.stop_pair.1; + let shutdown = self.stop_pair.0; loop { futures_util::select! { @@ -111,17 +112,19 @@ impl Server { log::warn!("Too many connections. Try again in a while"); continue; } + + let conn_counter2 = conn_counter.clone(); + let shutdown2 = shutdown.clone(); let methods = methods.clone(); - let counter = conn_counter.clone(); let cfg = self.cfg.clone(); + let stop_handle2 = self.stop_handle.clone(); tokio::spawn(async move { - let r = background_task(socket, id, methods, cfg).await; - drop(counter); - r + let _ = background_task(socket, id, methods, cfg, shutdown2, stop_handle2).await; + drop(conn_counter2); }); - id += 1; + id = id.wrapping_add(1); } else { break; } @@ -142,13 +145,15 @@ async fn background_task( conn_id: ConnectionId, methods: Methods, cfg: Settings, + shutdown: mpsc::Sender<()>, + stop_handle: Arc>, ) -> Result<(), Error> { + let _lock = stop_handle.read().await; // For each incoming background_task we perform a handshake. let mut server = SokettoServer::new(BufReader::new(BufWriter::new(socket.compat()))); let key = { let req = server.receive_request().await?; - cfg.allowed_origins.verify(req.headers().origin).map(|()| req.key()) }; @@ -169,19 +174,27 @@ async fn background_task( let (mut sender, mut receiver) = server.into_builder().finish(); let (tx, mut rx) = mpsc::unbounded::(); + let shutdown2 = shutdown.clone(); // Send results back to the client. tokio::spawn(async move { - while let Some(response) = rx.next().await { - log::debug!("send: {}", response); - let _ = sender.send_text(response).await; - let _ = sender.flush().await; + while !shutdown2.is_closed() { + match rx.next().await { + Some(response) => { + log::debug!("send: {}", response); + let _ = sender.send_text(response).await; + let _ = sender.flush().await; + } + None => break, + }; } + // terminate connection. + let _ = sender.close().await; }); // Buffer for incoming data. let mut data = Vec::with_capacity(100); - loop { + while !shutdown.is_closed() { data.clear(); receiver.receive_data(&mut data).await?; @@ -228,6 +241,7 @@ async fn background_task( send_error(id, &tx, code.into()); } } + Ok(()) } #[derive(Debug, Clone)] @@ -334,7 +348,7 @@ impl Builder { methods: Methods::default(), cfg: self.settings, stop_pair, - stop_handle: Arc::new(Mutex::new(())), + stop_handle: Arc::new(RwLock::new(())), }) } } @@ -349,20 +363,18 @@ impl Default for Builder { #[derive(Debug, Clone)] pub struct StopHandle { stop_sender: mpsc::Sender<()>, - stop_handle: Arc>, + stop_handle: Arc>, } impl StopHandle { /// Requests server to stop. Returns an error if server was already stopped. - /// - /// Note: This method *does not* abort spawned futures, e.g. `tokio::spawn` handlers - /// for subscriptions. It only prevents server from accepting new connections. pub async fn stop(&mut self) -> Result<(), Error> { self.stop_sender.send(()).await.map_err(|_| Error::AlreadyStopped) } /// Blocks indefinitely until the server is stopped. pub async fn wait_for_stop(&self) { - self.stop_handle.lock().await; + // blocks until there are no readers left. + self.stop_handle.write().await; } } From 7496afe201bef95d8c4e4fac7d729ca92f4f89ee Mon Sep 17 00:00:00 2001 From: Maciej Hirsz <1096222+maciejhirsz@users.noreply.github.com> Date: Wed, 30 Jun 2021 17:48:50 +0200 Subject: [PATCH 18/21] Synchronization-less async connections in ws-server (#388) * WIP * More WIP * Simplify ConnDriver * Progress all connections on each poll * Make ConnDriver more opaque and less leaky * fmt * Spawn connections on tasks after handshake * WIP put connections on tasks * cargo fmt, naming clarity * Fix grumbles * Extra comment on swap_remove * Remove unwrap from the handshake * Restore the wrapping_add on connection id --- ws-server/src/server.rs | 183 ++++++++++++++++++++++++++++------------ 1 file changed, 131 insertions(+), 52 deletions(-) diff --git a/ws-server/src/server.rs b/ws-server/src/server.rs index 96a69ee345..728461f5d2 100644 --- a/ws-server/src/server.rs +++ b/ws-server/src/server.rs @@ -24,7 +24,13 @@ // IN background_task WITH THE SOFTWARE OR THE USE OR OTHER // DEALINGS IN THE SOFTWARE. +use std::future::Future; +use std::pin::Pin; +use std::task::{Context, Poll}; +use std::{net::SocketAddr, sync::Arc}; + use futures_channel::mpsc; +use futures_util::future::{join_all, FutureExt}; use futures_util::stream::StreamExt; use futures_util::{ io::{BufReader, BufWriter}, @@ -32,13 +38,11 @@ use futures_util::{ }; use jsonrpsee_types::TEN_MB_SIZE_BYTES; use soketto::handshake::{server::Response, Server as SokettoServer}; -use std::{net::SocketAddr, sync::Arc}; use tokio::{ - net::{TcpListener, ToSocketAddrs}, + net::{TcpListener, TcpStream, ToSocketAddrs}, sync::RwLock, }; -use tokio_stream::wrappers::TcpListenerStream; -use tokio_util::compat::TokioAsyncReadCompatExt; +use tokio_util::compat::{Compat, TokioAsyncReadCompatExt}; use jsonrpsee_types::error::Error; use jsonrpsee_types::v2::error::JsonRpcErrorCode; @@ -91,64 +95,114 @@ impl Server { // Acquire read access to the lock such that additional reader(s) may share this lock. // Write access to this lock will only be possible after the server and all background tasks have stopped. let _stop_handle = self.stop_handle.read().await; + let shutdown = self.stop_pair.0; - let mut incoming = TcpListenerStream::new(self.listener).fuse(); let methods = self.methods; - let conn_counter = Arc::new(()); let mut id = 0; - let mut stop_receiver = self.stop_pair.1; - let shutdown = self.stop_pair.0; + + let mut driver = ConnDriver::new(self.listener, self.stop_pair.1); loop { - futures_util::select! { - socket = incoming.next() => { - if let Some(Ok(socket)) = socket { - if let Err(e) = socket.set_nodelay(true) { - log::error!("Could not set NODELAY on socket: {:?}", e); - continue; - } - - if Arc::strong_count(&conn_counter) > self.cfg.max_connections as usize { - log::warn!("Too many connections. Try again in a while"); - continue; - } - - let conn_counter2 = conn_counter.clone(); - let shutdown2 = shutdown.clone(); - let methods = methods.clone(); - let cfg = self.cfg.clone(); - let stop_handle2 = self.stop_handle.clone(); - - tokio::spawn(async move { - let _ = background_task(socket, id, methods, cfg, shutdown2, stop_handle2).await; - drop(conn_counter2); - }); - - id = id.wrapping_add(1); - } else { - break; + match Pin::new(&mut driver).await { + Ok((socket, _addr)) => { + if let Err(e) = socket.set_nodelay(true) { + log::error!("Could not set NODELAY on socket: {:?}", e); + continue; } - }, - stop = stop_receiver.next() => { - if stop.is_some() { - break; + + if driver.connection_count() >= self.cfg.max_connections as usize { + log::warn!("Too many connections. Try again in a while."); + continue; } - }, - complete => break, + + let methods = &methods; + let cfg = &self.cfg; + + driver.add(Box::pin(handshake(socket, id, methods, cfg, &shutdown, &self.stop_handle))); + + id = id.wrapping_add(1); + } + Err(DriverError::Io(err)) => { + log::error!("Error while awaiting a new connection: {:?}", err); + } + Err(DriverError::Shutdown) => break, } } } } -async fn background_task( +/// This is a glorified select `Future` that will attempt to drive all +/// connection futures `F` to completion on each `poll`, while also +/// handling incoming connections. +struct ConnDriver { + listener: TcpListener, + stop_receiver: mpsc::Receiver<()>, + connections: Vec, +} + +impl ConnDriver +where + F: Future + Unpin, +{ + fn new(listener: TcpListener, stop_receiver: mpsc::Receiver<()>) -> Self { + ConnDriver { listener, stop_receiver, connections: Vec::new() } + } + + fn connection_count(&self) -> usize { + self.connections.len() + } + + fn add(&mut self, conn: F) { + self.connections.push(conn); + } +} + +enum DriverError { + Shutdown, + Io(std::io::Error), +} + +impl Future for ConnDriver +where + F: Future + Unpin, +{ + type Output = Result<(TcpStream, SocketAddr), DriverError>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll { + let this = Pin::into_inner(self); + + let mut i = 0; + + while i < this.connections.len() { + if this.connections[i].poll_unpin(cx).is_ready() { + // Using `swap_remove` since we don't care about ordering + // but we do care about removing being `O(1)`. + // + // We don't increment `i` in this branch, since we now + // have a shorter length, and potentially a new value at + // current index + this.connections.swap_remove(i); + } else { + i += 1; + } + } + + if let Poll::Ready(Some(())) = this.stop_receiver.next().poll_unpin(cx) { + return Poll::Ready(Err(DriverError::Shutdown)); + } + + this.listener.poll_accept(cx).map_err(DriverError::Io) + } +} + +async fn handshake( socket: tokio::net::TcpStream, conn_id: ConnectionId, - methods: Methods, - cfg: Settings, - shutdown: mpsc::Sender<()>, - stop_handle: Arc>, + methods: &Methods, + cfg: &Settings, + shutdown: &mpsc::Sender<()>, + stop_handle: &Arc>, ) -> Result<(), Error> { - let _lock = stop_handle.read().await; // For each incoming background_task we perform a handshake. let mut server = SokettoServer::new(BufReader::new(BufWriter::new(socket.compat()))); @@ -170,6 +224,31 @@ async fn background_task( } } + let join_result = tokio::spawn(background_task( + server, + conn_id, + methods.clone(), + cfg.max_request_body_size, + shutdown.clone(), + stop_handle.clone(), + )) + .await; + + match join_result { + Err(_) => Err(Error::Custom("Background task was aborted".into())), + Ok(result) => result, + } +} + +async fn background_task( + server: SokettoServer<'_, BufReader>>>, + conn_id: ConnectionId, + methods: Methods, + max_request_body_size: u32, + shutdown: mpsc::Sender<()>, + stop_handle: Arc>, +) -> Result<(), Error> { + let _lock = stop_handle.read().await; // And we can finally transition to a websocket background_task. let (mut sender, mut receiver) = server.into_builder().finish(); let (tx, mut rx) = mpsc::unbounded::(); @@ -199,8 +278,8 @@ async fn background_task( receiver.receive_data(&mut data).await?; - if data.len() > cfg.max_request_body_size as usize { - log::warn!("Request is too big ({} bytes, max is {})", data.len(), cfg.max_request_body_size); + if data.len() > max_request_body_size as usize { + log::warn!("Request is too big ({} bytes, max is {})", data.len(), max_request_body_size); send_error(Id::Null, &tx, JsonRpcErrorCode::OversizedRequest.into()); continue; } @@ -219,9 +298,9 @@ async fn background_task( // batch and read the results off of a new channel, `rx_batch`, and then send the complete batch response // back to the client over `tx`. let (tx_batch, mut rx_batch) = mpsc::unbounded::(); - for req in batch { - methods.execute(&tx_batch, req, conn_id).await; - } + + join_all(batch.into_iter().map(|req| methods.execute(&tx_batch, req, conn_id))).await; + // Closes the receiving half of a channel without dropping it. This prevents any further messages from // being sent on the channel. rx_batch.close(); From f705e325a2524024c598f4c6f015863d6e1ec0a0 Mon Sep 17 00:00:00 2001 From: Maciej Hirsz <1096222+maciejhirsz@users.noreply.github.com> Date: Thu, 1 Jul 2021 18:12:29 +0200 Subject: [PATCH 19/21] Set allowed Host header values (#399) * Set allowed Host header values * Error if allowed hosts list is empty * Grammar Co-authored-by: David Co-authored-by: David --- types/src/error.rs | 4 +-- ws-server/src/server.rs | 75 +++++++++++++++++++++++++++++++---------- 2 files changed, 60 insertions(+), 19 deletions(-) diff --git a/types/src/error.rs b/types/src/error.rs index 7ca23178ad..0dc9a6c259 100644 --- a/types/src/error.rs +++ b/types/src/error.rs @@ -96,8 +96,8 @@ pub enum Error { #[error("Attempted to stop server that is already stopped")] AlreadyStopped, /// List passed into `set_allowed_origins` was empty - #[error("Must set at least one allowed origin")] - EmptyAllowedOrigins, + #[error("Must set at least one allowed value for the {0} header")] + EmptyAllowList(&'static str), /// Custom error. #[error("Custom error: {0}")] Custom(String), diff --git a/ws-server/src/server.rs b/ws-server/src/server.rs index 728461f5d2..020b8826c6 100644 --- a/ws-server/src/server.rs +++ b/ws-server/src/server.rs @@ -208,7 +208,10 @@ async fn handshake( let key = { let req = server.receive_request().await?; - cfg.allowed_origins.verify(req.headers().origin).map(|()| req.key()) + let host_check = cfg.allowed_hosts.verify("Host", Some(req.headers().host)); + let origin_check = cfg.allowed_origins.verify("Origin", req.headers().origin); + + host_check.and(origin_check).map(|()| req.key()) }; match key { @@ -324,16 +327,16 @@ async fn background_task( } #[derive(Debug, Clone)] -enum AllowedOrigins { +enum AllowedValue { Any, - OneOf(Arc<[String]>), + OneOf(Box<[String]>), } -impl AllowedOrigins { - fn verify(&self, origin: Option<&[u8]>) -> Result<(), Error> { - if let (AllowedOrigins::OneOf(list), Some(origin)) = (self, origin) { - if !list.iter().any(|o| o.as_bytes() == origin) { - let error = format!("Origin denied: {}", String::from_utf8_lossy(origin)); +impl AllowedValue { + fn verify(&self, header: &str, value: Option<&[u8]>) -> Result<(), Error> { + if let (AllowedValue::OneOf(list), Some(value)) = (self, value) { + if !list.iter().any(|o| o.as_bytes() == value) { + let error = format!("{} denied: {}", header, String::from_utf8_lossy(value)); log::warn!("{}", error); return Err(Error::Request(error)); } @@ -350,8 +353,10 @@ struct Settings { max_request_body_size: u32, /// Maximum number of incoming connections allowed. max_connections: u64, - /// Cross-origin policy by which to accept or deny incoming requests. - allowed_origins: AllowedOrigins, + /// Policy by which to accept or deny incoming requests based on the `Origin` header. + allowed_origins: AllowedValue, + /// Policy by which to accept or deny incoming requests based on the `Host` header. + allowed_hosts: AllowedValue, } impl Default for Settings { @@ -359,7 +364,8 @@ impl Default for Settings { Self { max_request_body_size: TEN_MB_SIZE_BYTES, max_connections: MAX_CONNECTIONS, - allowed_origins: AllowedOrigins::Any, + allowed_origins: AllowedValue::Any, + allowed_hosts: AllowedValue::Any, } } } @@ -385,11 +391,11 @@ impl Builder { /// Set a list of allowed origins. During the handshake, the `Origin` header will be /// checked against the list, connections without a matching origin will be denied. - /// Values should include protocol. + /// Values should be hostnames with protocol. /// /// ```rust /// # let mut builder = jsonrpsee_ws_server::WsServerBuilder::default(); - /// builder.set_allowed_origins(vec!["https://example.com"]); + /// builder.set_allowed_origins(["https://example.com"]); /// ``` /// /// By default allows any `Origin`. @@ -400,13 +406,13 @@ impl Builder { List: IntoIterator, Origin: Into, { - let list: Arc<_> = list.into_iter().map(Into::into).collect(); + let list: Box<_> = list.into_iter().map(Into::into).collect(); if list.len() == 0 { - return Err(Error::EmptyAllowedOrigins); + return Err(Error::EmptyAllowList("Origin")); } - self.settings.allowed_origins = AllowedOrigins::OneOf(list); + self.settings.allowed_origins = AllowedValue::OneOf(list); Ok(self) } @@ -414,7 +420,42 @@ impl Builder { /// Restores the default behavior of allowing connections with `Origin` header /// containing any value. This will undo any list set by [`set_allowed_origins`](Builder::set_allowed_origins). pub fn allow_all_origins(mut self) -> Self { - self.settings.allowed_origins = AllowedOrigins::Any; + self.settings.allowed_origins = AllowedValue::Any; + self + } + + /// Set a list of allowed hosts. During the handshake, the `Host` header will be + /// checked against the list. Connections without a matching host will be denied. + /// Values should be hostnames without protocol. + /// + /// ```rust + /// # let mut builder = jsonrpsee_ws_server::WsServerBuilder::default(); + /// builder.set_allowed_hosts(["example.com"]); + /// ``` + /// + /// By default allows any `Host`. + /// + /// Will return an error if `list` is empty. Use [`allow_all_hosts`](Builder::allow_all_hosts) to restore the default. + pub fn set_allowed_hosts(mut self, list: List) -> Result + where + List: IntoIterator, + Host: Into, + { + let list: Box<_> = list.into_iter().map(Into::into).collect(); + + if list.len() == 0 { + return Err(Error::EmptyAllowList("Host")); + } + + self.settings.allowed_hosts = AllowedValue::OneOf(list); + + Ok(self) + } + + /// Restores the default behavior of allowing connections with `Host` header + /// containing any value. This will undo any list set by [`set_allowed_hosts`](Builder::set_allowed_hosts). + pub fn allow_all_hosts(mut self) -> Self { + self.settings.allowed_hosts = AllowedValue::Any; self } From 095db9b2d5d25f4ba822ff5878150563f1643f41 Mon Sep 17 00:00:00 2001 From: Maciej Hirsz <1096222+maciejhirsz@users.noreply.github.com> Date: Thu, 1 Jul 2021 18:12:45 +0200 Subject: [PATCH 20/21] Streaming RpcParams parsing (#401) * Streaming RpcParams parsing * DRY RpcParams::one again * Fix doc comments --- types/src/v2/params.rs | 99 +++++++++++++++++++++++++++++++--- utils/src/server/rpc_module.rs | 4 +- 2 files changed, 94 insertions(+), 9 deletions(-) diff --git a/types/src/v2/params.rs b/types/src/v2/params.rs index 3b7777ea69..890a4caea9 100644 --- a/types/src/v2/params.rs +++ b/types/src/v2/params.rs @@ -68,7 +68,90 @@ impl<'a> RpcParams<'a> { Self(raw) } - /// Attempt to parse all parameters as array or map into type T + fn next_inner(&mut self) -> Option> + where + T: Deserialize<'a>, + { + let mut json = self.0?.trim_start(); + + match json.as_bytes().get(0)? { + b']' => { + self.0 = None; + + return None; + } + b'[' | b',' => json = &json[1..], + _ => return Some(Err(CallError::InvalidParams)), + } + + let mut iter = serde_json::Deserializer::from_str(json).into_iter::(); + + match iter.next()? { + Ok(value) => { + self.0 = Some(&json[iter.byte_offset()..]); + + Some(Ok(value)) + } + Err(_) => { + self.0 = None; + + Some(Err(CallError::InvalidParams)) + } + } + } + + /// Parse the next parameter to type `T` + /// + /// ``` + /// # use jsonrpsee_types::v2::params::RpcParams; + /// let mut params = RpcParams::new(Some(r#"[true, 10, "foo"]"#)); + /// + /// let a: bool = params.next().unwrap(); + /// let b: i32 = params.next().unwrap(); + /// let c: &str = params.next().unwrap(); + /// + /// assert_eq!(a, true); + /// assert_eq!(b, 10); + /// assert_eq!(c, "foo"); + /// ``` + pub fn next(&mut self) -> Result + where + T: Deserialize<'a>, + { + match self.next_inner() { + Some(result) => result, + None => Err(CallError::InvalidParams), + } + } + + /// Parse the next optional parameter to type `Option`. + /// + /// The result will be `None` for `null`, and for missing values in the supplied JSON array. + /// + /// ``` + /// # use jsonrpsee_types::v2::params::RpcParams; + /// let mut params = RpcParams::new(Some(r#"[1, 2, null]"#)); + /// + /// let params: [Option; 4] = [ + /// params.optional_next().unwrap(), + /// params.optional_next().unwrap(), + /// params.optional_next().unwrap(), + /// params.optional_next().unwrap(), + /// ];; + /// + /// assert_eq!(params, [Some(1), Some(2), None, None]); + /// ``` + pub fn optional_next(&mut self) -> Result, CallError> + where + T: Deserialize<'a>, + { + match self.next_inner::>() { + Some(result) => result, + None => Ok(None), + } + } + + /// Attempt to parse all parameters as array or map into type `T` pub fn parse(self) -> Result where T: Deserialize<'a>, @@ -77,7 +160,7 @@ impl<'a> RpcParams<'a> { serde_json::from_str(params).map_err(|_| CallError::InvalidParams) } - /// Attempt to parse only the first parameter from an array into type T + /// Attempt to parse parameters as an array of a single value of type `T`, and returns that value. pub fn one(self) -> Result where T: Deserialize<'a>, @@ -288,15 +371,17 @@ mod test { #[test] fn params_parse() { - let none = RpcParams::new(None); - assert!(none.one::().is_err()); + let mut none = RpcParams::new(None); + assert!(none.next::().is_err()); - let array_params = RpcParams::new(Some("[1, 2, 3]")); + let mut array_params = RpcParams::new(Some("[1, 2, 3]")); let arr: Result<[u64; 3], _> = array_params.parse(); assert!(arr.is_ok()); - let arr: Result<(u64, u64, u64), _> = array_params.parse(); - assert!(arr.is_ok()); + assert_eq!(array_params.next::().unwrap(), 1); + assert_eq!(array_params.next::().unwrap(), 2); + assert_eq!(array_params.next::().unwrap(), 3); + assert!(array_params.next::().is_err()); let array_one = RpcParams::new(Some("[1]")); let one: Result = array_one.one(); diff --git a/utils/src/server/rpc_module.rs b/utils/src/server/rpc_module.rs index 5342045e3e..ee09d1b4ff 100644 --- a/utils/src/server/rpc_module.rs +++ b/utils/src/server/rpc_module.rs @@ -267,8 +267,8 @@ impl RpcModule { /// use jsonrpsee_utils::server::rpc_module::RpcModule; /// /// let mut ctx = RpcModule::new(99_usize); - /// ctx.register_subscription("sub", "unsub", |params, mut sink, ctx| { - /// let x: usize = params.one()?; + /// ctx.register_subscription("sub", "unsub", |mut params, mut sink, ctx| { + /// let x: usize = params.next()?; /// std::thread::spawn(move || { /// let sum = x + (*ctx); /// sink.send(&sum) From ddb508063269d510210caa106da68b8fad20eeae Mon Sep 17 00:00:00 2001 From: Igor Aleksanov Date: Thu, 1 Jul 2021 20:37:27 +0400 Subject: [PATCH 21/21] New proc macro (#387) * Start working on the new proc macro system * Add skeleton for rendering * Improve error reporting * Main part of 'render_client' * Implement RPC client generation * Client successfully rendered * Add doc-comment generation for the API client * Check that all the methods have receiver * Start working on the server impl * Add helper method to find server crate * Fix usage of client rpc path * Decent progress on the server macro implementation * Server macro compiled successfully * Remove unneeded re-export * Insert SubscriptionSink argument to the subscription server signatures * Add basic doc-comment for the macro * no_run -> ignore * Trait with subscription compiles * Extend the example * Add integration test for client/server impl * Add trybuild setup * Set correct span for attribute parsing related errors * Add basic set of trybuild tests * Add tests for client and server generated separately * Improve proc-macro documentation * Update proc-macros/src/lib.rs Co-authored-by: Niklas Adolfsson * Fix a couple of bugs in docs * Fix rendering subscription with params Co-authored-by: Niklas Adolfsson --- jsonrpsee/src/lib.rs | 8 +- proc-macros/Cargo.toml | 7 + proc-macros/src/client_builder.rs | 170 ++++++ proc-macros/src/helpers.rs | 58 +++ proc-macros/src/lib.rs | 482 ++++++++++-------- proc-macros/src/new/attributes.rs | 51 ++ proc-macros/src/new/mod.rs | 220 ++++++++ proc-macros/src/new/render_client.rs | 134 +++++ proc-macros/src/new/render_server.rs | 263 ++++++++++ proc-macros/src/new/respan.rs | 16 + proc-macros/tests/rpc_example.rs | 15 + proc-macros/tests/ui.rs | 18 + proc-macros/tests/ui/correct/basic.rs | 80 +++ proc-macros/tests/ui/correct/only_client.rs | 17 + proc-macros/tests/ui/correct/only_server.rs | 59 +++ .../ui/incorrect/method/method_no_name.rs | 10 + .../ui/incorrect/method/method_no_name.stderr | 5 + .../method/method_unexpected_field.rs | 10 + .../method/method_unexpected_field.stderr | 5 + .../tests/ui/incorrect/rpc/rpc_assoc_items.rs | 20 + .../ui/incorrect/rpc/rpc_assoc_items.stderr | 11 + .../tests/ui/incorrect/rpc/rpc_empty.rs | 7 + .../tests/ui/incorrect/rpc/rpc_empty.stderr | 5 + .../tests/ui/incorrect/rpc/rpc_no_impls.rs | 10 + .../ui/incorrect/rpc/rpc_no_impls.stderr | 5 + .../ui/incorrect/rpc/rpc_not_qualified.rs | 9 + .../ui/incorrect/rpc/rpc_not_qualified.stderr | 5 + .../tests/ui/incorrect/sub/sub_async.rs | 10 + .../tests/ui/incorrect/sub/sub_async.stderr | 6 + .../tests/ui/incorrect/sub/sub_empty_attr.rs | 10 + .../ui/incorrect/sub/sub_empty_attr.stderr | 5 + .../tests/ui/incorrect/sub/sub_no_item.rs | 10 + .../tests/ui/incorrect/sub/sub_no_item.stderr | 5 + .../tests/ui/incorrect/sub/sub_no_name.rs | 10 + .../tests/ui/incorrect/sub/sub_no_name.stderr | 6 + .../tests/ui/incorrect/sub/sub_no_unsub.rs | 10 + .../ui/incorrect/sub/sub_no_unsub.stderr | 5 + .../tests/ui/incorrect/sub/sub_return_type.rs | 10 + .../ui/incorrect/sub/sub_return_type.stderr | 6 + .../ui/incorrect/sub/sub_unsupported_field.rs | 10 + .../sub/sub_unsupported_field.stderr | 5 + tests/tests/new_proc_macros.rs | 93 ++++ types/src/lib.rs | 10 + types/src/v2/params.rs | 7 + 44 files changed, 1709 insertions(+), 209 deletions(-) create mode 100644 proc-macros/src/client_builder.rs create mode 100644 proc-macros/src/helpers.rs create mode 100644 proc-macros/src/new/attributes.rs create mode 100644 proc-macros/src/new/mod.rs create mode 100644 proc-macros/src/new/render_client.rs create mode 100644 proc-macros/src/new/render_server.rs create mode 100644 proc-macros/src/new/respan.rs create mode 100644 proc-macros/tests/rpc_example.rs create mode 100644 proc-macros/tests/ui.rs create mode 100644 proc-macros/tests/ui/correct/basic.rs create mode 100644 proc-macros/tests/ui/correct/only_client.rs create mode 100644 proc-macros/tests/ui/correct/only_server.rs create mode 100644 proc-macros/tests/ui/incorrect/method/method_no_name.rs create mode 100644 proc-macros/tests/ui/incorrect/method/method_no_name.stderr create mode 100644 proc-macros/tests/ui/incorrect/method/method_unexpected_field.rs create mode 100644 proc-macros/tests/ui/incorrect/method/method_unexpected_field.stderr create mode 100644 proc-macros/tests/ui/incorrect/rpc/rpc_assoc_items.rs create mode 100644 proc-macros/tests/ui/incorrect/rpc/rpc_assoc_items.stderr create mode 100644 proc-macros/tests/ui/incorrect/rpc/rpc_empty.rs create mode 100644 proc-macros/tests/ui/incorrect/rpc/rpc_empty.stderr create mode 100644 proc-macros/tests/ui/incorrect/rpc/rpc_no_impls.rs create mode 100644 proc-macros/tests/ui/incorrect/rpc/rpc_no_impls.stderr create mode 100644 proc-macros/tests/ui/incorrect/rpc/rpc_not_qualified.rs create mode 100644 proc-macros/tests/ui/incorrect/rpc/rpc_not_qualified.stderr create mode 100644 proc-macros/tests/ui/incorrect/sub/sub_async.rs create mode 100644 proc-macros/tests/ui/incorrect/sub/sub_async.stderr create mode 100644 proc-macros/tests/ui/incorrect/sub/sub_empty_attr.rs create mode 100644 proc-macros/tests/ui/incorrect/sub/sub_empty_attr.stderr create mode 100644 proc-macros/tests/ui/incorrect/sub/sub_no_item.rs create mode 100644 proc-macros/tests/ui/incorrect/sub/sub_no_item.stderr create mode 100644 proc-macros/tests/ui/incorrect/sub/sub_no_name.rs create mode 100644 proc-macros/tests/ui/incorrect/sub/sub_no_name.stderr create mode 100644 proc-macros/tests/ui/incorrect/sub/sub_no_unsub.rs create mode 100644 proc-macros/tests/ui/incorrect/sub/sub_no_unsub.stderr create mode 100644 proc-macros/tests/ui/incorrect/sub/sub_return_type.rs create mode 100644 proc-macros/tests/ui/incorrect/sub/sub_return_type.stderr create mode 100644 proc-macros/tests/ui/incorrect/sub/sub_unsupported_field.rs create mode 100644 proc-macros/tests/ui/incorrect/sub/sub_unsupported_field.stderr create mode 100644 tests/tests/new_proc_macros.rs diff --git a/jsonrpsee/src/lib.rs b/jsonrpsee/src/lib.rs index b23fa428d1..b758b000a4 100644 --- a/jsonrpsee/src/lib.rs +++ b/jsonrpsee/src/lib.rs @@ -26,4 +26,10 @@ pub use proc_macros; /// Common types used to implement JSON RPC server and client. #[cfg(feature = "macros")] -pub use types; +pub mod types { + pub use ::types::*; + + /// Set of RPC methods that can be mounted to the server. + #[cfg(feature = "server")] + pub use utils::server::rpc_module::{RpcModule, SubscriptionSink}; +} diff --git a/proc-macros/Cargo.toml b/proc-macros/Cargo.toml index 93d1c71086..222b3c7c30 100644 --- a/proc-macros/Cargo.toml +++ b/proc-macros/Cargo.toml @@ -18,3 +18,10 @@ proc-macro2 = "1.0" quote = "1.0" syn = { version = "1.0", default-features = false, features = ["extra-traits", "full"] } proc-macro-crate = "1" +bae = "0.1.6" + +[dev-dependencies] +jsonrpsee = { path = "../jsonrpsee", features = ["full"] } +trybuild = "1.0" +tokio = { version = "1", features = ["rt", "macros"] } +futures-channel = { version = "0.3.14", default-features = false } diff --git a/proc-macros/src/client_builder.rs b/proc-macros/src/client_builder.rs new file mode 100644 index 0000000000..2f7d965d90 --- /dev/null +++ b/proc-macros/src/client_builder.rs @@ -0,0 +1,170 @@ +use quote::{format_ident, quote, quote_spanned}; +use std::collections::HashSet; +use syn::spanned::Spanned as _; + +use crate::helpers::*; + +/// Generates the macro output token stream corresponding to a single API. +pub fn build_client_api(api: crate::api_def::ApiDefinition) -> Result { + let enum_name = &api.name; + let visibility = &api.visibility; + let generics = api.generics.clone(); + let mut non_used_type_params = HashSet::new(); + + let mut variants = Vec::new(); + for function in &api.definitions { + let variant_name = snake_case_to_camel_case(&function.signature.ident); + if let syn::ReturnType::Type(_, ty) = &function.signature.output { + non_used_type_params.insert(ty); + }; + + let mut params_list = Vec::new(); + + for input in function.signature.inputs.iter() { + let (ty, pat_span, param_variant_name) = match input { + syn::FnArg::Receiver(_) => { + return Err(syn::Error::new( + input.span(), + "Having `self` is not allowed in RPC queries definitions", + )); + } + syn::FnArg::Typed(syn::PatType { ty, pat, .. }) => (ty, pat.span(), param_variant_name(&pat)?), + }; + params_list.push(quote_spanned!(pat_span=> #param_variant_name: #ty)); + } + + variants.push(quote_spanned!(function.signature.ident.span()=> + #variant_name { + #(#params_list,)* + } + )); + } + + let client_impl_block = build_client_impl(&api)?; + + let mut ret_variants = Vec::new(); + for (idx, ty) in non_used_type_params.into_iter().enumerate() { + // NOTE(niklasad1): variant names are converted from `snake_case` to `CamelCase` + // It's impossible to have collisions between `_0, _1, ... _N` + // Because variant name `_0`, `__0` becomes `0` in `CamelCase` + // then `0` is not a valid identifier in Rust syntax and the error message is hard to understand. + // Perhaps document this in macro when it's ready. + let varname = format_ident!("_{}", idx); + ret_variants.push(quote_spanned!(ty.span()=> #varname (#ty))); + } + + Ok(quote_spanned!(api.name.span()=> + #visibility enum #enum_name #generics { + #(#[allow(unused)] #variants,)* #(#[allow(unused)] #ret_variants,)* + } + + #client_impl_block + )) +} + +/// Builds the impl block that allow performing outbound JSON-RPC queries. +/// +/// Generates the `impl { }` block containing functions that perform RPC client calls. +fn build_client_impl(api: &crate::api_def::ApiDefinition) -> Result { + let enum_name = &api.name; + + let (impl_generics_org, type_generics, where_clause_org) = api.generics.split_for_impl(); + let client_functions = build_client_functions(&api)?; + + Ok(quote_spanned!(api.name.span() => + impl #impl_generics_org #enum_name #type_generics #where_clause_org { + #(#client_functions)* + } + )) +} + +/// Builds the functions that allow performing outbound JSON-RPC queries. +/// +/// Generates a list of functions that perform RPC client calls. +fn build_client_functions(api: &crate::api_def::ApiDefinition) -> Result, syn::Error> { + let visibility = &api.visibility; + + let _crate = find_jsonrpsee_client_crate()?; + + let mut client_functions = Vec::new(); + for function in &api.definitions { + let f_name = &function.signature.ident; + let ret_ty = match function.signature.output { + syn::ReturnType::Default => quote!(()), + syn::ReturnType::Type(_, ref ty) => quote_spanned!(ty.span()=> #ty), + }; + let rpc_method_name = + function.attributes.method.clone().unwrap_or_else(|| function.signature.ident.to_string()); + + let mut params_list = Vec::new(); + let mut params_to_json = Vec::new(); + let mut params_to_array = Vec::new(); + let mut params_tys = Vec::new(); + + for (param_index, input) in function.signature.inputs.iter().enumerate() { + let (ty, pat_span, rpc_param_name) = match input { + syn::FnArg::Receiver(_) => { + return Err(syn::Error::new( + input.span(), + "Having `self` is not allowed in RPC queries definitions", + )); + } + syn::FnArg::Typed(syn::PatType { ty, pat, attrs, .. }) => { + (ty, pat.span(), rpc_param_name(&pat, &attrs)?) + } + }; + + let generated_param_name = + syn::Ident::new(&format!("param{}", param_index), proc_macro2::Span::call_site()); + + params_tys.push(ty); + params_list.push(quote_spanned!(pat_span=> #generated_param_name: impl Into<#ty>)); + params_to_json.push(quote_spanned!(pat_span=> + map.insert( + #rpc_param_name, + #_crate::to_json_value(#generated_param_name.into()).map_err(#_crate::Error::ParseError)? + ); + )); + params_to_array.push(quote_spanned!(pat_span=> + #_crate::to_json_value(#generated_param_name.into()).map_err(#_crate::Error::ParseError)? + )); + } + + let params_building = if params_list.is_empty() { + quote_spanned!(function.signature.span()=> #_crate::v2::params::JsonRpcParams::NoParams) + } else if function.attributes.positional_params { + quote_spanned!(function.signature.span()=> vec![#(#params_to_array),*].into()) + } else { + quote_spanned!(function.signature.span()=> + { + let mut map = std::collections::BTreeMap::new(); + #(#params_to_json)* + map.into() + } + ) + }; + + let is_notification = function.is_void_ret_type(); + let function_body = if is_notification { + quote_spanned!(function.signature.span()=> + client.notification(#rpc_method_name, #params_building).await + ) + } else { + quote_spanned!(function.signature.span()=> + client.request(#rpc_method_name, #params_building).await + ) + }; + + client_functions.push(quote_spanned!(function.signature.span()=> + #visibility async fn #f_name (client: &impl #_crate::traits::Client #(, #params_list)*) -> core::result::Result<#ret_ty, #_crate::Error> + where + #ret_ty: #_crate::DeserializeOwned + #(, #params_tys: #_crate::Serialize)* + { + #function_body + } + )); + } + + Ok(client_functions) +} diff --git a/proc-macros/src/helpers.rs b/proc-macros/src/helpers.rs new file mode 100644 index 0000000000..30d93f3469 --- /dev/null +++ b/proc-macros/src/helpers.rs @@ -0,0 +1,58 @@ +use inflector::Inflector as _; +use proc_macro2::Span; +use proc_macro_crate::{crate_name, FoundCrate}; +use quote::quote; + +/// Turns a snake case function name into an UpperCamelCase name suitable to be an enum variant. +pub(crate) fn snake_case_to_camel_case(snake_case: &syn::Ident) -> syn::Ident { + syn::Ident::new(&snake_case.to_string().to_pascal_case(), snake_case.span()) +} + +/// Determine the name of the variant in the enum based on the pattern of the function parameter. +pub(crate) fn param_variant_name(pat: &syn::Pat) -> syn::parse::Result<&syn::Ident> { + match pat { + // TODO: check other fields of the `PatIdent` + syn::Pat::Ident(ident) => Ok(&ident.ident), + _ => unimplemented!(), + } +} + +/// Determine the name of the parameter based on the pattern. +pub(crate) fn rpc_param_name(pat: &syn::Pat, _attrs: &[syn::Attribute]) -> syn::parse::Result { + // TODO: look in attributes if the user specified a param name + match pat { + // TODO: check other fields of the `PatIdent` + syn::Pat::Ident(ident) => Ok(ident.ident.to_string()), + _ => unimplemented!(), + } +} + +/// Search for client-side `jsonrpsee` in `Cargo.toml`. +pub(crate) fn find_jsonrpsee_client_crate() -> Result { + find_jsonrpsee_crate("jsonrpsee-http-client", "jsonrpsee-ws-client") +} + +/// Search for server-side `jsonrpsee` in `Cargo.toml`. +pub(crate) fn find_jsonrpsee_server_crate() -> Result { + find_jsonrpsee_crate("jsonrpsee-http-server", "jsonrpsee-ws-server") +} + +fn find_jsonrpsee_crate(http_name: &str, ws_name: &str) -> Result { + match crate_name("jsonrpsee") { + Ok(FoundCrate::Name(name)) => { + let ident = syn::Ident::new(&name, Span::call_site()); + Ok(quote!(#ident::types)) + } + Ok(FoundCrate::Itself) => panic!("Deriving RPC methods in any of the `jsonrpsee crates` is not supported"), + Err(_) => match (crate_name(http_name), crate_name(ws_name)) { + (Ok(FoundCrate::Name(name)), _) | (_, Ok(FoundCrate::Name(name))) => { + let ident = syn::Ident::new(&name, Span::call_site()); + Ok(quote!(#ident)) + } + (Ok(FoundCrate::Itself), _) | (_, Ok(FoundCrate::Itself)) => { + panic!("Deriving RPC methods in any of the `jsonrpsee crates` is not supported") + } + (_, Err(e)) => Err(syn::Error::new(Span::call_site(), &e)), + }, + } +} diff --git a/proc-macros/src/lib.rs b/proc-macros/src/lib.rs index 659ab78122..80af4b1c04 100644 --- a/proc-macros/src/lib.rs +++ b/proc-macros/src/lib.rs @@ -26,15 +26,14 @@ extern crate proc_macro; -use inflector::Inflector as _; +use new::RpcDescription; use proc_macro::TokenStream; -use proc_macro2::Span; -use proc_macro_crate::{crate_name, FoundCrate}; -use quote::{format_ident, quote, quote_spanned}; -use std::collections::HashSet; -use syn::spanned::Spanned as _; +use quote::quote; mod api_def; +mod client_builder; +mod helpers; +mod new; /// Wraps around one or more API definitions and generates an enum. /// @@ -100,7 +99,7 @@ pub fn rpc_client_api(input_token_stream: TokenStream) -> TokenStream { let mut out = Vec::with_capacity(defs.apis.len()); for api in defs.apis { - match build_client_api(api) { + match client_builder::build_client_api(api) { Ok(a) => out.push(a), Err(err) => return err.to_compile_error().into(), }; @@ -111,212 +110,279 @@ pub fn rpc_client_api(input_token_stream: TokenStream) -> TokenStream { }) } -/// Generates the macro output token stream corresponding to a single API. -fn build_client_api(api: api_def::ApiDefinition) -> Result { - let enum_name = &api.name; - let visibility = &api.visibility; - let generics = api.generics.clone(); - let mut non_used_type_params = HashSet::new(); +// New implementation starts here. - let mut variants = Vec::new(); - for function in &api.definitions { - let variant_name = snake_case_to_camel_case(&function.signature.ident); - if let syn::ReturnType::Type(_, ty) = &function.signature.output { - non_used_type_params.insert(ty); - }; - - let mut params_list = Vec::new(); - - for input in function.signature.inputs.iter() { - let (ty, pat_span, param_variant_name) = match input { - syn::FnArg::Receiver(_) => { - return Err(syn::Error::new( - input.span(), - "Having `self` is not allowed in RPC queries definitions", - )); - } - syn::FnArg::Typed(syn::PatType { ty, pat, .. }) => (ty, pat.span(), param_variant_name(&pat)?), - }; - params_list.push(quote_spanned!(pat_span=> #param_variant_name: #ty)); - } - - variants.push(quote_spanned!(function.signature.ident.span()=> - #variant_name { - #(#params_list,)* - } - )); - } - - let client_impl_block = build_client_impl(&api)?; - - let mut ret_variants = Vec::new(); - for (idx, ty) in non_used_type_params.into_iter().enumerate() { - // NOTE(niklasad1): variant names are converted from `snake_case` to `CamelCase` - // It's impossible to have collisions between `_0, _1, ... _N` - // Because variant name `_0`, `__0` becomes `0` in `CamelCase` - // then `0` is not a valid identifier in Rust syntax and the error message is hard to understand. - // Perhaps document this in macro when it's ready. - let varname = format_ident!("_{}", idx); - ret_variants.push(quote_spanned!(ty.span()=> #varname (#ty))); - } - - Ok(quote_spanned!(api.name.span()=> - #visibility enum #enum_name #generics { - #(#[allow(unused)] #variants,)* #(#[allow(unused)] #ret_variants,)* - } - - #client_impl_block - )) -} - -/// Builds the impl block that allow performing outbound JSON-RPC queries. +/// Main RPC macro. /// -/// Generates the `impl { }` block containing functions that perform RPC client calls. -fn build_client_impl(api: &api_def::ApiDefinition) -> Result { - let enum_name = &api.name; - - let (impl_generics_org, type_generics, where_clause_org) = api.generics.split_for_impl(); - let client_functions = build_client_functions(&api)?; - - Ok(quote_spanned!(api.name.span() => - impl #impl_generics_org #enum_name #type_generics #where_clause_org { - #(#client_functions)* - } - )) -} - -/// Builds the functions that allow performing outbound JSON-RPC queries. +/// ## Description /// -/// Generates a list of functions that perform RPC client calls. -fn build_client_functions(api: &api_def::ApiDefinition) -> Result, syn::Error> { - let visibility = &api.visibility; - - let _crate = find_jsonrpsee_crate()?; - - let mut client_functions = Vec::new(); - for function in &api.definitions { - let f_name = &function.signature.ident; - let ret_ty = match function.signature.output { - syn::ReturnType::Default => quote!(()), - syn::ReturnType::Type(_, ref ty) => quote_spanned!(ty.span()=> #ty), - }; - let rpc_method_name = - function.attributes.method.clone().unwrap_or_else(|| function.signature.ident.to_string()); - - let mut params_list = Vec::new(); - let mut params_to_json = Vec::new(); - let mut params_to_array = Vec::new(); - let mut params_tys = Vec::new(); - - for (param_index, input) in function.signature.inputs.iter().enumerate() { - let (ty, pat_span, rpc_param_name) = match input { - syn::FnArg::Receiver(_) => { - return Err(syn::Error::new( - input.span(), - "Having `self` is not allowed in RPC queries definitions", - )); - } - syn::FnArg::Typed(syn::PatType { ty, pat, attrs, .. }) => { - (ty, pat.span(), rpc_param_name(&pat, &attrs)?) - } - }; - - let generated_param_name = - syn::Ident::new(&format!("param{}", param_index), proc_macro2::Span::call_site()); - - params_tys.push(ty); - params_list.push(quote_spanned!(pat_span=> #generated_param_name: impl Into<#ty>)); - params_to_json.push(quote_spanned!(pat_span=> - map.insert( - #rpc_param_name, - #_crate::to_json_value(#generated_param_name.into()).map_err(#_crate::Error::ParseError)? - ); - )); - params_to_array.push(quote_spanned!(pat_span=> - #_crate::to_json_value(#generated_param_name.into()).map_err(#_crate::Error::ParseError)? - )); - } - - let params_building = if params_list.is_empty() { - quote_spanned!(function.signature.span()=> #_crate::v2::params::JsonRpcParams::NoParams) - } else if function.attributes.positional_params { - quote_spanned!(function.signature.span()=> vec![#(#params_to_array),*].into()) - } else { - quote_spanned!(function.signature.span()=> - { - let mut map = std::collections::BTreeMap::new(); - #(#params_to_json)* - map.into() - } - ) - }; - - let is_notification = function.is_void_ret_type(); - let function_body = if is_notification { - quote_spanned!(function.signature.span()=> - client.notification(#rpc_method_name, #params_building).await - ) - } else { - quote_spanned!(function.signature.span()=> - client.request(#rpc_method_name, #params_building).await - ) - }; - - client_functions.push(quote_spanned!(function.signature.span()=> - #visibility async fn #f_name (client: &impl #_crate::traits::Client #(, #params_list)*) -> core::result::Result<#ret_ty, #_crate::Error> - where - #ret_ty: #_crate::DeserializeOwned - #(, #params_tys: #_crate::Serialize)* - { - #function_body - } - )); - } - - Ok(client_functions) -} - -/// Turns a snake case function name into an UpperCamelCase name suitable to be an enum variant. -fn snake_case_to_camel_case(snake_case: &syn::Ident) -> syn::Ident { - syn::Ident::new(&snake_case.to_string().to_pascal_case(), snake_case.span()) -} - -/// Determine the name of the variant in the enum based on the pattern of the function parameter. -fn param_variant_name(pat: &syn::Pat) -> syn::parse::Result<&syn::Ident> { - match pat { - // TODO: check other fields of the `PatIdent` - syn::Pat::Ident(ident) => Ok(&ident.ident), - _ => unimplemented!(), - } -} +/// This macro is capable of generating both server and client implementations on demand. +/// Based on the attributes provided to the `rpc` macro, either one or both of implementations +/// will be generated. +/// +/// For clients, it will be an extension trait that will add all the required methods to any +/// type that implements `Client` or `SubscriptionClient` (depending on whether trait has +/// subscriptions methods or not), namely `HttpClient` and `WsClient`. +/// +/// For servers, it will generate a trait mostly equivalent to the initial one, with two main +/// differences: +/// +/// - This trait will have one additional (already implemented) method `into_rpc`, which +/// will turn any object that implements the server trait into an `RpcModule`. +/// - For subscription methods, there will be one additional argument inserted right +/// after `&self`: `subscription_sink: SubscriptionSink`. It should be used to +/// actually maintain the subscription. +/// +/// Since this macro can generate up to two traits, both server and client traits will have +/// a new name. For the `Foo` trait, server trait will be named `FooServer`, and client, +/// correspondingly, `FooClient`. +/// +/// `FooClient` in that case will only have to be imported in the context and will be ready to +/// use, while `FooServer` must be implemented for some type first. +/// +/// ## Prerequisites +/// +/// - Implementors of the server trait must be `Sync`, `Send`, `Sized` and `'static`. +/// If you want to implement this trait to some type that is not thread-safe, consider +/// using `Arc>`. +/// +/// ## Examples +/// +/// Below you can find the example of the macro usage along with the code +/// that will be generated for it. +/// +/// ```ignore +/// #[rpc(client, server, namespace = "foo")] +/// pub trait Rpc { +/// #[method(name = "foo")] +/// async fn async_method(&self, param_a: u8, param_b: String) -> u16; +/// #[method(name = "bar")] +/// fn sync_method(&self) -> String; +/// +/// #[subscription(name = "sub", unsub = "unsub", item = "String")] +/// fn sub(&self); +/// } +/// ``` +/// +/// Server code that will be generated: +/// +/// ```ignore +/// #[async_trait] +/// pub trait RpcServer { +/// // RPC methods are usual methods and can be either sync or async. +/// async fn async_method(&self, param_a: u8, param_b: String) -> u16; +/// fn sync_method(&self) -> String; +/// +/// // Note that `subscription_sink` was added automatically. +/// fn sub(&self, subscription_sink: SubscriptionSink); +/// +/// fn into_rpc(self) -> Result { +/// // Actual implementation stripped, but inside we will create +/// // a module with one method and one subscription +/// } +/// } +/// ``` +/// +/// Client code that will be generated: +/// +/// ```ignore +/// #[async_trait] +/// pub trait RpcClient: SubscriptionClient { +/// // In client implementation all the methods are (obviously) async. +/// async fn async_method(&self, param_a: u8, param_b: String) -> Result { +/// // Actual implementations are stripped, but inside a corresponding `Client` or +/// // `SubscriptionClient` method is called. +/// } +/// async fn sync_method(&self) -> Result { +/// // ... +/// } +/// +/// // Subscription method returns `Subscription` object in case of success. +/// async fn sub(&self) -> Result, Error> { +/// // ... +/// } +/// } +/// +/// impl RpcClient for T where T: SubscriptionClient {} +/// ``` +/// +/// ## Attributes +/// +/// ### `rpc` attribute +/// +/// `rpc` attribute is applied to a trait in order to turn it into an RPC implementation. +/// +/// **Arguments:** +/// +/// - `server`: generate `Server` trait for the server implementation. +/// - `client`: generate `Client` extension trait that makes RPC clients to invoke a concrete RPC +/// implementation methods conveniently. +/// - `namespace`: add a prefix to all the methods and subscriptions in this RPC. For example, with namespace +/// `foo` and method `spam`, the resulting method name will be `foo_spam`. +/// +/// **Trait requirements:** +/// +/// Trait wrapped with an `rpc` attribute **must not**: +/// +/// - have associated types or constants; +/// - have Rust methods not marked with either `method` or `subscription` attribute; +/// - be empty. +/// +/// At least one of the `server` or `client` flags must be provided, otherwise the compilation will err. +/// +/// ### `method` attribute +/// +/// `method` attribute is used to define an RPC method. +/// +/// **Arguments:** +/// +/// - `name` (mandatory): name of the RPC method. Does not have to be the same as the Rust method name. +/// +/// **Method requirements:** +/// +/// Rust method marked with `method` attribute, **may**: +/// +/// - be either `async` or not; +/// - have input parameters or not; +/// - have return value or not (in the latter case, it will be considered a notification method). +/// +/// ### `subscription` attribute +/// +/// **Arguments:** +/// +/// - `name` (mandatory): name of the RPC method. Does not have to be the same as the Rust method name. +/// - `unsub` (mandatory): name of the RPC method to unsubscribe from the subscription. Must not be the same as `name`. +/// - `item` (mandatory): type of items yielded by the subscription. Note that it must be the type, not string. +/// +/// **Method requirements:** +/// +/// Rust method marked with `subscription` attribute **must**: +/// +/// - be synchronous; +/// - not have return value. +/// +/// Rust method marked with `subscription` attribute **may**: +/// +/// - have input parameters or not. +/// +/// ## Full workflow example +/// +/// ```rust +/// //! Example of using proc macro to generate working client and server. +/// +/// use std::net::SocketAddr; +/// +/// use futures_channel::oneshot; +/// use jsonrpsee::{ws_client::*, ws_server::WsServerBuilder}; +/// +/// // RPC is moved into a separate module to clearly show names of generated entities. +/// mod rpc_impl { +/// use jsonrpsee::{proc_macros::rpc, types::async_trait, ws_server::SubscriptionSink}; +/// +/// // Generate both server and client implementations, prepend all the methods with `foo_` prefix. +/// #[rpc(client, server, namespace = "foo")] +/// pub trait Rpc { +/// #[method(name = "foo")] +/// async fn async_method(&self, param_a: u8, param_b: String) -> u16; +/// +/// #[method(name = "bar")] +/// fn sync_method(&self) -> u16; +/// +/// #[subscription(name = "sub", unsub = "unsub", item = String)] +/// fn sub(&self); +/// } +/// +/// // Structure that will implement `RpcServer` trait. +/// // In can have fields, if required, as long as it's still `Send + Sync + 'static`. +/// pub struct RpcServerImpl; +/// +/// // Note that the trait name we use is `RpcServer`, not `Rpc`! +/// #[async_trait] +/// impl RpcServer for RpcServerImpl { +/// async fn async_method(&self, _param_a: u8, _param_b: String) -> u16 { +/// 42u16 +/// } +/// +/// fn sync_method(&self) -> u16 { +/// 10u16 +/// } +/// +/// // We could've spawned a `tokio` future that yields values while our program works, +/// // but for simplicity of the example we will only send two values and then close +/// // the subscription. +/// fn sub(&self, mut sink: SubscriptionSink) { +/// sink.send(&"Response_A").unwrap(); +/// sink.send(&"Response_B").unwrap(); +/// } +/// } +/// } +/// +/// // Use generated implementations of server and client. +/// use rpc_impl::{RpcClient, RpcServer, RpcServerImpl}; +/// +/// pub async fn websocket_server() -> SocketAddr { +/// let (server_started_tx, server_started_rx) = oneshot::channel(); +/// +/// std::thread::spawn(move || { +/// let rt = tokio::runtime::Runtime::new().unwrap(); +/// let mut server = rt.block_on(WsServerBuilder::default().build("127.0.0.1:0")).unwrap(); +/// // `into_rpc()` method was generated inside of the `RpcServer` trait under the hood. +/// server.register_module(RpcServerImpl.into_rpc().unwrap()).unwrap(); +/// +/// rt.block_on(async move { +/// server_started_tx.send(server.local_addr().unwrap()).unwrap(); +/// +/// server.start().await +/// }); +/// }); +/// +/// server_started_rx.await.unwrap() +/// } +/// +/// // In the main function, we will spawn the server, create a client connected to this server, +/// // and call all the available methods. +/// #[tokio::main] +/// async fn main() { +/// let server_addr = websocket_server().await; +/// let server_url = format!("ws://{}", server_addr); +/// // Note that we create the client as usual, but thanks to the `use rpc_impl::RpcClient`, +/// // the client object will have all the methods to interact with the server. +/// let client = WsClientBuilder::default().build(&server_url).await.unwrap(); +/// +/// // Invoke RPC methods. +/// assert_eq!(client.async_method(10, "a".into()).await.unwrap(), 42); +/// assert_eq!(client.sync_method().await.unwrap(), 10); +/// +/// // Subscribe and receive messages from the subscription. +/// let mut sub = client.sub().await.unwrap(); +/// let first_recv = sub.next().await.unwrap(); +/// assert_eq!(first_recv, Some("Response_A".to_string())); +/// let second_recv = sub.next().await.unwrap(); +/// assert_eq!(second_recv, Some("Response_B".to_string())); +/// } +/// ``` +#[proc_macro_attribute] +pub fn rpc(attr: TokenStream, item: TokenStream) -> TokenStream { + let attr = proc_macro2::TokenStream::from(attr); + + let rebuilt_rpc_attribute = syn::Attribute { + pound_token: syn::token::Pound::default(), + style: syn::AttrStyle::Outer, + bracket_token: syn::token::Bracket::default(), + path: syn::Ident::new("rpc", proc_macro2::Span::call_site()).into(), + tokens: quote! { (#attr) }, + }; -/// Determine the name of the parameter based on the pattern. -fn rpc_param_name(pat: &syn::Pat, _attrs: &[syn::Attribute]) -> syn::parse::Result { - // TODO: look in attributes if the user specified a param name - match pat { - // TODO: check other fields of the `PatIdent` - syn::Pat::Ident(ident) => Ok(ident.ident.to_string()), - _ => unimplemented!(), + match rpc_impl(rebuilt_rpc_attribute, item) { + Ok(tokens) => tokens, + Err(err) => err.to_compile_error(), } + .into() } -/// Search for `jsonrpsee` in `Cargo.toml`. -fn find_jsonrpsee_crate() -> Result { - match crate_name("jsonrpsee") { - Ok(FoundCrate::Name(name)) => { - let ident = syn::Ident::new(&name, Span::call_site()); - Ok(quote!(#ident::types)) - } - Ok(FoundCrate::Itself) => panic!("Deriving RPC methods in any of the `jsonrpsee crates` is not supported"), - Err(_) => match (crate_name("jsonrpsee-http-client"), crate_name("jsonrpsee-ws-client")) { - (Ok(FoundCrate::Name(name)), _) | (_, Ok(FoundCrate::Name(name))) => { - let ident = syn::Ident::new(&name, Span::call_site()); - Ok(quote!(#ident)) - } - (Ok(FoundCrate::Itself), _) | (_, Ok(FoundCrate::Itself)) => { - panic!("Deriving RPC methods in any of the `jsonrpsee crates` is not supported") - } - (_, Err(e)) => Err(syn::Error::new(Span::call_site(), &e)), - }, - } +/// Convenience form of `rpc` that may use `?` for error handling to avoid boilerplate. +fn rpc_impl(attr: syn::Attribute, item: TokenStream) -> Result { + let trait_data: syn::ItemTrait = syn::parse(item)?; + let rpc = RpcDescription::from_item(attr, trait_data)?; + rpc.render() } diff --git a/proc-macros/src/new/attributes.rs b/proc-macros/src/new/attributes.rs new file mode 100644 index 0000000000..3bcc41417a --- /dev/null +++ b/proc-macros/src/new/attributes.rs @@ -0,0 +1,51 @@ +use bae::FromAttributes; + +/// Input for the `#[rpc(...)]` attribute macro. +#[derive(Debug, Clone, FromAttributes)] +pub(crate) struct Rpc { + /// Switch denoting that server trait must be generated. + /// Assuming that trait to which attribute is applied is named `Foo`, the generated + /// server trait will have `FooServer` name. + pub server: Option<()>, + /// Switch denoting that client extension trait must be generated. + /// Assuming that trait to which attribute is applied is named `Foo`, the generated + /// client trait will have `FooClient` name. + pub client: Option<()>, + /// Optional prefix for RPC namespace. + pub namespace: Option, +} + +impl Rpc { + /// Returns `true` if at least one of `server` or `client` attributes is present. + pub(crate) fn is_correct(&self) -> bool { + self.server.is_some() || self.client.is_some() + } + + /// Returns `true` if server implementation was requested. + pub(crate) fn needs_server(&self) -> bool { + self.server.is_some() + } + + /// Returns `true` if client implementation was requested. + pub(crate) fn needs_client(&self) -> bool { + self.client.is_some() + } +} + +/// Input for the `#[method(...)]` attribute. +#[derive(Debug, Clone, FromAttributes)] +pub(crate) struct Method { + /// Method name + pub name: syn::LitStr, +} + +/// Input for the `#[subscription(...)]` attribute. +#[derive(Debug, Clone, FromAttributes)] +pub(crate) struct Subscription { + /// Subscription name + pub name: syn::LitStr, + /// Name of the method to unsubscribe. + pub unsub: syn::LitStr, + /// Type yielded by the subscription. + pub item: syn::Type, +} diff --git a/proc-macros/src/new/mod.rs b/proc-macros/src/new/mod.rs new file mode 100644 index 0000000000..7a6968c136 --- /dev/null +++ b/proc-macros/src/new/mod.rs @@ -0,0 +1,220 @@ +//! Declaration of the JSON RPC generator procedural macros. + +use self::respan::Respan; +use proc_macro2::TokenStream as TokenStream2; +use quote::quote; +use syn::Attribute; + +mod attributes; +mod render_client; +mod render_server; +mod respan; + +#[derive(Debug, Clone)] +pub struct RpcMethod { + pub name: syn::LitStr, + pub params: Vec<(syn::PatIdent, syn::Type)>, + pub returns: Option, + pub signature: syn::TraitItemMethod, +} + +impl RpcMethod { + pub fn from_item(mut method: syn::TraitItemMethod) -> Result { + let attributes = attributes::Method::from_attributes(&method.attrs).respan(&method.attrs.first())?; + let sig = method.sig.clone(); + let name = attributes.name; + let params: Vec<_> = sig + .inputs + .into_iter() + .filter_map(|arg| match arg { + syn::FnArg::Receiver(_) => None, + syn::FnArg::Typed(arg) => match *arg.pat { + syn::Pat::Ident(name) => Some((name, *arg.ty)), + _ => panic!("Identifier in signature must be an ident"), + }, + }) + .collect(); + + let returns = match sig.output { + syn::ReturnType::Default => None, + syn::ReturnType::Type(_, output) => Some(*output), + }; + + // We've analyzed attributes and don't need them anymore. + method.attrs.clear(); + + Ok(Self { name, params, returns, signature: method }) + } +} + +#[derive(Debug, Clone)] +pub struct RpcSubscription { + pub name: syn::LitStr, + pub unsub_method: syn::LitStr, + pub params: Vec<(syn::PatIdent, syn::Type)>, + pub item: syn::Type, + pub signature: syn::TraitItemMethod, +} + +impl RpcSubscription { + pub fn from_item(mut sub: syn::TraitItemMethod) -> Result { + let attributes = attributes::Subscription::from_attributes(&sub.attrs).respan(&sub.attrs.first())?; + let sig = sub.sig.clone(); + let name = attributes.name; + let unsub_method = attributes.unsub; + let item = attributes.item; + let params: Vec<_> = sig + .inputs + .into_iter() + .filter_map(|arg| match arg { + syn::FnArg::Receiver(_) => None, + syn::FnArg::Typed(arg) => match *arg.pat { + syn::Pat::Ident(name) => Some((name, *arg.ty)), + _ => panic!("Identifier in signature must be an ident"), + }, + }) + .collect(); + + // We've analyzed attributes and don't need them anymore. + sub.attrs.clear(); + + Ok(Self { name, unsub_method, params, item, signature: sub }) + } +} + +#[derive(Debug)] +pub struct RpcDescription { + /// Path to the `jsonrpsee` client types part. + jsonrpsee_client_path: Option, + /// Path to the `jsonrpsee` server types part. + jsonrpsee_server_path: Option, + /// Data about RPC declaration + attrs: attributes::Rpc, + /// Trait definition in which all the attributes were stripped. + trait_def: syn::ItemTrait, + /// List of RPC methods defined in the trait. + methods: Vec, + /// List of RPC subscritpions defined in the trait. + subscriptions: Vec, +} + +impl RpcDescription { + pub fn from_item(attr: syn::Attribute, mut item: syn::ItemTrait) -> Result { + let attrs = attributes::Rpc::from_attributes(&[attr.clone()]).respan(&attr)?; + if !attrs.is_correct() { + return Err(syn::Error::new_spanned(&item.ident, "Either 'server' or 'client' attribute must be applied")); + } + + let jsonrpsee_client_path = crate::helpers::find_jsonrpsee_client_crate().ok(); + let jsonrpsee_server_path = crate::helpers::find_jsonrpsee_server_crate().ok(); + + if attrs.needs_client() && jsonrpsee_client_path.is_none() { + return Err(syn::Error::new_spanned(&item.ident, "Unable to locate 'jsonrpsee' client dependency")); + } + if attrs.needs_server() && jsonrpsee_server_path.is_none() { + return Err(syn::Error::new_spanned(&item.ident, "Unable to locate 'jsonrpsee' server dependency")); + } + + item.attrs.clear(); // Remove RPC attributes. + + let mut methods = Vec::new(); + let mut subscriptions = Vec::new(); + + // Go through all the methods in the trait and collect methods and + // subscriptions. + for entry in item.items.iter() { + if let syn::TraitItem::Method(method) = entry { + if method.sig.receiver().is_none() { + return Err(syn::Error::new_spanned(&method.sig, "First argument of the trait must be '&self'")); + } + + let mut is_method = false; + let mut is_sub = false; + if has_attr(&method.attrs, "method") { + is_method = true; + + let method_data = RpcMethod::from_item(method.clone())?; + methods.push(method_data); + } + if has_attr(&method.attrs, "subscription") { + is_sub = true; + if is_method { + return Err(syn::Error::new_spanned( + &method, + "Element cannot be both subscription and method at the same time", + )); + } + if method.sig.asyncness.is_some() { + return Err(syn::Error::new_spanned(&method, "Subscription methods must not be `async`")); + } + if !matches!(method.sig.output, syn::ReturnType::Default) { + return Err(syn::Error::new_spanned(&method, "Subscription methods must not return anything")); + } + + let sub_data = RpcSubscription::from_item(method.clone())?; + subscriptions.push(sub_data); + } + + if !is_method && !is_sub { + return Err(syn::Error::new_spanned( + &method, + "Methods must have either 'method' or 'subscription' attribute", + )); + } + } else { + return Err(syn::Error::new_spanned(&entry, "Only methods allowed in RPC traits")); + } + } + + if methods.is_empty() && subscriptions.is_empty() { + return Err(syn::Error::new_spanned(&item, "RPC cannot be empty")); + } + + Ok(Self { jsonrpsee_client_path, jsonrpsee_server_path, attrs, trait_def: item, methods, subscriptions }) + } + + pub fn render(self) -> Result { + let server_impl = if self.attrs.needs_server() { self.render_server()? } else { TokenStream2::new() }; + let client_impl = if self.attrs.needs_client() { self.render_client()? } else { TokenStream2::new() }; + + Ok(quote! { + #server_impl + #client_impl + }) + } + + /// Formats the identifier as a path relative to the resolved + /// `jsonrpsee` client path. + fn jrps_client_item(&self, item: impl quote::ToTokens) -> TokenStream2 { + let jsonrpsee = self.jsonrpsee_client_path.as_ref().unwrap(); + quote! { #jsonrpsee::#item } + } + + /// Formats the identifier as a path relative to the resolved + /// `jsonrpsee` server path. + fn jrps_server_item(&self, item: impl quote::ToTokens) -> TokenStream2 { + let jsonrpsee = self.jsonrpsee_server_path.as_ref().unwrap(); + quote! { #jsonrpsee::#item } + } + + /// Based on the namespace, renders the full name of the RPC method/subscription. + /// Examples: + /// For namespace `foo` and method `makeSpam`, result will be `foo_makeSpam`. + /// For no namespace and method `makeSpam` it will be just `makeSpam. + fn rpc_identifier(&self, method: &syn::LitStr) -> String { + if let Some(ns) = &self.attrs.namespace { + format!("{}_{}", ns.value(), method.value()) + } else { + method.value() + } + } +} + +fn has_attr(attrs: &[Attribute], ident: &str) -> bool { + for attr in attrs.iter().filter_map(|a| a.path.get_ident()) { + if attr == ident { + return true; + } + } + false +} diff --git a/proc-macros/src/new/render_client.rs b/proc-macros/src/new/render_client.rs new file mode 100644 index 0000000000..286b3245fa --- /dev/null +++ b/proc-macros/src/new/render_client.rs @@ -0,0 +1,134 @@ +use super::{RpcDescription, RpcMethod, RpcSubscription}; +use proc_macro2::TokenStream as TokenStream2; +use quote::quote; + +impl RpcDescription { + pub(super) fn render_client(&self) -> Result { + let jsonrpsee = self.jsonrpsee_client_path.as_ref().unwrap(); + + let trait_name = quote::format_ident!("{}Client", &self.trait_def.ident); + + let super_trait = if self.subscriptions.is_empty() { + quote! { #jsonrpsee::traits::Client } + } else { + quote! { #jsonrpsee::traits::SubscriptionClient } + }; + + let method_impls = + self.methods.iter().map(|method| self.render_method(method)).collect::, _>>()?; + let sub_impls = self.subscriptions.iter().map(|sub| self.render_sub(sub)).collect::, _>>()?; + + let async_trait = self.jrps_client_item(quote! { __reexports::async_trait }); + + // Doc-comment to be associated with the client. + let doc_comment = format!("Client implementation for the `{}` RPC API.", &self.trait_def.ident); + + let trait_impl = quote! { + #[#async_trait] + #[doc = #doc_comment] + pub trait #trait_name: #super_trait { + #(#method_impls)* + #(#sub_impls)* + } + + impl #trait_name for T where T: #super_trait {} + }; + + Ok(trait_impl) + } + + fn render_method(&self, method: &RpcMethod) -> Result { + // `jsonrpsee::Error` + let jrps_error = self.jrps_client_item(quote::format_ident!("Error")); + // Rust method to invoke (e.g. `self.(...)`). + let rust_method_name = &method.signature.sig.ident; + // List of inputs to put into `JsonRpcParams` (e.g. `self.foo(<12, "baz">)`). + // Includes `&self` receiver. + let rust_method_params = &method.signature.sig.inputs; + // Name of the RPC method (e.g. `foo_makeSpam`). + let rpc_method_name = self.rpc_identifier(&method.name); + + // Called method is either `request` or `notification`. + // `returns` represent the return type of the *rust method* (`Result< <..>, jsonrpsee::Error`). + let (called_method, returns) = if let Some(returns) = &method.returns { + let called_method = quote::format_ident!("request"); + let returns = quote! { Result<#returns, #jrps_error> }; + + (called_method, returns) + } else { + let called_method = quote::format_ident!("notification"); + let returns = quote! { Result<(), #jrps_error> }; + + (called_method, returns) + }; + + // Encoded parameters for the request. + let parameters = if !method.params.is_empty() { + let serde_json = self.jrps_client_item(quote! { __reexports::serde_json }); + let params = method.params.iter().map(|(param, _param_type)| { + quote! { #serde_json::to_value(&#param)? } + }); + + quote! { + vec![ #(#params),* ].into() + } + } else { + self.jrps_client_item(quote! { v2::params::JsonRpcParams::NoParams }) + }; + + // Doc-comment to be associated with the method. + let doc_comment = format!("Invokes the RPC method `{}`.", rpc_method_name); + + let method = quote! { + #[doc = #doc_comment] + async fn #rust_method_name(#rust_method_params) -> #returns { + self.#called_method(#rpc_method_name, #parameters).await + } + }; + Ok(method) + } + + fn render_sub(&self, sub: &RpcSubscription) -> Result { + // `jsonrpsee::Error` + let jrps_error = self.jrps_client_item(quote::format_ident!("Error")); + // Rust method to invoke (e.g. `self.(...)`). + let rust_method_name = &sub.signature.sig.ident; + // List of inputs to put into `JsonRpcParams` (e.g. `self.foo(<12, "baz">)`). + let rust_method_params = &sub.signature.sig.inputs; + // Name of the RPC subscription (e.g. `foo_sub`). + let rpc_sub_name = self.rpc_identifier(&sub.name); + // Name of the RPC method to unsubscribe (e.g. `foo_unsub`). + let rpc_unsub_name = self.rpc_identifier(&sub.unsub_method); + + // `returns` represent the return type of the *rust method*, which is wrapped + // into the `Subscription` object. + let sub_type = self.jrps_client_item(quote::format_ident!("Subscription")); + let item = &sub.item; + let returns = quote! { Result<#sub_type<#item>, #jrps_error> }; + + // Encoded parameters for the request. + let parameters = if !sub.params.is_empty() { + let serde_json = self.jrps_client_item(quote! { __reexports::serde_json }); + let params = sub.params.iter().map(|(param, _param_type)| { + quote! { #serde_json::to_value(&#param)? } + }); + + quote! { + vec![ #(#params),* ].into() + } + } else { + self.jrps_client_item(quote! { v2::params::JsonRpcParams::NoParams }) + }; + + // Doc-comment to be associated with the method. + let doc_comment = format!("Subscribes to the RPC method `{}`.", rpc_sub_name); + + let method = quote! { + #[doc = #doc_comment] + async fn #rust_method_name(#rust_method_params) -> #returns { + self.subscribe(#rpc_sub_name, #parameters, #rpc_unsub_name).await + } + }; + Ok(method) + } +} diff --git a/proc-macros/src/new/render_server.rs b/proc-macros/src/new/render_server.rs new file mode 100644 index 0000000000..0851035000 --- /dev/null +++ b/proc-macros/src/new/render_server.rs @@ -0,0 +1,263 @@ +use super::RpcDescription; +use proc_macro2::TokenStream as TokenStream2; +use quote::quote; + +impl RpcDescription { + pub(super) fn render_server(&self) -> Result { + let trait_name = quote::format_ident!("{}Server", &self.trait_def.ident); + + let method_impls = self.render_methods()?; + let into_rpc_impl = self.render_into_rpc()?; + + let async_trait = self.jrps_server_item(quote! { __reexports::async_trait }); + + // Doc-comment to be associated with the server. + let doc_comment = format!("Server trait implementation for the `{}` RPC API.", &self.trait_def.ident); + + let trait_impl = quote! { + #[#async_trait] + #[doc = #doc_comment] + pub trait #trait_name: Sized + Send + Sync + 'static { + #method_impls + #into_rpc_impl + } + }; + + Ok(trait_impl) + } + + fn render_methods(&self) -> Result { + let methods = self.methods.iter().map(|method| &method.signature); + + let subscription_sink_ty = self.jrps_server_item(quote! { SubscriptionSink }); + let subscriptions = self.subscriptions.iter().cloned().map(|mut sub| { + // Add `SubscriptionSink` as the second input parameter to the signature. + let subscription_sink: syn::FnArg = syn::parse_quote!(subscription_sink: #subscription_sink_ty); + sub.signature.sig.inputs.insert(1, subscription_sink); + sub.signature + }); + + Ok(quote! { + #(#methods)* + #(#subscriptions)* + }) + } + + fn render_into_rpc(&self) -> Result { + let jrps_error = self.jrps_server_item(quote! { error::Error }); + let rpc_module = self.jrps_server_item(quote! { RpcModule }); + + let methods = self.methods.iter().map(|method| { + // Rust method to invoke (e.g. `self.(...)`). + let rust_method_name = &method.signature.sig.ident; + // Name of the RPC method (e.g. `foo_makeSpam`). + let rpc_method_name = self.rpc_identifier(&method.name); + // `parsing` is the code associated with parsing structure from the + // provided `RpcParams` object. + // `params_seq` is the comma-delimited sequence of parametsrs. + let is_method = true; + let (parsing, params_seq) = self.render_params_decoding(&method.params, is_method); + + if method.signature.sig.asyncness.is_some() { + quote! { + rpc.register_async_method(#rpc_method_name, |params, context| { + let owned_params = params.owned(); + let fut = async move { + let params = owned_params.borrowed(); + #parsing + Ok(context.as_ref().#rust_method_name(#params_seq).await) + }; + Box::pin(fut) + })?; + } + } else { + quote! { + rpc.register_method(#rpc_method_name, |params, context| { + #parsing + Ok(context.#rust_method_name(#params_seq)) + })?; + } + } + }); + + let subscriptions = self.subscriptions.iter().map(|sub| { + // Rust method to invoke (e.g. `self.(...)`). + let rust_method_name = &sub.signature.sig.ident; + // Name of the RPC method to subscribe (e.g. `foo_sub`). + let rpc_sub_name = self.rpc_identifier(&sub.name); + // Name of the RPC method to unsubscribe (e.g. `foo_sub`). + let rpc_unsub_name = self.rpc_identifier(&sub.unsub_method); + // `parsing` is the code associated with parsing structure from the + // provided `RpcParams` object. + // `params_seq` is the comma-delimited sequence of parametsrs. + let is_method = false; + let (parsing, params_seq) = self.render_params_decoding(&sub.params, is_method); + + quote! { + rpc.register_subscription(#rpc_sub_name, #rpc_unsub_name, |params, sink, context| { + #parsing + Ok(context.as_ref().#rust_method_name(sink, #params_seq)) + })?; + } + }); + + let doc_comment = "Collects all the methods and subscriptions defined in the trait \ + and adds them into a single `RpcModule`."; + + Ok(quote! { + #[doc = #doc_comment] + fn into_rpc(self) -> Result<#rpc_module, #jrps_error> { + let mut rpc = #rpc_module::new(self); + + #(#methods)* + #(#subscriptions)* + + Ok(rpc) + } + }) + } + + fn render_params_decoding( + &self, + params: &[(syn::PatIdent, syn::Type)], + is_method: bool, + ) -> (TokenStream2, TokenStream2) { + if params.is_empty() { + return (TokenStream2::default(), TokenStream2::default()); + } + + // Implementations for `.map_err(...)?` and `.ok_or(...)?` with respect to the expected + // error return type. + let (err, map_err_impl, ok_or_impl) = if is_method { + // For methods, we return `CallError`. + let jrps_call_error = self.jrps_server_item(quote! { error::CallError }); + let err = quote! { #jrps_call_error::InvalidParams }; + let map_err = quote! { .map_err(|_| #jrps_call_error::InvalidParams)? }; + let ok_or = quote! { .ok_or(#jrps_call_error::InvalidParams)? }; + (err, map_err, ok_or) + } else { + // For subscriptions, we return `Error`. + // Note that while `Error` can be constructed from `CallError`, we should not do it, + // because it will be an obuse of the error type semantics. + // Instead, we use suitable top-level error variants. + let jrps_error = self.jrps_server_item(quote! { error::Error }); + let err = quote! { #jrps_error::Request("Required paramater missing".into()) }; + let map_err = quote! { .map_err(|err| #jrps_error::ParseError(err))? }; + let ok_or = quote! { .ok_or(#jrps_error::Request("Required paramater missing".into()))? }; + (err, map_err, ok_or) + }; + + let serde_json = self.jrps_server_item(quote! { __reexports::serde_json }); + + // Parameters encoded as a tuple (to be parsed from array). + let (params_fields_seq, params_types_seq): (Vec<_>, Vec<_>) = params.iter().cloned().unzip(); + let params_types = quote! { (#(#params_types_seq),*) }; + let params_fields = quote! { (#(#params_fields_seq),*) }; + + // Code to decode sequence of parameters from a JSON array. + let decode_array = { + let decode_fields = params.iter().enumerate().map(|(id, (name, ty))| { + if is_option(ty) { + quote! { + let #name = arr + .get(#id) + .cloned() + .map(#serde_json::from_value) + .transpose() + #map_err_impl; + } + } else { + quote! { + let #name = arr + .get(#id) + .cloned() + .map(#serde_json::from_value) + #ok_or_impl + #map_err_impl; + } + } + }); + + quote! { + #(#decode_fields);* + #params_fields + } + }; + + // Code to decode sequence of parameters from a JSON object (aka map). + let decode_map = { + let decode_fields = params.iter().map(|(name, ty)| { + let name_str = name.ident.to_string(); + if is_option(ty) { + quote! { + let #name = obj + .get(#name_str) + .cloned() + .map(#serde_json::from_value) + .transpose() + #map_err_impl; + } + } else { + quote! { + let #name = obj + .get(#name_str) + .cloned() + .map(#serde_json::from_value) + #ok_or_impl + #map_err_impl; + } + } + }); + + quote! { + #(#decode_fields);* + #params_fields + } + }; + + // Code to decode single parameter from a JSON primitive. + let decode_single = if params.len() == 1 { + quote! { + #serde_json::from_value(json) + #map_err_impl + } + } else { + quote! { return Err(#err);} + }; + + // Parsing of `serde_json::Value`. + let parsing = quote! { + let json: #serde_json::Value = params.parse()?; + let #params_fields: #params_types = match json { + #serde_json::Value::Null => return Err(#err), + #serde_json::Value::Array(arr) => { + #decode_array + } + #serde_json::Value::Object(obj) => { + #decode_map + } + _ => { + #decode_single + } + }; + }; + + let seq = quote! { + #(#params_fields_seq),* + }; + + (parsing, seq) + } +} + +/// Checks whether provided type is an `Option<...>`. +fn is_option(ty: &syn::Type) -> bool { + if let syn::Type::Path(path) = ty { + // TODO: Probably not the best way to check whether type is an `Option`. + if path.path.segments.iter().any(|seg| seg.ident == "Option") { + return true; + } + } + + false +} diff --git a/proc-macros/src/new/respan.rs b/proc-macros/src/new/respan.rs new file mode 100644 index 0000000000..ab8a1678d6 --- /dev/null +++ b/proc-macros/src/new/respan.rs @@ -0,0 +1,16 @@ +//! Module with a trait extension capable of re-spanning `syn` errors. + +use quote::ToTokens; + +/// Trait capable of changing `Span` set in the `syn::Error` so in case +/// of dependency setting it incorrectly, it is possible to easily create +/// a new error with the correct span. +pub trait Respan { + fn respan(self, spanned: S) -> Result; +} + +impl Respan for Result { + fn respan(self, spanned: S) -> Result { + self.map_err(|e| syn::Error::new_spanned(spanned, e)) + } +} diff --git a/proc-macros/tests/rpc_example.rs b/proc-macros/tests/rpc_example.rs new file mode 100644 index 0000000000..4e81c49551 --- /dev/null +++ b/proc-macros/tests/rpc_example.rs @@ -0,0 +1,15 @@ +//! Example of using proc macro to generate working client and server. + +use jsonrpsee_proc_macros::rpc; + +#[rpc(client, server, namespace = "foo")] +pub trait Rpc { + #[method(name = "foo")] + async fn async_method(&self, param_a: u8, param_b: String) -> u16; + + #[method(name = "bar")] + fn sync_method(&self) -> u16; + + #[subscription(name = "sub", unsub = "unsub", item = String)] + fn sub(&self); +} diff --git a/proc-macros/tests/ui.rs b/proc-macros/tests/ui.rs new file mode 100644 index 0000000000..3e45a93e38 --- /dev/null +++ b/proc-macros/tests/ui.rs @@ -0,0 +1,18 @@ +//! UI test set uses [`trybuild`](https://docs.rs/trybuild/1.0.42/trybuild/) to +//! check whether expected valid examples of code compile correctly, and for incorrect ones +//! errors are helpful and valid (e.g. have correct spans). +//! +//! Use with `TRYBUILD=overwrite` after updating codebase (see `trybuild` docs for more details on that) +//! to automatically regenerate `stderr` files, but don't forget to check that new files make sense. + +#[test] +fn ui_pass() { + let t = trybuild::TestCases::new(); + t.pass("tests/ui/correct/*.rs"); +} + +#[test] +fn ui_fail() { + let t = trybuild::TestCases::new(); + t.compile_fail("tests/ui/incorrect/**/*.rs"); +} diff --git a/proc-macros/tests/ui/correct/basic.rs b/proc-macros/tests/ui/correct/basic.rs new file mode 100644 index 0000000000..e7023cfad2 --- /dev/null +++ b/proc-macros/tests/ui/correct/basic.rs @@ -0,0 +1,80 @@ +//! Example of using proc macro to generate working client and server. + +use jsonrpsee::{ + proc_macros::rpc, + types::async_trait, + ws_client::*, + ws_server::{SubscriptionSink, WsServerBuilder}, +}; +use std::{net::SocketAddr, sync::mpsc::channel}; + +#[rpc(client, server, namespace = "foo")] +pub trait Rpc { + #[method(name = "foo")] + async fn async_method(&self, param_a: u8, param_b: String) -> u16; + + #[method(name = "bar")] + fn sync_method(&self) -> u16; + + #[subscription(name = "sub", unsub = "unsub", item = String)] + fn sub(&self); + + #[subscription(name = "echo", unsub = "no_more_echo", item = u32)] + fn sub_with_params(&self, val: u32); +} + +pub struct RpcServerImpl; + +#[async_trait] +impl RpcServer for RpcServerImpl { + async fn async_method(&self, _param_a: u8, _param_b: String) -> u16 { + 42u16 + } + + fn sync_method(&self) -> u16 { + 10u16 + } + + fn sub(&self, mut sink: SubscriptionSink) { + sink.send(&"Response_A").unwrap(); + sink.send(&"Response_B").unwrap(); + } + + fn sub_with_params(&self, mut sink: SubscriptionSink, val: u32) { + sink.send(&val).unwrap(); + sink.send(&val).unwrap(); + } +} + +pub async fn websocket_server() -> SocketAddr { + let (server_started_tx, server_started_rx) = channel(); + + std::thread::spawn(move || { + let rt = tokio::runtime::Runtime::new().unwrap(); + let mut server = rt.block_on(WsServerBuilder::default().build("127.0.0.1:0")).unwrap(); + server.register_module(RpcServerImpl.into_rpc().unwrap()).unwrap(); + + rt.block_on(async move { + server_started_tx.send(server.local_addr().unwrap()).unwrap(); + + server.start().await + }); + }); + + server_started_rx.recv().unwrap() +} + +#[tokio::main] +async fn main() { + let server_addr = websocket_server().await; + let server_url = format!("ws://{}", server_addr); + let client = WsClientBuilder::default().build(&server_url).await.unwrap(); + + assert_eq!(client.async_method(10, "a".into()).await.unwrap(), 42); + assert_eq!(client.sync_method().await.unwrap(), 10); + let mut sub = client.sub().await.unwrap(); + let first_recv = sub.next().await.unwrap(); + assert_eq!(first_recv, Some("Response_A".to_string())); + let second_recv = sub.next().await.unwrap(); + assert_eq!(second_recv, Some("Response_B".to_string())); +} diff --git a/proc-macros/tests/ui/correct/only_client.rs b/proc-macros/tests/ui/correct/only_client.rs new file mode 100644 index 0000000000..c04d87e96c --- /dev/null +++ b/proc-macros/tests/ui/correct/only_client.rs @@ -0,0 +1,17 @@ +//! Example of using proc macro to generate working client and server. + +use jsonrpsee::proc_macros::rpc; + +#[rpc(client)] +pub trait Rpc { + #[method(name = "foo")] + async fn async_method(&self, param_a: u8, param_b: String) -> u16; + + #[method(name = "bar")] + fn sync_method(&self) -> u16; + + #[subscription(name = "sub", unsub = "unsub", item = String)] + fn sub(&self); +} + +fn main() {} diff --git a/proc-macros/tests/ui/correct/only_server.rs b/proc-macros/tests/ui/correct/only_server.rs new file mode 100644 index 0000000000..0751c9d7f9 --- /dev/null +++ b/proc-macros/tests/ui/correct/only_server.rs @@ -0,0 +1,59 @@ +use jsonrpsee::{ + proc_macros::rpc, + types::async_trait, + ws_server::{SubscriptionSink, WsServerBuilder}, +}; +use std::{net::SocketAddr, sync::mpsc::channel}; + +#[rpc(server)] +pub trait Rpc { + #[method(name = "foo")] + async fn async_method(&self, param_a: u8, param_b: String) -> u16; + + #[method(name = "bar")] + fn sync_method(&self) -> u16; + + #[subscription(name = "sub", unsub = "unsub", item = String)] + fn sub(&self); +} + +pub struct RpcServerImpl; + +#[async_trait] +impl RpcServer for RpcServerImpl { + async fn async_method(&self, _param_a: u8, _param_b: String) -> u16 { + 42u16 + } + + fn sync_method(&self) -> u16 { + 10u16 + } + + fn sub(&self, mut sink: SubscriptionSink) { + sink.send(&"Response_A").unwrap(); + sink.send(&"Response_B").unwrap(); + } +} + +pub async fn websocket_server() -> SocketAddr { + let (server_started_tx, server_started_rx) = channel(); + + std::thread::spawn(move || { + let rt = tokio::runtime::Runtime::new().unwrap(); + let mut server = rt.block_on(WsServerBuilder::default().build("127.0.0.1:0")).unwrap(); + server.register_module(RpcServerImpl.into_rpc().unwrap()).unwrap(); + + rt.block_on(async move { + server_started_tx.send(server.local_addr().unwrap()).unwrap(); + + server.start().await + }); + }); + + server_started_rx.recv().unwrap() +} + +#[tokio::main] +async fn main() { + let _server_addr = websocket_server().await; +} diff --git a/proc-macros/tests/ui/incorrect/method/method_no_name.rs b/proc-macros/tests/ui/incorrect/method/method_no_name.rs new file mode 100644 index 0000000000..a3375d2f0d --- /dev/null +++ b/proc-macros/tests/ui/incorrect/method/method_no_name.rs @@ -0,0 +1,10 @@ +use jsonrpsee::proc_macros::rpc; + +// Missing mandatory `name` field. +#[rpc(client, server)] +pub trait NoMethodName { + #[method()] + async fn async_method(&self) -> u8; +} + +fn main() {} diff --git a/proc-macros/tests/ui/incorrect/method/method_no_name.stderr b/proc-macros/tests/ui/incorrect/method/method_no_name.stderr new file mode 100644 index 0000000000..954fb60c77 --- /dev/null +++ b/proc-macros/tests/ui/incorrect/method/method_no_name.stderr @@ -0,0 +1,5 @@ +error: unexpected end of input, `#[method]` is missing `name` argument + --> $DIR/method_no_name.rs:6:2 + | +6 | #[method()] + | ^^^^^^^^^^^ diff --git a/proc-macros/tests/ui/incorrect/method/method_unexpected_field.rs b/proc-macros/tests/ui/incorrect/method/method_unexpected_field.rs new file mode 100644 index 0000000000..c9b7c2e4d5 --- /dev/null +++ b/proc-macros/tests/ui/incorrect/method/method_unexpected_field.rs @@ -0,0 +1,10 @@ +use jsonrpsee::proc_macros::rpc; + +// Unsupported attribute field. +#[rpc(client, server)] +pub trait UnexpectedField { + #[method(name = "foo", magic = false)] + async fn async_method(&self) -> u8; +} + +fn main() {} diff --git a/proc-macros/tests/ui/incorrect/method/method_unexpected_field.stderr b/proc-macros/tests/ui/incorrect/method/method_unexpected_field.stderr new file mode 100644 index 0000000000..ab51a49ab4 --- /dev/null +++ b/proc-macros/tests/ui/incorrect/method/method_unexpected_field.stderr @@ -0,0 +1,5 @@ +error: `#[method]` got unknown `magic` argument. Supported arguments are `name` + --> $DIR/method_unexpected_field.rs:6:2 + | +6 | #[method(name = "foo", magic = false)] + | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ diff --git a/proc-macros/tests/ui/incorrect/rpc/rpc_assoc_items.rs b/proc-macros/tests/ui/incorrect/rpc/rpc_assoc_items.rs new file mode 100644 index 0000000000..e84e7819e5 --- /dev/null +++ b/proc-macros/tests/ui/incorrect/rpc/rpc_assoc_items.rs @@ -0,0 +1,20 @@ +use jsonrpsee::proc_macros::rpc; + +// Associated items are forbidden. +#[rpc(client, server)] +pub trait AssociatedConst { + const WOO: usize; + + #[method(name = "foo")] + async fn async_method(&self) -> u8; +} + +#[rpc(client, server)] +pub trait AssociatedType { + type Woo; + + #[method(name = "foo")] + async fn async_method(&self) -> u8; +} + +fn main() {} diff --git a/proc-macros/tests/ui/incorrect/rpc/rpc_assoc_items.stderr b/proc-macros/tests/ui/incorrect/rpc/rpc_assoc_items.stderr new file mode 100644 index 0000000000..cd701971b3 --- /dev/null +++ b/proc-macros/tests/ui/incorrect/rpc/rpc_assoc_items.stderr @@ -0,0 +1,11 @@ +error: Only methods allowed in RPC traits + --> $DIR/rpc_assoc_items.rs:6:2 + | +6 | const WOO: usize; + | ^^^^^^^^^^^^^^^^^ + +error: Only methods allowed in RPC traits + --> $DIR/rpc_assoc_items.rs:14:2 + | +14 | type Woo; + | ^^^^^^^^^ diff --git a/proc-macros/tests/ui/incorrect/rpc/rpc_empty.rs b/proc-macros/tests/ui/incorrect/rpc/rpc_empty.rs new file mode 100644 index 0000000000..5f58f9f846 --- /dev/null +++ b/proc-macros/tests/ui/incorrect/rpc/rpc_empty.rs @@ -0,0 +1,7 @@ +use jsonrpsee::proc_macros::rpc; + +// Empty RPC is forbidden. +#[rpc(client, server)] +pub trait Empty {} + +fn main() {} diff --git a/proc-macros/tests/ui/incorrect/rpc/rpc_empty.stderr b/proc-macros/tests/ui/incorrect/rpc/rpc_empty.stderr new file mode 100644 index 0000000000..27b8f4592e --- /dev/null +++ b/proc-macros/tests/ui/incorrect/rpc/rpc_empty.stderr @@ -0,0 +1,5 @@ +error: RPC cannot be empty + --> $DIR/rpc_empty.rs:5:1 + | +5 | pub trait Empty {} + | ^^^^^^^^^^^^^^^^^^ diff --git a/proc-macros/tests/ui/incorrect/rpc/rpc_no_impls.rs b/proc-macros/tests/ui/incorrect/rpc/rpc_no_impls.rs new file mode 100644 index 0000000000..4e8e98dbb0 --- /dev/null +++ b/proc-macros/tests/ui/incorrect/rpc/rpc_no_impls.rs @@ -0,0 +1,10 @@ +use jsonrpsee::proc_macros::rpc; + +// Either client or server field must be provided. +#[rpc()] +pub trait NoImpls { + #[method(name = "foo")] + async fn async_method(&self) -> u8; +} + +fn main() {} diff --git a/proc-macros/tests/ui/incorrect/rpc/rpc_no_impls.stderr b/proc-macros/tests/ui/incorrect/rpc/rpc_no_impls.stderr new file mode 100644 index 0000000000..3e7ea6fe2a --- /dev/null +++ b/proc-macros/tests/ui/incorrect/rpc/rpc_no_impls.stderr @@ -0,0 +1,5 @@ +error: Either 'server' or 'client' attribute must be applied + --> $DIR/rpc_no_impls.rs:5:11 + | +5 | pub trait NoImpls { + | ^^^^^^^ diff --git a/proc-macros/tests/ui/incorrect/rpc/rpc_not_qualified.rs b/proc-macros/tests/ui/incorrect/rpc/rpc_not_qualified.rs new file mode 100644 index 0000000000..46d48d5161 --- /dev/null +++ b/proc-macros/tests/ui/incorrect/rpc/rpc_not_qualified.rs @@ -0,0 +1,9 @@ +use jsonrpsee::proc_macros::rpc; + +// Method without type marker. +#[rpc(client, server)] +pub trait NotQualified { + async fn async_method(&self) -> u8; +} + +fn main() {} diff --git a/proc-macros/tests/ui/incorrect/rpc/rpc_not_qualified.stderr b/proc-macros/tests/ui/incorrect/rpc/rpc_not_qualified.stderr new file mode 100644 index 0000000000..5f41617512 --- /dev/null +++ b/proc-macros/tests/ui/incorrect/rpc/rpc_not_qualified.stderr @@ -0,0 +1,5 @@ +error: Methods must have either 'method' or 'subscription' attribute + --> $DIR/rpc_not_qualified.rs:6:2 + | +6 | async fn async_method(&self) -> u8; + | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ diff --git a/proc-macros/tests/ui/incorrect/sub/sub_async.rs b/proc-macros/tests/ui/incorrect/sub/sub_async.rs new file mode 100644 index 0000000000..ec1ff98e45 --- /dev/null +++ b/proc-macros/tests/ui/incorrect/sub/sub_async.rs @@ -0,0 +1,10 @@ +use jsonrpsee::proc_macros::rpc; + +// Subscription method must not be async. +#[rpc(client, server)] +pub trait AsyncSub { + #[subscription(name = "sub", unsub = "unsub", item = u8)] + async fn sub(&self); +} + +fn main() {} diff --git a/proc-macros/tests/ui/incorrect/sub/sub_async.stderr b/proc-macros/tests/ui/incorrect/sub/sub_async.stderr new file mode 100644 index 0000000000..c524c677ef --- /dev/null +++ b/proc-macros/tests/ui/incorrect/sub/sub_async.stderr @@ -0,0 +1,6 @@ +error: Subscription methods must not be `async` + --> $DIR/sub_async.rs:6:2 + | +6 | / #[subscription(name = "sub", unsub = "unsub", item = u8)] +7 | | async fn sub(&self); + | |________________________^ diff --git a/proc-macros/tests/ui/incorrect/sub/sub_empty_attr.rs b/proc-macros/tests/ui/incorrect/sub/sub_empty_attr.rs new file mode 100644 index 0000000000..55124fe914 --- /dev/null +++ b/proc-macros/tests/ui/incorrect/sub/sub_empty_attr.rs @@ -0,0 +1,10 @@ +use jsonrpsee::proc_macros::rpc; + +// Missing all the mandatory fields. +#[rpc(client, server)] +pub trait SubEmptyAttr { + #[subscription()] + fn sub(&self); +} + +fn main() {} diff --git a/proc-macros/tests/ui/incorrect/sub/sub_empty_attr.stderr b/proc-macros/tests/ui/incorrect/sub/sub_empty_attr.stderr new file mode 100644 index 0000000000..035144fcbf --- /dev/null +++ b/proc-macros/tests/ui/incorrect/sub/sub_empty_attr.stderr @@ -0,0 +1,5 @@ +error: unexpected end of input, `#[subscription]` is missing `name` argument + --> $DIR/sub_empty_attr.rs:6:2 + | +6 | #[subscription()] + | ^^^^^^^^^^^^^^^^^ diff --git a/proc-macros/tests/ui/incorrect/sub/sub_no_item.rs b/proc-macros/tests/ui/incorrect/sub/sub_no_item.rs new file mode 100644 index 0000000000..27c5573240 --- /dev/null +++ b/proc-macros/tests/ui/incorrect/sub/sub_no_item.rs @@ -0,0 +1,10 @@ +use jsonrpsee::proc_macros::rpc; + +// Missing mandatory `item` field. +#[rpc(client, server)] +pub trait NoSubItem { + #[subscription(name = "sub", unsub = "unsub")] + fn sub(&self); +} + +fn main() {} diff --git a/proc-macros/tests/ui/incorrect/sub/sub_no_item.stderr b/proc-macros/tests/ui/incorrect/sub/sub_no_item.stderr new file mode 100644 index 0000000000..93317246dd --- /dev/null +++ b/proc-macros/tests/ui/incorrect/sub/sub_no_item.stderr @@ -0,0 +1,5 @@ +error: unexpected end of input, `#[subscription]` is missing `item` argument + --> $DIR/sub_no_item.rs:6:2 + | +6 | #[subscription(name = "sub", unsub = "unsub")] + | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ diff --git a/proc-macros/tests/ui/incorrect/sub/sub_no_name.rs b/proc-macros/tests/ui/incorrect/sub/sub_no_name.rs new file mode 100644 index 0000000000..4f2786b148 --- /dev/null +++ b/proc-macros/tests/ui/incorrect/sub/sub_no_name.rs @@ -0,0 +1,10 @@ +use jsonrpsee::proc_macros::rpc; + +// Missing mandatory `name` field. +#[rpc(client, server)] +pub trait NoSubName { + #[subscription(unsub = "unsub", item = String)] + async fn async_method(&self) -> u8; +} + +fn main() {} diff --git a/proc-macros/tests/ui/incorrect/sub/sub_no_name.stderr b/proc-macros/tests/ui/incorrect/sub/sub_no_name.stderr new file mode 100644 index 0000000000..d4f7828790 --- /dev/null +++ b/proc-macros/tests/ui/incorrect/sub/sub_no_name.stderr @@ -0,0 +1,6 @@ +error: Subscription methods must not be `async` + --> $DIR/sub_no_name.rs:6:2 + | +6 | / #[subscription(unsub = "unsub", item = String)] +7 | | async fn async_method(&self) -> u8; + | |_______________________________________^ diff --git a/proc-macros/tests/ui/incorrect/sub/sub_no_unsub.rs b/proc-macros/tests/ui/incorrect/sub/sub_no_unsub.rs new file mode 100644 index 0000000000..a8e13ac144 --- /dev/null +++ b/proc-macros/tests/ui/incorrect/sub/sub_no_unsub.rs @@ -0,0 +1,10 @@ +use jsonrpsee::proc_macros::rpc; + +// Missing mandatory `unsub` field. +#[rpc(client, server)] +pub trait NoSubUnsub { + #[subscription(name = "sub", item = String)] + fn sub(&self); +} + +fn main() {} diff --git a/proc-macros/tests/ui/incorrect/sub/sub_no_unsub.stderr b/proc-macros/tests/ui/incorrect/sub/sub_no_unsub.stderr new file mode 100644 index 0000000000..59b99c643d --- /dev/null +++ b/proc-macros/tests/ui/incorrect/sub/sub_no_unsub.stderr @@ -0,0 +1,5 @@ +error: unexpected end of input, `#[subscription]` is missing `unsub` argument + --> $DIR/sub_no_unsub.rs:6:2 + | +6 | #[subscription(name = "sub", item = String)] + | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ diff --git a/proc-macros/tests/ui/incorrect/sub/sub_return_type.rs b/proc-macros/tests/ui/incorrect/sub/sub_return_type.rs new file mode 100644 index 0000000000..b208deaaeb --- /dev/null +++ b/proc-macros/tests/ui/incorrect/sub/sub_return_type.rs @@ -0,0 +1,10 @@ +use jsonrpsee::proc_macros::rpc; + +// Subscription method must not have return type. +#[rpc(client, server)] +pub trait SubWithReturnType { + #[subscription(name = "sub", unsub = "unsub", item = u8)] + fn sub(&self) -> u8; +} + +fn main() {} diff --git a/proc-macros/tests/ui/incorrect/sub/sub_return_type.stderr b/proc-macros/tests/ui/incorrect/sub/sub_return_type.stderr new file mode 100644 index 0000000000..aceb16d2fc --- /dev/null +++ b/proc-macros/tests/ui/incorrect/sub/sub_return_type.stderr @@ -0,0 +1,6 @@ +error: Subscription methods must not return anything + --> $DIR/sub_return_type.rs:6:2 + | +6 | / #[subscription(name = "sub", unsub = "unsub", item = u8)] +7 | | fn sub(&self) -> u8; + | |________________________^ diff --git a/proc-macros/tests/ui/incorrect/sub/sub_unsupported_field.rs b/proc-macros/tests/ui/incorrect/sub/sub_unsupported_field.rs new file mode 100644 index 0000000000..a12f1da41f --- /dev/null +++ b/proc-macros/tests/ui/incorrect/sub/sub_unsupported_field.rs @@ -0,0 +1,10 @@ +use jsonrpsee::proc_macros::rpc; + +// Unsupported attribute field. +#[rpc(client, server)] +pub trait UnsupportedField { + #[subscription(name = "sub", unsub = "unsub", item = u8, magic = true)] + fn sub(&self); +} + +fn main() {} diff --git a/proc-macros/tests/ui/incorrect/sub/sub_unsupported_field.stderr b/proc-macros/tests/ui/incorrect/sub/sub_unsupported_field.stderr new file mode 100644 index 0000000000..fed19dfca9 --- /dev/null +++ b/proc-macros/tests/ui/incorrect/sub/sub_unsupported_field.stderr @@ -0,0 +1,5 @@ +error: `#[subscription]` got unknown `magic` argument. Supported arguments are `item`, `name`, `unsub` + --> $DIR/sub_unsupported_field.rs:6:2 + | +6 | #[subscription(name = "sub", unsub = "unsub", item = u8, magic = true)] + | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ diff --git a/tests/tests/new_proc_macros.rs b/tests/tests/new_proc_macros.rs new file mode 100644 index 0000000000..fdf729a72d --- /dev/null +++ b/tests/tests/new_proc_macros.rs @@ -0,0 +1,93 @@ +//! Example of using proc macro to generate working client and server. + +use std::net::SocketAddr; + +use futures_channel::oneshot; +use jsonrpsee::{ws_client::*, ws_server::WsServerBuilder}; + +mod rpc_impl { + use jsonrpsee::{proc_macros::rpc, types::async_trait, ws_server::SubscriptionSink}; + + #[rpc(client, server, namespace = "foo")] + pub trait Rpc { + #[method(name = "foo")] + async fn async_method(&self, param_a: u8, param_b: String) -> u16; + + #[method(name = "bar")] + fn sync_method(&self) -> u16; + + #[subscription(name = "sub", unsub = "unsub", item = String)] + fn sub(&self); + + #[subscription(name = "echo", unsub = "no_more_echo", item = u32)] + fn sub_with_params(&self, val: u32); + } + + pub struct RpcServerImpl; + + #[async_trait] + impl RpcServer for RpcServerImpl { + async fn async_method(&self, _param_a: u8, _param_b: String) -> u16 { + 42u16 + } + + fn sync_method(&self) -> u16 { + 10u16 + } + + fn sub(&self, mut sink: SubscriptionSink) { + sink.send(&"Response_A").unwrap(); + sink.send(&"Response_B").unwrap(); + } + + fn sub_with_params(&self, mut sink: SubscriptionSink, val: u32) { + sink.send(&val).unwrap(); + sink.send(&val).unwrap(); + } + } +} + +// Use generated implementations of server and client. +use rpc_impl::{RpcClient, RpcServer, RpcServerImpl}; + +pub async fn websocket_server() -> SocketAddr { + let (server_started_tx, server_started_rx) = oneshot::channel(); + + std::thread::spawn(move || { + let rt = tokio::runtime::Runtime::new().unwrap(); + let mut server = rt.block_on(WsServerBuilder::default().build("127.0.0.1:0")).unwrap(); + server.register_module(RpcServerImpl.into_rpc().unwrap()).unwrap(); + + rt.block_on(async move { + server_started_tx.send(server.local_addr().unwrap()).unwrap(); + + server.start().await + }); + }); + + server_started_rx.await.unwrap() +} + +#[tokio::test] +async fn proc_macros_generic_ws_client_api() { + let server_addr = websocket_server().await; + let server_url = format!("ws://{}", server_addr); + let client = WsClientBuilder::default().build(&server_url).await.unwrap(); + + assert_eq!(client.async_method(10, "a".into()).await.unwrap(), 42); + assert_eq!(client.sync_method().await.unwrap(), 10); + + // Sub without params + let mut sub = client.sub().await.unwrap(); + let first_recv = sub.next().await.unwrap(); + assert_eq!(first_recv, Some("Response_A".to_string())); + let second_recv = sub.next().await.unwrap(); + assert_eq!(second_recv, Some("Response_B".to_string())); + + // Sub with params + let mut sub = client.sub_with_params(42).await.unwrap(); + let first_recv = sub.next().await.unwrap(); + assert_eq!(first_recv, Some(42)); + let second_recv = sub.next().await.unwrap(); + assert_eq!(second_recv, Some(42)); +} diff --git a/types/src/lib.rs b/types/src/lib.rs index 0009aee4b6..aae50dfd9d 100644 --- a/types/src/lib.rs +++ b/types/src/lib.rs @@ -20,6 +20,7 @@ mod client; /// Traits pub mod traits; +pub use async_trait::async_trait; pub use beef::Cow; pub use client::*; pub use error::Error; @@ -28,3 +29,12 @@ pub use serde_json::{ to_value as to_json_value, value::to_raw_value as to_json_raw_value, value::RawValue as JsonRawValue, Value as JsonValue, }; + +/// Re-exports for proc-macro library to not require any additional +/// dependencies to be explicitly added on the client side. +#[doc(hidden)] +pub mod __reexports { + pub use async_trait::async_trait; + pub use serde; + pub use serde_json; +} diff --git a/types/src/v2/params.rs b/types/src/v2/params.rs index 890a4caea9..d93a63f2a8 100644 --- a/types/src/v2/params.rs +++ b/types/src/v2/params.rs @@ -167,6 +167,13 @@ impl<'a> RpcParams<'a> { { self.parse::<[T; 1]>().map(|[res]| res) } + + /// Creates an owned version of parameters. + /// Required to simplify proc-macro implementation. + #[doc(hidden)] + pub fn owned(self) -> OwnedRpcParams { + self.into() + } } /// Owned version of [`RpcParams`].