Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pass OwnedRpcParams to async methods #410

Merged
merged 8 commits into from
Jul 10, 2021
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions proc-macros/src/new/render_server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,7 @@ impl RpcDescription {
if method.signature.sig.asyncness.is_some() {
quote! {
rpc.register_async_method(#rpc_method_name, |params, context| {
let owned_params = params.owned();
let fut = async move {
let params = owned_params.borrowed();
#parsing
Ok(context.as_ref().#rust_method_name(#params_seq).await)
};
Expand Down
100 changes: 49 additions & 51 deletions types/src/v2/params.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,42 +59,54 @@ impl Serialize for TwoPointZero {
}

/// Parameters sent with the RPC request
#[derive(Clone, Copy, Debug)]
pub struct RpcParams<'a>(Option<&'a str>);
#[derive(Clone, Debug)]
pub struct RpcParams<'a> {
json: Option<Cow<'a, str>>,
offset: usize,
}

impl<'a> RpcParams<'a> {
/// Create params
pub fn new(raw: Option<&'a str>) -> Self {
Self(raw)
pub fn new(json: Option<&'a str>) -> Self {
Self {
json: json.map(Into::into),
offset: 0,
}
}

fn next_inner<T>(&mut self) -> Option<Result<T, CallError>>
fn next_inner<'temp, T>(&'temp mut self) -> Option<Result<T, CallError>>
where
T: Deserialize<'a>,
T: Deserialize<'temp>,
{
let mut json = self.0?.trim_start();

match json.as_bytes().get(0)? {
b']' => {
self.0 = None;

return None;
let json = self.json.as_ref()?.as_ref();

loop {
match json.as_bytes().get(self.offset)? {
b']' => {
return None;
}
b'[' | b',' => {
self.offset += 1;
break;
}
b' ' | b'\n' | b'\r' | b'\t' | 0x0C => {
self.offset += 1;
}
_ => {
return Some(Err(CallError::InvalidParams));
}
}
b'[' | b',' => json = &json[1..],
_ => return Some(Err(CallError::InvalidParams)),
}

let mut iter = serde_json::Deserializer::from_str(json).into_iter::<T>();
let mut iter = serde_json::Deserializer::from_str(&json[self.offset..]).into_iter::<T>();

match iter.next()? {
Ok(value) => {
self.0 = Some(&json[iter.byte_offset()..]);
self.offset += iter.byte_offset();

Some(Ok(value))
}
Err(_) => {
self.0 = None;

Some(Err(CallError::InvalidParams))
}
}
Expand All @@ -114,9 +126,9 @@ impl<'a> RpcParams<'a> {
/// assert_eq!(b, 10);
/// assert_eq!(c, "foo");
/// ```
pub fn next<T>(&mut self) -> Result<T, CallError>
pub fn next<'temp, T>(&'temp mut self) -> Result<T, CallError>
where
T: Deserialize<'a>,
T: Deserialize<'temp>,
{
match self.next_inner() {
Some(result) => result,
Expand All @@ -141,9 +153,9 @@ impl<'a> RpcParams<'a> {
///
/// assert_eq!(params, [Some(1), Some(2), None, None]);
/// ```
pub fn optional_next<T>(&mut self) -> Result<Option<T>, CallError>
pub fn optional_next<'temp, T>(&'temp mut self) -> Result<Option<T>, CallError>
where
T: Deserialize<'a>,
T: Deserialize<'temp>,
{
match self.next_inner::<Option<T>>() {
Some(result) => result,
Expand All @@ -152,52 +164,38 @@ impl<'a> RpcParams<'a> {
}

/// Attempt to parse all parameters as array or map into type `T`
pub fn parse<T>(self) -> Result<T, CallError>
pub fn parse<T>(&'a self) -> Result<T, CallError>
where
T: Deserialize<'a>,
{
let params = self.0.unwrap_or("null");
let params = self.json.as_ref().map(AsRef::as_ref).unwrap_or("null");
maciejhirsz marked this conversation as resolved.
Show resolved Hide resolved
serde_json::from_str(params).map_err(|_| CallError::InvalidParams)
}

/// Attempt to parse parameters as an array of a single value of type `T`, and returns that value.
pub fn one<T>(self) -> Result<T, CallError>
pub fn one<T>(&'a self) -> Result<T, CallError>
where
T: Deserialize<'a>,
{
self.parse::<[T; 1]>().map(|[res]| res)
}

/// Creates an owned version of parameters.
/// Required to simplify proc-macro implementation.
#[doc(hidden)]
pub fn owned(self) -> OwnedRpcParams {
self.into()
}
}

/// Owned version of [`RpcParams`].
#[derive(Clone, Debug)]
pub struct OwnedRpcParams(Option<String>);

impl OwnedRpcParams {
/// Converts `OwnedRpcParams` into borrowed [`RpcParams`].
pub fn borrowed(&self) -> RpcParams<'_> {
RpcParams(self.0.as_ref().map(|s| s.as_ref()))
}
}

impl<'a> From<RpcParams<'a>> for OwnedRpcParams {
fn from(borrowed: RpcParams<'a>) -> Self {
Self(borrowed.0.map(Into::into))
/// Convert `RpcParams<'a>` to `RpcParams<'static>` so that it can be moved across threads.
///
/// This will cause an allocation if the params internally are using a borrowed JSON slice.
pub fn into_owned(self) -> RpcParams<'static> {
RpcParams {
json: self.json.map(|s| Cow::owned(s.into_owned())),
offset: self.offset,
}
}
}

/// [Serializable JSON-RPC parameters](https://www.jsonrpc.org/specification#parameter_structures)
///
/// If your type implement `Into<JsonValue>` call that favor of `serde_json::to:value` to
/// construct the parameters. Because `serde_json::to_value` serializes the type which
/// allocates whereas `Into<JsonValue>` doesn't in most cases.
/// If your type implement `Into<JsonValue>` call that in favor of `serde_json::to:value` to
/// construct the parameters. Because `serde_json::to_value` serializes the type which allocates
/// whereas `Into<JsonValue>` doesn't in most cases.
#[derive(Serialize, Debug, Clone)]
#[serde(untagged)]
pub enum JsonRpcParams<'a> {
Expand Down
16 changes: 10 additions & 6 deletions utils/src/server/rpc_module.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@ use futures_util::{future::BoxFuture, FutureExt};
use jsonrpsee_types::error::{CallError, Error, SubscriptionClosedError};
use jsonrpsee_types::v2::error::{JsonRpcErrorCode, JsonRpcErrorObject, CALL_EXECUTION_FAILED_CODE};
use jsonrpsee_types::v2::params::{
Id, JsonRpcSubscriptionParams, OwnedId, OwnedRpcParams, RpcParams, SubscriptionId as JsonRpcSubscriptionId,
TwoPointZero,
Id, JsonRpcSubscriptionParams, OwnedId, RpcParams, SubscriptionId as JsonRpcSubscriptionId, TwoPointZero,
};
use jsonrpsee_types::v2::request::{JsonRpcNotification, JsonRpcRequest};

Expand All @@ -22,7 +21,9 @@ use std::sync::Arc;
pub type SyncMethod = Arc<dyn Send + Sync + Fn(Id, RpcParams, &MethodSink, ConnectionId) -> Result<(), Error>>;
/// Similar to [`SyncMethod`], but represents an asynchronous handler.
pub type AsyncMethod = Arc<
dyn Send + Sync + Fn(OwnedId, OwnedRpcParams, MethodSink, ConnectionId) -> BoxFuture<'static, Result<(), Error>>,
dyn Send
+ Sync
+ Fn(OwnedId, RpcParams<'static>, MethodSink, ConnectionId) -> BoxFuture<'static, Result<(), Error>>,
>;
/// Connection ID, used for stateful protocol such as WebSockets.
/// For stateless protocols such as http it's unused, so feel free to set it some hardcoded value.
Expand Down Expand Up @@ -60,7 +61,7 @@ impl MethodCallback {
MethodCallback::Sync(callback) => (callback)(req.id.clone(), params, tx, conn_id),
MethodCallback::Async(callback) => {
let tx = tx.clone();
let params = OwnedRpcParams::from(params);
let params = params.into_owned();
let id = OwnedId::from(req.id);

(callback)(id, params, tx, conn_id).await
Expand Down Expand Up @@ -215,7 +216,11 @@ impl<Context: Send + Sync + 'static> RpcModule<Context> {
pub fn register_async_method<R, F>(&mut self, method_name: &'static str, callback: F) -> Result<(), Error>
where
R: Serialize + Send + Sync + 'static,
F: Fn(RpcParams, Arc<Context>) -> BoxFuture<'static, Result<R, CallError>> + Copy + Send + Sync + 'static,
F: Fn(RpcParams<'static>, Arc<Context>) -> BoxFuture<'static, Result<R, CallError>>
+ Copy
+ Send
+ Sync
+ 'static,
{
self.methods.verify_method_name(method_name)?;

Expand All @@ -226,7 +231,6 @@ impl<Context: Send + Sync + 'static> RpcModule<Context> {
MethodCallback::Async(Arc::new(move |id, params, tx, _| {
let ctx = ctx.clone();
let future = async move {
let params = params.borrowed();
let id = id.borrowed();
match callback(params, ctx).await {
Ok(res) => send_response(id, &tx, res),
Expand Down
38 changes: 27 additions & 11 deletions ws-server/src/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,7 @@ use futures_util::FutureExt;
use jsonrpsee_test_utils::helpers::*;
use jsonrpsee_test_utils::types::{Id, TestContext, WebSocketTestClient};
use jsonrpsee_test_utils::TimeoutFutureExt;
use jsonrpsee_types::{
error::{CallError, Error},
v2::params::RpcParams,
};
use jsonrpsee_types::error::{CallError, Error};
use serde_json::Value as JsonValue;
use std::fmt;
use std::net::SocketAddr;
Expand All @@ -25,13 +22,15 @@ impl fmt::Display for MyAppError {
impl std::error::Error for MyAppError {}

/// Spawns a dummy `JSONRPC v2 WebSocket`
/// It has two hardcoded methods: "say_hello" and "add"
async fn server() -> SocketAddr {
server_with_handles().await.0
}

/// Spawns a dummy `JSONRPC v2 WebSocket`
/// It has two hardcoded methods: "say_hello" and "add"
/// It has the following methods:
/// sync methods: `say_hello` and `add`
/// async: `say_hello_async` and `add_sync`
/// other: `invalid_params` (always returns `CallError::InvalidParams`), `call_fail` (always returns `CallError::Failed`), `sleep_for`
/// Returns the address together with handles for server future and server stop.
async fn server_with_handles() -> (SocketAddr, JoinHandle<()>, StopHandle) {
let mut server = WsServerBuilder::default().build("127.0.0.1:0").with_default_timeout().await.unwrap().unwrap();
Expand All @@ -43,7 +42,14 @@ async fn server_with_handles() -> (SocketAddr, JoinHandle<()>, StopHandle) {
})
.unwrap();
module
.register_async_method("say_hello_async", |_: RpcParams, _| {
.register_method("add", |params, _| {
let params: Vec<u64> = params.parse()?;
let sum: u64 = params.into_iter().sum();
Ok(sum)
})
.unwrap();
module
.register_async_method("say_hello_async", |_, _| {
async move {
log::debug!("server respond to hello");
// Call some async function inside.
Expand All @@ -53,13 +59,13 @@ async fn server_with_handles() -> (SocketAddr, JoinHandle<()>, StopHandle) {
.boxed()
})
.unwrap();
module
.register_method("add", |params, _| {
module.register_async_method("add_async", |params, _| {
async move {
let params: Vec<u64> = params.parse()?;
let sum: u64 = params.into_iter().sum();
Ok(sum)
})
.unwrap();
}.boxed()
}).unwrap();
module.register_method("invalid_params", |_params, _| Err::<(), _>(CallError::InvalidParams)).unwrap();
module.register_method("call_fail", |_params, _| Err::<(), _>(CallError::Failed(Box::new(MyAppError)))).unwrap();
module
Expand Down Expand Up @@ -310,6 +316,16 @@ async fn async_method_call_with_ok_context() {
assert_eq!(response, ok_response("ok!".into(), Id::Num(1)));
}

#[tokio::test]
async fn async_method_call_with_params() {
let addr = server().await;
let mut client = WebSocketTestClient::new(addr).await.unwrap();

let req = r#"{"jsonrpc":"2.0","method":"add_async", "params":[1, 2],"id":1}"#;
let response = client.send_request_text(req).await.unwrap();
assert_eq!(response, ok_response(JsonValue::Number(3.into()), Id::Num(1)));
}

#[tokio::test]
async fn async_method_call_that_fails() {
let addr = server_with_context().await;
Expand Down