Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(tonic): add Request and Response extensions #642

Merged
merged 2 commits into from
May 13, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 14 additions & 1 deletion examples/src/interceptor/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@ impl Greeter for MyGreeter {
&self,
request: Request<HelloRequest>,
) -> Result<Response<HelloReply>, Status> {
let extension = request.extensions().get::<MyExtension>().unwrap();
println!("extension data = {}", extension.some_piece_of_data);

let reply = hello_world::HelloReply {
message: format!("Hello {}!", request.into_inner().name),
};
Expand All @@ -40,7 +43,17 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
/// This function will get called on each inbound request, if a `Status`
/// is returned, it will cancel the request and return that status to the
/// client.
fn intercept(req: Request<()>) -> Result<Request<()>, Status> {
fn intercept(mut req: Request<()>) -> Result<Request<()>, Status> {
println!("Intercepting request: {:?}", req);

// Set an extension that can be retrieved by `say_hello`
req.extensions_mut().insert(MyExtension {
some_piece_of_data: "foo".to_string(),
});

Ok(req)
}

struct MyExtension {
some_piece_of_data: String,
}
3 changes: 3 additions & 0 deletions tests/integration_tests/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@ bytes = "1.0"
[dev-dependencies]
tokio = { version = "1.0", features = ["macros", "rt-multi-thread", "net"] }
tokio-stream = { version = "0.1.5", features = ["net"] }
tower-service = "0.3"
hyper = "0.14"
futures = "0.3"

[build-dependencies]
tonic-build = { path = "../../tonic-build" }
144 changes: 144 additions & 0 deletions tests/integration_tests/tests/extensions.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
use futures_util::FutureExt;
use hyper::{Body, Request as HyperRequest, Response as HyperResponse};
use integration_tests::pb::{test_client, test_server, Input, Output};
use std::{
task::{Context, Poll},
time::Duration,
};
use tokio::sync::oneshot;
use tonic::{
body::BoxBody,
transport::{Endpoint, NamedService, Server},
Request, Response, Status,
};
use tower_service::Service;

struct ExtensionValue(i32);

#[tokio::test]
async fn setting_extension_from_interceptor() {
struct Svc;

#[tonic::async_trait]
impl test_server::Test for Svc {
async fn unary_call(&self, req: Request<Input>) -> Result<Response<Output>, Status> {
let value = req.extensions().get::<ExtensionValue>().unwrap();
assert_eq!(value.0, 42);

Ok(Response::new(Output {}))
}
}

let svc = test_server::TestServer::with_interceptor(Svc, |mut req: Request<()>| {
req.extensions_mut().insert(ExtensionValue(42));
Ok(req)
});

let (tx, rx) = oneshot::channel::<()>();

let jh = tokio::spawn(async move {
Server::builder()
.add_service(svc)
.serve_with_shutdown("127.0.0.1:1323".parse().unwrap(), rx.map(drop))
.await
.unwrap();
});

tokio::time::sleep(Duration::from_millis(100)).await;

let channel = Endpoint::from_static("http://127.0.0.1:1323")
.connect()
.await
.unwrap();

let mut client = test_client::TestClient::new(channel);

client.unary_call(Input {}).await.unwrap();

tx.send(()).unwrap();

jh.await.unwrap();
}

#[tokio::test]
async fn setting_extension_from_tower() {
struct Svc;

#[tonic::async_trait]
impl test_server::Test for Svc {
async fn unary_call(&self, req: Request<Input>) -> Result<Response<Output>, Status> {
let value = req.extensions().get::<ExtensionValue>().unwrap();
assert_eq!(value.0, 42);

Ok(Response::new(Output {}))
}
}

let svc = InterceptedService {
inner: test_server::TestServer::new(Svc),
};

let (tx, rx) = oneshot::channel::<()>();

let jh = tokio::spawn(async move {
Server::builder()
.add_service(svc)
.serve_with_shutdown("127.0.0.1:1324".parse().unwrap(), rx.map(drop))
.await
.unwrap();
});

tokio::time::sleep(Duration::from_millis(100)).await;

let channel = Endpoint::from_static("http://127.0.0.1:1324")
.connect()
.await
.unwrap();

let mut client = test_client::TestClient::new(channel);

client.unary_call(Input {}).await.unwrap();

tx.send(()).unwrap();

jh.await.unwrap();
}

#[derive(Debug, Clone)]
struct InterceptedService<S> {
inner: S,
}

impl<S> Service<HyperRequest<Body>> for InterceptedService<S>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Its crazy how complex this is :(

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I really wanna try and find a way to make this easier. Having to implement NamedService also complicates things a bit because you cannot take a middleware from tower and use it with tonics router.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, though I don't have a proper solution either :(

where
S: Service<HyperRequest<Body>, Response = HyperResponse<BoxBody>>
+ NamedService
+ Clone
+ Send
+ 'static,
S::Future: Send + 'static,
{
type Response = S::Response;
type Error = S::Error;
type Future = futures::future::BoxFuture<'static, Result<Self::Response, Self::Error>>;

fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}

fn call(&mut self, mut req: HyperRequest<Body>) -> Self::Future {
let clone = self.inner.clone();
let mut inner = std::mem::replace(&mut self.inner, clone);

req.extensions_mut().insert(ExtensionValue(42));

Box::pin(async move {
let response = inner.call(req).await?;
Ok(response)
})
}
}

impl<S: NamedService> NamedService for InterceptedService<S> {
const NAME: &'static str = S::NAME;
}
5 changes: 3 additions & 2 deletions tonic/src/client/grpc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,8 @@ impl<T> Grpc<T> {
M1: Send + Sync + 'static,
M2: Send + Sync + 'static,
{
let (mut parts, body) = self.streaming(request, path, codec).await?.into_parts();
let (mut parts, body, extensions) =
self.streaming(request, path, codec).await?.into_parts();

futures_util::pin_mut!(body);

Expand All @@ -114,7 +115,7 @@ impl<T> Grpc<T> {
parts.merge(trailers);
}

Ok(Response::from_parts(parts, message))
Ok(Response::from_parts(parts, message, extensions))
}

/// Send a server side streaming gRPC request.
Expand Down
71 changes: 71 additions & 0 deletions tonic/src/extensions.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
use std::fmt;

/// A type map of protocol extensions.
///
/// `Extensions` can be used by [`Interceptor`] and [`Request`] to store extra data derived from
/// the underlying protocol.
///
/// [`Interceptor`]: crate::Interceptor
/// [`Request`]: crate::Request
pub struct Extensions {
inner: http::Extensions,
}

impl Extensions {
pub(crate) fn new() -> Self {
Self {
inner: http::Extensions::new(),
}
}

/// Insert a type into this `Extensions`.
///
/// If a extension of this type already existed, it will
/// be returned.
#[inline]
pub fn insert<T: Send + Sync + 'static>(&mut self, val: T) -> Option<T> {
self.inner.insert(val)
}

/// Get a reference to a type previously inserted on this `Extensions`.
#[inline]
pub fn get<T: Send + Sync + 'static>(&self) -> Option<&T> {
self.inner.get()
}

/// Get a mutable reference to a type previously inserted on this `Extensions`.
#[inline]
pub fn get_mut<T: Send + Sync + 'static>(&mut self) -> Option<&mut T> {
self.inner.get_mut()
}

/// Remove a type from this `Extensions`.
///
/// If a extension of this type existed, it will be returned.
#[inline]
pub fn remove<T: Send + Sync + 'static>(&mut self) -> Option<T> {
self.inner.remove()
}

