Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(transport): Support timeouts with "grpc-timeout" header #606

Merged
merged 12 commits into from
Apr 29, 2021
2 changes: 2 additions & 0 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ jobs:

env:
RUSTFLAGS: "-D warnings"
# run a lot of quickcheck iterations
QUICKCHECK_TESTS: 1000

steps:
- uses: hecrj/setup-rust-action@master
Expand Down
1 change: 1 addition & 0 deletions tests/integration_tests/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ bytes = "1.0"

[dev-dependencies]
tokio = { version = "1.0", features = ["macros", "rt-multi-thread", "net"] }
tokio-stream = { version = "0.1.5", features = ["net"] }

[build-dependencies]
tonic-build = { path = "../../tonic-build" }
92 changes: 92 additions & 0 deletions tests/integration_tests/tests/timeout.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
use integration_tests::pb::{test_client, test_server, Input, Output};
use std::{net::SocketAddr, time::Duration};
use tokio::net::TcpListener;
use tonic::{transport::Server, Code, Request, Response, Status};

#[tokio::test]
async fn cancelation_on_timeout() {
let addr = run_service_in_background(Duration::from_secs(1), Duration::from_secs(100)).await;

let mut client = test_client::TestClient::connect(format!("http://{}", addr))
.await
.unwrap();

let mut req = Request::new(Input {});
req.metadata_mut()
// 500 ms
.insert("grpc-timeout", "500m".parse().unwrap());

let res = client.unary_call(req).await;

let err = res.unwrap_err();
assert!(err.message().contains("Timeout expired"));
assert_eq!(err.code(), Code::Cancelled);
}

#[tokio::test]
async fn picks_server_timeout_if_thats_sorter() {
let addr = run_service_in_background(Duration::from_secs(1), Duration::from_millis(100)).await;

let mut client = test_client::TestClient::connect(format!("http://{}", addr))
.await
.unwrap();

let mut req = Request::new(Input {});
req.metadata_mut()
// 10 hours
.insert("grpc-timeout", "10H".parse().unwrap());

let res = client.unary_call(req).await;
let err = res.unwrap_err();
assert!(err.message().contains("Timeout expired"));
assert_eq!(err.code(), Code::Cancelled);
}

#[tokio::test]
async fn picks_client_timeout_if_thats_sorter() {
let addr = run_service_in_background(Duration::from_secs(1), Duration::from_secs(100)).await;

let mut client = test_client::TestClient::connect(format!("http://{}", addr))
.await
.unwrap();

let mut req = Request::new(Input {});
req.metadata_mut()
// 100 ms
.insert("grpc-timeout", "100m".parse().unwrap());

let res = client.unary_call(req).await;
let err = res.unwrap_err();
assert!(err.message().contains("Timeout expired"));
assert_eq!(err.code(), Code::Cancelled);
}

async fn run_service_in_background(latency: Duration, server_timeout: Duration) -> SocketAddr {
struct Svc {
latency: Duration,
}

#[tonic::async_trait]
impl test_server::Test for Svc {
async fn unary_call(&self, _req: Request<Input>) -> Result<Response<Output>, Status> {
tokio::time::sleep(self.latency).await;
Ok(Response::new(Output {}))
}
}

let svc = test_server::TestServer::new(Svc { latency });

let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();

tokio::spawn(async move {
Server::builder()
.timeout(server_timeout)
.add_service(svc)
.serve_with_incoming(tokio_stream::wrappers::TcpListenerStream::new(listener))
.await
.unwrap();
});

addr
}
7 changes: 5 additions & 2 deletions tonic/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ transport = [
"tokio",
"tower",
"tracing-futures",
"tokio/macros"
"tokio/macros",
"tokio/time",
]
tls = ["transport", "tokio-rustls"]
tls-roots = ["tls", "rustls-native-certs"]
Expand Down Expand Up @@ -68,7 +69,7 @@ h2 = { version = "0.3", optional = true }
hyper = { version = "0.14.2", features = ["full"], optional = true }
tokio = { version = "1.0.1", features = ["net"], optional = true }
tokio-stream = "0.1"
tower = { version = "0.4.4", features = ["balance", "buffer", "discover", "limit", "load", "make", "timeout", "util"], optional = true }
tower = { version = "0.4.7", features = ["balance", "buffer", "discover", "limit", "load", "make", "timeout", "util"], optional = true }
tracing-futures = { version = "0.2", optional = true }

