Skip to content

Commit

Permalink
rpc module: fix race in subscription close callback (#1098)
Browse files Browse the repository at this point in the history
* rpc module: fix race in subscription close callback

* use futures::future::try_join

* fix bad test
  • Loading branch information
niklasad1 authored Apr 26, 2023
1 parent 1339e72 commit 8fee8c2
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 33 deletions.
29 changes: 17 additions & 12 deletions core/src/server/rpc_module.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ use std::collections::hash_map::Entry;
use std::fmt::{self, Debug};
use std::future::Future;
use std::ops::{Deref, DerefMut};
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;

use crate::error::Error;
Expand Down Expand Up @@ -691,7 +690,7 @@ impl<Context: Send + Sync + 'static> RpcModule<Context> {

// response to the subscription call.
let (tx, rx) = oneshot::channel();
let is_accepted = Arc::new(AtomicBool::new(false));
let (accepted_tx, accepted_rx) = oneshot::channel();

let sub_id = uniq_sub.sub_id.clone();
let method = notif_method_name;
Expand All @@ -712,18 +711,24 @@ impl<Context: Send + Sync + 'static> RpcModule<Context> {
// This runs until the subscription callback has completed.
let sub_fut = callback(params.into_owned(), sink, ctx.clone());

let is_accepted2 = is_accepted.clone();
tokio::spawn(async move {
match sub_fut.await.into_response() {
SubscriptionCloseResponse::Notif(msg) if is_accepted2.load(Ordering::SeqCst) => {
// This will wait for the subscription future to be resolved
let response = match futures_util::future::try_join(sub_fut.map(|f| Ok(f)), accepted_rx).await {
Ok((r, _)) => r.into_response(),
// The accept call failed i.e, the subscription was not accepted.
Err(_) => return,
};

match response {
SubscriptionCloseResponse::Notif(msg) => {
let json = sub_message_to_json(msg, SubNotifResultOrError::Result, &sub_id, method);
_ = method_sink.send(json).await;
let _ = method_sink.send(json).await;
}
SubscriptionCloseResponse::NotifErr(msg) if is_accepted2.load(Ordering::SeqCst) => {
SubscriptionCloseResponse::NotifErr(msg) => {
let json = sub_message_to_json(msg, SubNotifResultOrError::Error, &sub_id, method);
_ = method_sink.send(json).await;
let _ = method_sink.send(json).await;
}
_ => (),
SubscriptionCloseResponse::None => (),
}
});

Expand All @@ -732,10 +737,10 @@ impl<Context: Send + Sync + 'static> RpcModule<Context> {
Box::pin(async move {
match rx.await {
Ok(msg) => {
// If the subscription is rejected `success` will be set
// to false to prevent further notifications to be sent.
// If the subscription was accepted then send a message
// to subscription task otherwise rely on the drop impl.
if msg.success {
is_accepted.store(true, Ordering::SeqCst);
let _ = accepted_tx.send(());
}
Ok(msg)
}
Expand Down
23 changes: 3 additions & 20 deletions tests/tests/helpers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ use jsonrpsee::server::{
AllowHosts, PendingSubscriptionSink, RpcModule, ServerBuilder, ServerHandle, SubscriptionMessage, TrySendError,
};
use jsonrpsee::types::{ErrorObject, ErrorObjectOwned};
use jsonrpsee::{IntoSubscriptionCloseResponse, SubscriptionCloseResponse};
use jsonrpsee::SubscriptionCloseResponse;
use serde::Serialize;
use tokio::time::interval;
use tokio_stream::wrappers::IntervalStream;
Expand Down Expand Up @@ -105,25 +105,8 @@ pub async fn server_with_subscription_and_handle() -> (SocketAddr, ServerHandle)

module
.register_subscription("subscribe_option", "n", "unsubscribe_option", |_, pending, _| async move {
enum Response {
Nothing,
Closed,
}

impl IntoSubscriptionCloseResponse for Response {
fn into_response(self) -> jsonrpsee::SubscriptionCloseResponse {
match self {
Response::Nothing => SubscriptionCloseResponse::None,
Response::Closed => SubscriptionCloseResponse::Notif("close".into()),
}
}
}

let Ok(_sink) = pending.accept().await else {
return Response::Nothing;
};

Response::Closed
let _ = pending.accept().await;
SubscriptionCloseResponse::None
})
.unwrap();

Expand Down
51 changes: 50 additions & 1 deletion tests/tests/rpc_module.rs
Original file line number Diff line number Diff line change
Expand Up @@ -404,7 +404,7 @@ async fn rejected_subscription_without_server() {
module
.register_subscription("my_sub", "my_sub", "my_unsub", |_, pending, _| async move {
let err = ErrorObject::borrowed(PARSE_ERROR_CODE, &"rejected", None);
let _ = pending.reject(err.into_owned()).await;
pending.reject(err.into_owned()).await;
Ok(())
})
.unwrap();
Expand Down Expand Up @@ -534,3 +534,52 @@ async fn serialize_sub_error_adds_extra_string_quotes() {
sub_resp
);
}

#[tokio::test]
async fn subscription_close_response_works() {
use jsonrpsee::SubscriptionCloseResponse;

init_logger();

let mut module = RpcModule::new(());

module
.register_subscription("my_sub", "my_sub", "my_unsub", |params, pending, _| async move {
let x = match params.one::<usize>() {
Ok(op) => op,
Err(e) => {
pending.reject(e).await;
return SubscriptionCloseResponse::None;
}
};

let _sink = pending.accept().await.unwrap();

SubscriptionCloseResponse::Notif(SubscriptionMessage::from_json(&x).unwrap())
})
.unwrap();

// ensure subscription with raw_json_request works.
{
let (rp, mut stream) =
module.raw_json_request(r#"{"jsonrpc":"2.0","method":"my_sub","params":[1],"id":0}"#, 1).await.unwrap();
let resp = serde_json::from_str::<Response<u64>>(&rp.result).unwrap();

let sub_id = match resp.payload {
ResponsePayload::Result(val) => val,
_ => panic!("Expected valid response"),
};

assert_eq!(
format!(r#"{{"jsonrpc":"2.0","method":"my_sub","params":{{"subscription":{},"result":1}}}}"#, sub_id),
stream.recv().await.unwrap()
);
}

// ensure subscribe API works.
{
let mut sub = module.subscribe_unbounded("my_sub", [1]).await.unwrap();
let (rx, _id) = sub.next::<usize>().await.unwrap().unwrap();
assert_eq!(rx, 1);
}
}

0 comments on commit 8fee8c2

Please sign in to comment.