diff --git a/tests/integration_tests/tests/timeout.rs b/tests/integration_tests/tests/timeout.rs
index 2ceca4abd..7a4b94e6b 100644
--- a/tests/integration_tests/tests/timeout.rs
+++ b/tests/integration_tests/tests/timeout.rs
@@ -45,7 +45,7 @@ async fn cancelation_on_timeout() {
}
#[tokio::test]
-async fn picks_the_shortest_timeout() {
+async fn picks_server_timeout_if_thats_sorter() {
struct Svc;
#[tonic::async_trait]
@@ -80,10 +80,50 @@ async fn picks_the_shortest_timeout() {
// 10 hours
.insert("grpc-timeout", "10H".parse().unwrap());
- // TODO(david): for some reason this fails with "h2 protocol error: protocol error: unexpected
- // internal error encountered". Seems to be happening on `master` as well. Bug?
let res = client.unary_call(req).await;
- dbg!(&res);
let err = res.unwrap_err();
assert!(err.message().contains("Timeout expired"));
+ assert_eq!(err.code(), Code::Cancelled);
+}
+
+#[tokio::test]
+async fn picks_client_timeout_if_thats_sorter() {
+ struct Svc;
+
+ #[tonic::async_trait]
+ impl test_server::Test for Svc {
+ async fn unary_call(&self, _req: Request) -> Result, Status> {
+ // Wait for a time longer than the timeout
+ tokio::time::sleep(Duration::from_secs(1)).await;
+ Ok(Response::new(Output {}))
+ }
+ }
+
+ let svc = test_server::TestServer::new(Svc);
+
+ let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
+ let addr = listener.local_addr().unwrap();
+
+ tokio::spawn(async move {
+ Server::builder()
+ .timeout(Duration::from_secs(9001))
+ .add_service(svc)
+ .serve_with_incoming(tokio_stream::wrappers::TcpListenerStream::new(listener))
+ .await
+ .unwrap();
+ });
+
+ let mut client = test_client::TestClient::connect(format!("http://{}", addr))
+ .await
+ .unwrap();
+
+ let mut req = Request::new(Input {});
+ req.metadata_mut()
+ // 100 ms
+ .insert("grpc-timeout", "100m".parse().unwrap());
+
+ let res = client.unary_call(req).await;
+ let err = res.unwrap_err();
+ assert!(err.message().contains("Timeout expired"));
+ assert_eq!(err.code(), Code::Cancelled);
}
diff --git a/tonic/Cargo.toml b/tonic/Cargo.toml
index 0d417e1e6..5e494d5c9 100644
--- a/tonic/Cargo.toml
+++ b/tonic/Cargo.toml
@@ -69,7 +69,7 @@ h2 = { version = "0.3", optional = true }
hyper = { version = "0.14.2", features = ["full"], optional = true }
tokio = { version = "1.0.1", features = ["net"], optional = true }
tokio-stream = "0.1"
-tower = { version = "0.4.4", features = ["balance", "buffer", "discover", "limit", "load", "make", "timeout", "util"], optional = true }
+tower = { version = "0.4.7", features = ["balance", "buffer", "discover", "limit", "load", "make", "timeout", "util"], optional = true }
tracing-futures = { version = "0.2", optional = true }
# rustls
diff --git a/tonic/src/status.rs b/tonic/src/status.rs
index f1d735c9d..071e1f9cd 100644
--- a/tonic/src/status.rs
+++ b/tonic/src/status.rs
@@ -313,7 +313,7 @@ impl Status {
Status::try_from_error(err).unwrap_or_else(|| Status::new(Code::Unknown, err.to_string()))
}
- fn try_from_error(err: &(dyn Error + 'static)) -> Option {
+ pub(crate) fn try_from_error(err: &(dyn Error + 'static)) -> Option {
let mut cause = Some(err);
while let Some(err) = cause {
diff --git a/tonic/src/transport/server/mod.rs b/tonic/src/transport/server/mod.rs
index 6f2729d09..c2ad4d5ef 100644
--- a/tonic/src/transport/server/mod.rs
+++ b/tonic/src/transport/server/mod.rs
@@ -2,6 +2,7 @@
mod conn;
mod incoming;
+mod recover_error;
#[cfg(feature = "tls")]
#[cfg_attr(docsrs, doc(cfg(feature = "tls")))]
mod tls;
@@ -21,8 +22,9 @@ pub(crate) use tokio_rustls::server::TlsStream;
#[cfg(feature = "tls")]
use crate::transport::Error;
+use self::recover_error::RecoverError;
use super::{
- service::{Or, Routes, ServerIo},
+ service::{GrpcTimeout, Or, Routes, ServerIo},
BoxFuture,
};
use crate::{body::BoxBody, request::ConnectionInfo};
@@ -42,10 +44,7 @@ use std::{
time::Duration,
};
use tokio::io::{AsyncRead, AsyncWrite};
-use tower::{
- limit::concurrency::ConcurrencyLimitLayer, timeout::TimeoutLayer, util::Either, Service,
- ServiceBuilder,
-};
+use tower::{limit::concurrency::ConcurrencyLimitLayer, util::Either, Service, ServiceBuilder};
use tracing_futures::{Instrument, Instrumented};
type BoxService = tower::util::BoxService, Response, crate::Error>;
@@ -655,8 +654,9 @@ where
Box::pin(async move {
let svc = ServiceBuilder::new()
+ .layer_fn(RecoverError::new)
.option_layer(concurrency_limit.map(ConcurrencyLimitLayer::new))
- .option_layer(timeout.map(TimeoutLayer::new))
+ .layer_fn(|s| GrpcTimeout::new(s, timeout))
.service(svc);
let svc = BoxService::new(Svc {
diff --git a/tonic/src/transport/server/recover_error.rs b/tonic/src/transport/server/recover_error.rs
new file mode 100644
index 000000000..4c7fdc666
--- /dev/null
+++ b/tonic/src/transport/server/recover_error.rs
@@ -0,0 +1,75 @@
+use crate::{body::BoxBody, Status};
+use futures_util::ready;
+use http::Response;
+use pin_project::pin_project;
+use std::{
+ future::Future,
+ pin::Pin,
+ task::{Context, Poll},
+};
+use tower::Service;
+
+/// Middleware that attempts to recover from service errors by turning them into a response with a
+/// `grpc-status` and an empty body.
+#[derive(Debug, Clone)]
+pub(crate) struct RecoverError {
+ inner: S,
+}
+
+impl RecoverError {
+ pub(crate) fn new(inner: S) -> Self {
+ Self { inner }
+ }
+}
+
+impl Service for RecoverError
+where
+ S: Service>,
+ S::Error: Into,
+{
+ type Response = Response;
+ type Error = crate::Error;
+ type Future = ResponseFuture;
+
+ fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> {
+ self.inner.poll_ready(cx).map_err(Into::into)
+ }
+
+ fn call(&mut self, req: R) -> Self::Future {
+ ResponseFuture {
+ inner: self.inner.call(req),
+ }
+ }
+}
+
+#[pin_project]
+pub(crate) struct ResponseFuture {
+ #[pin]
+ inner: F,
+}
+
+impl Future for ResponseFuture
+where
+ F: Future