# rustls
Expand All @@ -80,6 +81,8 @@ tokio = { version = "1.0", features = ["rt", "macros"] }
static_assertions = "1.0"
rand = "0.8"
bencher = "0.1.5"
quickcheck = "1.0"
quickcheck_macros = "1.0"

[package.metadata.docs.rs]
all-features = true
Expand Down
6 changes: 4 additions & 2 deletions tonic/src/metadata/map.rs
Original file line number Diff line number Diff line change
Expand Up @@ -194,15 +194,17 @@ pub struct OccupiedEntry<'a, VE: ValueEncoding> {
phantom: PhantomData<VE>,
}

#[cfg(feature = "transport")]
pub(crate) const GRPC_TIMEOUT_HEADER: &str = "grpc-timeout";

// ===== impl MetadataMap =====

impl MetadataMap {
// Headers reserved by the gRPC protocol.
pub(crate) const GRPC_RESERVED_HEADERS: [&'static str; 8] = [
pub(crate) const GRPC_RESERVED_HEADERS: [&'static str; 7] = [
davidpdrsn marked this conversation as resolved.
Show resolved Hide resolved
"te",
"user-agent",
"content-type",
"grpc-timeout",
"grpc-message",
"grpc-encoding",
"grpc-message-type",
Expand Down
3 changes: 3 additions & 0 deletions tonic/src/metadata/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ pub use self::value::AsciiMetadataValue;
pub use self::value::BinaryMetadataValue;
pub use self::value::MetadataValue;

#[cfg(feature = "transport")]
pub(crate) use self::map::GRPC_TIMEOUT_HEADER;
davidpdrsn marked this conversation as resolved.
Show resolved Hide resolved

/// The metadata::errors module contains types for errors that can occur
/// while handling gRPC custom metadata.
pub mod errors {
Expand Down
6 changes: 5 additions & 1 deletion tonic/src/status.rs
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,7 @@ impl Status {
Status::try_from_error(err).unwrap_or_else(|| Status::new(Code::Unknown, err.to_string()))
}

fn try_from_error(err: &(dyn Error + 'static)) -> Option<Status> {
pub(crate) fn try_from_error(err: &(dyn Error + 'static)) -> Option<Status> {
let mut cause = Some(err);

while let Some(err) = cause {
Expand All @@ -331,6 +331,10 @@ impl Status {
if let Some(h2) = err.downcast_ref::<h2::Error>() {
return Some(Status::from_h2_error(h2));
}

if let Some(timeout) = err.downcast_ref::<crate::transport::TimeoutExpired>() {
return Some(Status::cancelled(timeout.to_string()));
}
}

cause = err.source();
Expand Down
2 changes: 2 additions & 0 deletions tonic/src/transport/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,8 @@ pub use self::channel::{Channel, Endpoint};
pub use self::error::Error;
#[doc(inline)]
pub use self::server::{NamedService, Server};
#[doc(inline)]
pub use self::service::TimeoutExpired;
pub use self::tls::{Certificate, Identity};
pub use hyper::{Body, Uri};

Expand Down
12 changes: 6 additions & 6 deletions tonic/src/transport/server/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

mod conn;
mod incoming;
mod recover_error;
#[cfg(feature = "tls")]
#[cfg_attr(docsrs, doc(cfg(feature = "tls")))]
mod tls;
Expand All @@ -21,8 +22,9 @@ pub(crate) use tokio_rustls::server::TlsStream;
#[cfg(feature = "tls")]
use crate::transport::Error;

use self::recover_error::RecoverError;
use super::{
service::{Or, Routes, ServerIo},
service::{GrpcTimeout, Or, Routes, ServerIo},
BoxFuture,
};
use crate::{body::BoxBody, request::ConnectionInfo};
Expand All @@ -42,10 +44,7 @@ use std::{
time::Duration,
};
use tokio::io::{AsyncRead, AsyncWrite};
use tower::{
limit::concurrency::ConcurrencyLimitLayer, timeout::TimeoutLayer, util::Either, Service,
ServiceBuilder,
};
use tower::{limit::concurrency::ConcurrencyLimitLayer, util::Either, Service, ServiceBuilder};
use tracing_futures::{Instrument, Instrumented};

type BoxService = tower::util::BoxService<Request<Body>, Response<BoxBody>, crate::Error>;
Expand Down Expand Up @@ -655,8 +654,9 @@ where

Box::pin(async move {
let svc = ServiceBuilder::new()
.layer_fn(RecoverError::new)
.option_layer(concurrency_limit.map(ConcurrencyLimitLayer::new))
.option_layer(timeout.map(TimeoutLayer::new))
.layer_fn(|s| GrpcTimeout::new(s, timeout))
.service(svc);

let svc = BoxService::new(Svc {
Expand Down
75 changes: 75 additions & 0 deletions tonic/src/transport/server/recover_error.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
use crate::{body::BoxBody, Status};
use futures_util::ready;
use http::Response;
use pin_project::pin_project;
use std::{
future::Future,
pin::Pin,
task::{Context, Poll},
};
use tower::Service;

/// Middleware that attempts to recover from service errors by turning them into a response built
/// from the `Status`.
#[derive(Debug, Clone)]
pub(crate) struct RecoverError<S> {
inner: S,
}

impl<S> RecoverError<S> {
pub(crate) fn new(inner: S) -> Self {
Self { inner }
}
}

impl<S, R> Service<R> for RecoverError<S>
where
S: Service<R, Response = Response<BoxBody>>,
S::Error: Into<crate::Error>,
{
type Response = Response<BoxBody>;
type Error = crate::Error;
type Future = ResponseFuture<S::Future>;

fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx).map_err(Into::into)
}

fn call(&mut self, req: R) -> Self::Future {
ResponseFuture {
inner: self.inner.call(req),
}
}
}

#[pin_project]
pub(crate) struct ResponseFuture<F> {
#[pin]
inner: F,
}

impl<F, E> Future for ResponseFuture<F>
where
F: Future<Output = Result<Response<BoxBody>, E>>,
E: Into<crate::Error>,
{
type Output = Result<Response<BoxBody>, crate::Error>;

fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let result: Result<Response<BoxBody>, crate::Error> =
ready!(self.project().inner.poll(cx)).map_err(Into::into);

match result {
Ok(res) => Poll::Ready(Ok(res)),
Err(err) => {
if let Some(status) = Status::try_from_error(&*err) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be good to document this actually since its quite unclear. What errors would be returned as a status and which ones would fail? I do think this is the right approach though.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@seanmonstar what do you think here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any error that Status::try_from_error can downcast will be converted to a response. Do you think we should mention that here or refer to the docs for Status::try_from_error (and write some docs for that method as well maybe)?

let mut res = Response::new(BoxBody::empty());
status.add_header(res.headers_mut()).unwrap();
Poll::Ready(Ok(res))
} else {
Poll::Ready(Err(err))
}
}
}
}
}
5 changes: 2 additions & 3 deletions tonic/src/transport/service/connection.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use super::super::BoxFuture;
use super::{reconnect::Reconnect, AddOrigin, UserAgent};
use super::{grpc_timeout::GrpcTimeout, reconnect::Reconnect, AddOrigin, UserAgent};
use crate::{body::BoxBody, transport::Endpoint};
use http::Uri;
use hyper::client::conn::Builder;
Expand All @@ -14,7 +14,6 @@ use tower::load::Load;
use tower::{
layer::Layer,
limit::{concurrency::ConcurrencyLimitLayer, rate::RateLimitLayer},
timeout::TimeoutLayer,
util::BoxService,
ServiceBuilder, ServiceExt,
};
Expand Down Expand Up @@ -53,7 +52,7 @@ impl Connection {
let stack = ServiceBuilder::new()
.layer_fn(|s| AddOrigin::new(s, endpoint.uri.clone()))
.layer_fn(|s| UserAgent::new(s, endpoint.user_agent.clone()))
.option_layer(endpoint.timeout.map(TimeoutLayer::new))
.layer_fn(|s| GrpcTimeout::new(s, endpoint.timeout))
.option_layer(endpoint.concurrency_limit.map(ConcurrencyLimitLayer::new))
.option_layer(endpoint.rate_limit.map(|(l, d)| RateLimitLayer::new(l, d)))
.into_inner();
Expand Down
Loading