Skip to content

Commit

Permalink
feat: inject connection id in extensions (#1381)
Browse files Browse the repository at this point in the history
* feat: inject connection id in extensions

* simplify connection details example
  • Loading branch information
niklasad1 authored May 29, 2024
1 parent 318a6c9 commit b6d94c7
Show file tree
Hide file tree
Showing 7 changed files with 61 additions and 127 deletions.
28 changes: 21 additions & 7 deletions core/src/server/rpc_module.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,21 @@ pub type SubscriptionMethod<'a> =
type UnsubscriptionMethod =
Arc<dyn Send + Sync + Fn(Id, Params, ConnectionId, MaxResponseSize, Extensions) -> MethodResponse>;

/// Connection ID, used for stateful protocol such as WebSockets.
/// For stateless protocols such as http it's unused, so feel free to set it some hardcoded value.
pub type ConnectionId = usize;
/// Connection ID.
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash, Default, serde::Deserialize, serde::Serialize)]
pub struct ConnectionId(pub usize);

impl From<u32> for ConnectionId {
fn from(id: u32) -> Self {
Self(id as usize)
}
}

impl From<usize> for ConnectionId {
fn from(id: usize) -> Self {
Self(id)
}
}

/// Max response size.
pub type MaxResponseSize = usize;
Expand Down Expand Up @@ -356,16 +368,18 @@ impl Methods {
let (tx, mut rx) = mpsc::channel(buf_size);
let Request { id, method, params, extensions, .. } = req;
let params = Params::new(params.as_ref().map(|params| params.as_ref().get()));
let max_response_size = usize::MAX;
let conn_id = ConnectionId(0);

let response = match self.method(&method) {
None => MethodResponse::error(id, ErrorObject::from(ErrorCode::MethodNotFound)),
Some(MethodCallback::Sync(cb)) => (cb)(id, params, usize::MAX, extensions),
Some(MethodCallback::Sync(cb)) => (cb)(id, params, max_response_size, extensions),
Some(MethodCallback::Async(cb)) => {
(cb)(id.into_owned(), params.into_owned(), 0, usize::MAX, extensions).await
(cb)(id.into_owned(), params.into_owned(), conn_id, max_response_size, extensions).await
}
Some(MethodCallback::Subscription(cb)) => {
let conn_state =
SubscriptionState { conn_id: 0, id_provider: &RandomIntegerIdProvider, subscription_permit };
SubscriptionState { conn_id, id_provider: &RandomIntegerIdProvider, subscription_permit };
let res = (cb)(id, params, MethodSink::new(tx.clone()), conn_state, extensions).await;

// This message is not used because it's used for metrics so we discard in other to
Expand All @@ -376,7 +390,7 @@ impl Methods {

res
}
Some(MethodCallback::Unsubscription(cb)) => (cb)(id, params, 0, usize::MAX, extensions),
Some(MethodCallback::Unsubscription(cb)) => (cb)(id, params, conn_id, max_response_size, extensions),
};

let is_success = response.is_success();
Expand Down
8 changes: 4 additions & 4 deletions core/src/server/subscription.rs
Original file line number Diff line number Diff line change
Expand Up @@ -302,8 +302,8 @@ impl PendingSubscriptionSink {
}

/// Returns connection identifier, which was used to perform pending subscription request
pub fn connection_id(&self) -> ConnectionId {
self.uniq_sub.conn_id
pub fn connection_id(&self) -> usize {
self.uniq_sub.conn_id.0
}
}

Expand Down Expand Up @@ -336,8 +336,8 @@ impl SubscriptionSink {
}

/// Get the connection ID.
pub fn connection_id(&self) -> ConnectionId {
self.uniq_sub.conn_id
pub fn connection_id(&self) -> usize {
self.uniq_sub.conn_id.0
}

/// Send out a response on the subscription and wait until there is capacity.
Expand Down
130 changes: 21 additions & 109 deletions examples/examples/server_with_connection_details.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,62 +25,43 @@
// DEALINGS IN THE SOFTWARE.

use std::net::SocketAddr;
use std::sync::atomic::AtomicU32;
use std::sync::Arc;

use futures::future::{self, Either};
use hyper_util::rt::{TokioExecutor, TokioIo};
use jsonrpsee::core::async_trait;
use jsonrpsee::core::SubscriptionResult;
use jsonrpsee::proc_macros::rpc;
use jsonrpsee::server::middleware::rpc::RpcServiceT;
use jsonrpsee::server::{stop_channel, PendingSubscriptionSink, RpcServiceBuilder, SubscriptionMessage};
use jsonrpsee::server::{PendingSubscriptionSink, SubscriptionMessage};
use jsonrpsee::types::{ErrorObject, ErrorObjectOwned};
use jsonrpsee::ws_client::WsClientBuilder;
use jsonrpsee::ConnectionId;
use jsonrpsee::Extensions;
use tokio::net::TcpListener;
use tower::Service;

#[derive(Debug, Clone)]
struct ConnectionDetails<S> {
inner: S,
connection_id: u32,
}

impl<'a, S> RpcServiceT<'a> for ConnectionDetails<S>
where
S: RpcServiceT<'a>,
{
type Future = S::Future;

fn call(&self, mut request: jsonrpsee::types::Request<'a>) -> Self::Future {
request.extensions_mut().insert(self.connection_id);
self.inner.call(request)
}
}

#[rpc(server, client)]
pub trait Rpc {
/// method with connection ID.
#[method(name = "connectionIdMethod")]
async fn method(&self, first_param: usize, second_param: u16) -> Result<u32, ErrorObjectOwned>;
async fn method(&self) -> Result<usize, ErrorObjectOwned>;

#[subscription(name = "subscribeConnectionId", item = u32)]
#[subscription(name = "subscribeConnectionId", item = usize)]
async fn sub(&self) -> SubscriptionResult;
}

pub struct RpcServerImpl;

#[async_trait]
impl RpcServer for RpcServerImpl {
async fn method(&self, ext: &Extensions, _first_param: usize, _second_param: u16) -> Result<u32, ErrorObjectOwned> {
ext.get::<u32>().cloned().ok_or_else(|| ErrorObject::owned(0, "No connection details found", None::<()>))
async fn method(&self, ext: &Extensions) -> Result<usize, ErrorObjectOwned> {
let conn_id = ext
.get::<ConnectionId>()
.cloned()
.ok_or_else(|| ErrorObject::owned(0, "No connection details found", None::<()>))?;

Ok(conn_id.0)
}

async fn sub(&self, pending: PendingSubscriptionSink, ext: &Extensions) -> SubscriptionResult {
let sink = pending.accept().await?;
let conn_id = ext
.get::<u32>()
.get::<ConnectionId>()
.cloned()
.ok_or_else(|| ErrorObject::owned(0, "No connection details found", None::<()>))?;
sink.send(SubscriptionMessage::from_json(&conn_id).unwrap()).await?;
Expand All @@ -99,14 +80,14 @@ async fn main() -> anyhow::Result<()> {
let url = format!("ws://{}", server_addr);

let client = WsClientBuilder::default().build(&url).await?;
let connection_id_first = client.method(1, 2).await.unwrap();
let connection_id_first = client.method().await.unwrap();

// Second call from the same connection ID.
assert_eq!(client.method(1, 2).await.unwrap(), connection_id_first);
assert_eq!(client.method().await.unwrap(), connection_id_first);

// Second client will increment the connection ID.
let client2 = WsClientBuilder::default().build(&url).await?;
let connection_id_second = client2.method(1, 2).await.unwrap();
let connection_id_second = client2.method().await.unwrap();
assert_ne!(connection_id_first, connection_id_second);

let mut sub = client.sub().await.unwrap();
Expand All @@ -119,81 +100,12 @@ async fn main() -> anyhow::Result<()> {
}

async fn run_server() -> anyhow::Result<SocketAddr> {
let listener = TcpListener::bind(SocketAddr::from(([127, 0, 0, 1], 0))).await?;
let addr = listener.local_addr()?;

let (stop_hdl, server_hdl) = stop_channel();

tokio::spawn(async move {
let conn_id = Arc::new(AtomicU32::new(0));
// Create and finalize a server configuration from a TowerServiceBuilder
// given an RpcModule and the stop handle.
let svc_builder = jsonrpsee::server::Server::builder().to_service_builder();
let methods = RpcServerImpl.into_rpc();

loop {
let stream = tokio::select! {
res = listener.accept() => {
match res {
Ok((stream, _remote_addr)) => stream,
Err(e) => {
tracing::error!("failed to accept v4 connection: {:?}", e);
continue;
}
}
}
_ = stop_hdl.clone().shutdown() => break,
};

let methods2 = methods.clone();
let stop_hdl2 = stop_hdl.clone();
let svc_builder2 = svc_builder.clone();
let conn_id2 = conn_id.clone();
let svc = hyper::service::service_fn(move |req: hyper::Request<hyper::body::Incoming>| {
let connection_id = conn_id2.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
let rpc_middleware = RpcServiceBuilder::default()
.layer_fn(move |service| ConnectionDetails { inner: service, connection_id });

// Start a new service with our own connection ID.
let mut tower_service = svc_builder2
.clone()
.set_rpc_middleware(rpc_middleware)
.connection_id(connection_id)
.build(methods2.clone(), stop_hdl2.clone());

async move { tower_service.call(req).await.map_err(|e| anyhow::anyhow!("{:?}", e)) }
});

let stop_hdl2 = stop_hdl.clone();
// Spawn a new task to serve each respective (Hyper) connection.
tokio::spawn(async move {
let builder = hyper_util::server::conn::auto::Builder::new(TokioExecutor::new());
let conn = builder.serve_connection_with_upgrades(TokioIo::new(stream), svc);
let stopped = stop_hdl2.shutdown();

// Pin the future so that it can be polled.
tokio::pin!(stopped, conn);

let res = match future::select(conn, stopped).await {
// Return the connection if not stopped.
Either::Left((conn, _)) => conn,
// If the server is stopped, we should gracefully shutdown
// the connection and poll it until it finishes.
Either::Right((_, mut conn)) => {
conn.as_mut().graceful_shutdown();
conn.await
}
};

// Log any errors that might have occurred.
if let Err(err) = res {
tracing::error!(err=?err, "HTTP connection failed");
}
});
}
});

tokio::spawn(server_hdl.stopped());
let server = jsonrpsee::server::Server::builder().build("127.0.0.1:0").await?;
let addr = server.local_addr()?;

let handle = server.start(RpcServerImpl.into_rpc());

tokio::spawn(handle.stopped());

Ok(addr)
}
14 changes: 11 additions & 3 deletions server/src/middleware/rpc/layer/rpc_service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ use super::ResponseFuture;
use std::sync::Arc;

use crate::middleware::rpc::RpcServiceT;
use crate::ConnectionId;
use futures_util::future::BoxFuture;
use jsonrpsee_core::server::{
BoundedSubscriptions, MethodCallback, MethodResponse, MethodSink, Methods, SubscriptionState,
Expand All @@ -41,7 +42,7 @@ use jsonrpsee_types::{ErrorObject, Request};
/// JSON-RPC service middleware.
#[derive(Clone, Debug)]
pub struct RpcService {
conn_id: usize,
conn_id: ConnectionId,
methods: Methods,
max_response_body_size: usize,
cfg: RpcServiceCfg,
Expand All @@ -63,7 +64,12 @@ pub(crate) enum RpcServiceCfg {

impl RpcService {
/// Create a new service.
pub(crate) fn new(methods: Methods, max_response_body_size: usize, conn_id: usize, cfg: RpcServiceCfg) -> Self {
pub(crate) fn new(
methods: Methods,
max_response_body_size: usize,
conn_id: ConnectionId,
cfg: RpcServiceCfg,
) -> Self {
Self { methods, max_response_body_size, conn_id, cfg }
}
}
Expand All @@ -77,7 +83,9 @@ impl<'a> RpcServiceT<'a> for RpcService {
let conn_id = self.conn_id;
let max_response_body_size = self.max_response_body_size;

let Request { id, method, params, extensions, .. } = req;
let Request { id, method, params, mut extensions, .. } = req;
extensions.insert(conn_id);

let params = jsonrpsee_types::Params::new(params.as_ref().map(|p| serde_json::value::RawValue::get(p)));

match self.methods.method_with_name(&method) {
Expand Down
4 changes: 2 additions & 2 deletions server/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1101,7 +1101,7 @@ where
let rpc_service = RpcService::new(
this.methods.clone(),
this.server_cfg.max_response_body_size as usize,
this.conn_id as usize,
this.conn_id.into(),
cfg,
);

Expand Down Expand Up @@ -1160,7 +1160,7 @@ where
let rpc_service = self.rpc_middleware.service(RpcService::new(
methods,
max_response_size as usize,
this.conn_id as usize,
this.conn_id.into(),
RpcServiceCfg::OnlyCalls,
));

Expand Down
2 changes: 1 addition & 1 deletion server/src/transport/http.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ where
let rpc_service = rpc_service.service(RpcService::new(
methods.into(),
max_response_body_size as usize,
conn.conn_id as usize,
conn.conn_id.into(),
RpcServiceCfg::OnlyCalls,
));

Expand Down
2 changes: 1 addition & 1 deletion server/src/transport/ws.rs
Original file line number Diff line number Diff line change
Expand Up @@ -443,7 +443,7 @@ where
let rpc_service = RpcService::new(
methods.into(),
server_cfg.max_response_body_size as usize,
conn.conn_id as usize,
conn.conn_id.into(),
rpc_service_cfg,
);

Expand Down

0 comments on commit b6d94c7

Please sign in to comment.