diff --git a/tests/integration_tests/tests/timeout.rs b/tests/integration_tests/tests/timeout.rs index 2ceca4abd..7a4b94e6b 100644 --- a/tests/integration_tests/tests/timeout.rs +++ b/tests/integration_tests/tests/timeout.rs @@ -45,7 +45,7 @@ async fn cancelation_on_timeout() { } #[tokio::test] -async fn picks_the_shortest_timeout() { +async fn picks_server_timeout_if_thats_sorter() { struct Svc; #[tonic::async_trait] @@ -80,10 +80,50 @@ async fn picks_the_shortest_timeout() { // 10 hours .insert("grpc-timeout", "10H".parse().unwrap()); - // TODO(david): for some reason this fails with "h2 protocol error: protocol error: unexpected - // internal error encountered". Seems to be happening on `master` as well. Bug? let res = client.unary_call(req).await; - dbg!(&res); let err = res.unwrap_err(); assert!(err.message().contains("Timeout expired")); + assert_eq!(err.code(), Code::Cancelled); +} + +#[tokio::test] +async fn picks_client_timeout_if_thats_sorter() { + struct Svc; + + #[tonic::async_trait] + impl test_server::Test for Svc { + async fn unary_call(&self, _req: Request) -> Result, Status> { + // Wait for a time longer than the timeout + tokio::time::sleep(Duration::from_secs(1)).await; + Ok(Response::new(Output {})) + } + } + + let svc = test_server::TestServer::new(Svc); + + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + + tokio::spawn(async move { + Server::builder() + .timeout(Duration::from_secs(9001)) + .add_service(svc) + .serve_with_incoming(tokio_stream::wrappers::TcpListenerStream::new(listener)) + .await + .unwrap(); + }); + + let mut client = test_client::TestClient::connect(format!("http://{}", addr)) + .await + .unwrap(); + + let mut req = Request::new(Input {}); + req.metadata_mut() + // 100 ms + .insert("grpc-timeout", "100m".parse().unwrap()); + + let res = client.unary_call(req).await; + let err = res.unwrap_err(); + assert!(err.message().contains("Timeout expired")); + assert_eq!(err.code(), Code::Cancelled); } diff --git a/tonic/Cargo.toml b/tonic/Cargo.toml index 0d417e1e6..5e494d5c9 100644 --- a/tonic/Cargo.toml +++ b/tonic/Cargo.toml @@ -69,7 +69,7 @@ h2 = { version = "0.3", optional = true } hyper = { version = "0.14.2", features = ["full"], optional = true } tokio = { version = "1.0.1", features = ["net"], optional = true } tokio-stream = "0.1" -tower = { version = "0.4.4", features = ["balance", "buffer", "discover", "limit", "load", "make", "timeout", "util"], optional = true } +tower = { version = "0.4.7", features = ["balance", "buffer", "discover", "limit", "load", "make", "timeout", "util"], optional = true } tracing-futures = { version = "0.2", optional = true } # rustls diff --git a/tonic/src/status.rs b/tonic/src/status.rs index f1d735c9d..071e1f9cd 100644 --- a/tonic/src/status.rs +++ b/tonic/src/status.rs @@ -313,7 +313,7 @@ impl Status { Status::try_from_error(err).unwrap_or_else(|| Status::new(Code::Unknown, err.to_string())) } - fn try_from_error(err: &(dyn Error + 'static)) -> Option { + pub(crate) fn try_from_error(err: &(dyn Error + 'static)) -> Option { let mut cause = Some(err); while let Some(err) = cause { diff --git a/tonic/src/transport/server/mod.rs b/tonic/src/transport/server/mod.rs index 6f2729d09..c2ad4d5ef 100644 --- a/tonic/src/transport/server/mod.rs +++ b/tonic/src/transport/server/mod.rs @@ -2,6 +2,7 @@ mod conn; mod incoming; +mod recover_error; #[cfg(feature = "tls")] #[cfg_attr(docsrs, doc(cfg(feature = "tls")))] mod tls; @@ -21,8 +22,9 @@ pub(crate) use tokio_rustls::server::TlsStream; #[cfg(feature = "tls")] use crate::transport::Error; +use self::recover_error::RecoverError; use super::{ - service::{Or, Routes, ServerIo}, + service::{GrpcTimeout, Or, Routes, ServerIo}, BoxFuture, }; use crate::{body::BoxBody, request::ConnectionInfo}; @@ -42,10 +44,7 @@ use std::{ time::Duration, }; use tokio::io::{AsyncRead, AsyncWrite}; -use tower::{ - limit::concurrency::ConcurrencyLimitLayer, timeout::TimeoutLayer, util::Either, Service, - ServiceBuilder, -}; +use tower::{limit::concurrency::ConcurrencyLimitLayer, util::Either, Service, ServiceBuilder}; use tracing_futures::{Instrument, Instrumented}; type BoxService = tower::util::BoxService, Response, crate::Error>; @@ -655,8 +654,9 @@ where Box::pin(async move { let svc = ServiceBuilder::new() + .layer_fn(RecoverError::new) .option_layer(concurrency_limit.map(ConcurrencyLimitLayer::new)) - .option_layer(timeout.map(TimeoutLayer::new)) + .layer_fn(|s| GrpcTimeout::new(s, timeout)) .service(svc); let svc = BoxService::new(Svc { diff --git a/tonic/src/transport/server/recover_error.rs b/tonic/src/transport/server/recover_error.rs new file mode 100644 index 000000000..4c7fdc666 --- /dev/null +++ b/tonic/src/transport/server/recover_error.rs @@ -0,0 +1,75 @@ +use crate::{body::BoxBody, Status}; +use futures_util::ready; +use http::Response; +use pin_project::pin_project; +use std::{ + future::Future, + pin::Pin, + task::{Context, Poll}, +}; +use tower::Service; + +/// Middleware that attempts to recover from service errors by turning them into a response with a +/// `grpc-status` and an empty body. +#[derive(Debug, Clone)] +pub(crate) struct RecoverError { + inner: S, +} + +impl RecoverError { + pub(crate) fn new(inner: S) -> Self { + Self { inner } + } +} + +impl Service for RecoverError +where + S: Service>, + S::Error: Into, +{ + type Response = Response; + type Error = crate::Error; + type Future = ResponseFuture; + + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + self.inner.poll_ready(cx).map_err(Into::into) + } + + fn call(&mut self, req: R) -> Self::Future { + ResponseFuture { + inner: self.inner.call(req), + } + } +} + +#[pin_project] +pub(crate) struct ResponseFuture { + #[pin] + inner: F, +} + +impl Future for ResponseFuture +where + F: Future, E>>, + E: Into, +{ + type Output = Result, crate::Error>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let result: Result, crate::Error> = + ready!(self.project().inner.poll(cx)).map_err(Into::into); + + match result { + Ok(res) => Poll::Ready(Ok(res)), + Err(err) => { + if let Some(status) = Status::try_from_error(&*err) { + let mut res = Response::new(BoxBody::empty()); + status.add_header(res.headers_mut()).unwrap(); + Poll::Ready(Ok(res)) + } else { + Poll::Ready(Err(err)) + } + } + } + } +} diff --git a/tonic/src/transport/service/mod.rs b/tonic/src/transport/service/mod.rs index b8386561f..4e1d89c0c 100644 --- a/tonic/src/transport/service/mod.rs +++ b/tonic/src/transport/service/mod.rs @@ -14,6 +14,7 @@ pub(crate) use self::add_origin::AddOrigin; pub(crate) use self::connection::Connection; pub(crate) use self::connector::connector; pub(crate) use self::discover::DynamicServiceStream; +pub(crate) use self::grpc_timeout::GrpcTimeout; pub(crate) use self::io::ServerIo; pub(crate) use self::router::{Or, Routes}; #[cfg(feature = "tls")]