/// Clear the `Extensions` of all inserted extensions.
#[inline]
pub fn clear(&mut self) {
self.inner.clear()
}

#[inline]
pub(crate) fn from_http(http: http::Extensions) -> Self {
Self { inner: http }
}

#[inline]
pub(crate) fn into_http(self) -> http::Extensions {
self.inner
}
}

impl fmt::Debug for Extensions {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Extensions").finish()
}
}
3 changes: 3 additions & 0 deletions tonic/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
//! [`transport`]: transport/index.html

#![recursion_limit = "256"]
#![allow(clippy::inconsistent_struct_constructor)]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the heck is this lint lmao

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The most pedantic lint ever https://rust-lang.github.io/rust-clippy/master/#inconsistent_struct_constructor. rust-analyzer was complaining about it.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lmaooooo

#![warn(
missing_debug_implementations,
missing_docs,
Expand All @@ -87,6 +88,7 @@ pub mod server;
#[cfg_attr(docsrs, doc(cfg(feature = "transport")))]
pub mod transport;

mod extensions;
mod interceptor;
mod macros;
mod request;
Expand All @@ -100,6 +102,7 @@ pub use async_trait::async_trait;

#[doc(inline)]
pub use codec::Streaming;
pub use extensions::Extensions;
pub use interceptor::Interceptor;
pub use request::{IntoRequest, IntoStreamingRequest, Request};
pub use response::Response;
Expand Down
Loading