diff --git a/tests/integration_tests/tests/status.rs b/tests/integration_tests/tests/status.rs index df6bc4b3b..99e20c695 100644 --- a/tests/integration_tests/tests/status.rs +++ b/tests/integration_tests/tests/status.rs @@ -194,3 +194,52 @@ async fn status_from_server_stream_with_source() { let source = error.source().unwrap(); source.downcast_ref::().unwrap(); } + +#[tokio::test] +async fn message_and_then_status_from_server_stream() { + integration_tests::trace_init(); + + struct Svc; + + #[tonic::async_trait] + impl test_stream_server::TestStream for Svc { + type StreamCallStream = Stream; + + async fn stream_call( + &self, + _: Request, + ) -> Result, Status> { + let s = tokio_stream::iter(vec![ + Ok(OutputStream {}), + Err::(Status::unavailable("foo")), + ]); + Ok(Response::new(Box::pin(s) as Self::StreamCallStream)) + } + } + + let svc = test_stream_server::TestStreamServer::new(Svc); + + tokio::spawn(async move { + Server::builder() + .add_service(svc) + .serve("127.0.0.1:1340".parse().unwrap()) + .await + .unwrap(); + }); + + tokio::time::sleep(Duration::from_millis(100)).await; + + let mut client = test_stream_client::TestStreamClient::connect("http://127.0.0.1:1340") + .await + .unwrap(); + + let mut stream = client + .stream_call(InputStream {}) + .await + .unwrap() + .into_inner(); + + assert_eq!(stream.message().await.unwrap(), Some(OutputStream {})); + assert_eq!(stream.message().await.unwrap_err().message(), "foo"); + assert_eq!(stream.message().await.unwrap(), None); +} diff --git a/tonic/src/codec/encode.rs b/tonic/src/codec/encode.rs index 82b4eb61d..0b5de1bda 100644 --- a/tonic/src/codec/encode.rs +++ b/tonic/src/codec/encode.rs @@ -74,6 +74,7 @@ where max_message_size: Option, buf: BytesMut, uncompression_buf: BytesMut, + error: Option, } impl EncodedBytes @@ -112,6 +113,7 @@ where max_message_size, buf, uncompression_buf, + error: None, } } } @@ -131,9 +133,14 @@ where max_message_size, buf, uncompression_buf, + error, } = self.project(); let buffer_settings = encoder.buffer_settings(); + if let Some(status) = error.take() { + return Poll::Ready(Some(Err(status))); + } + loop { match source.as_mut().poll_next(cx) { Poll::Pending if buf.is_empty() => { @@ -163,7 +170,11 @@ where } } Poll::Ready(Some(Err(status))) => { - return Poll::Ready(Some(Err(status))); + if buf.is_empty() { + return Poll::Ready(Some(Err(status))); + } + *error = Some(status); + return Poll::Ready(Some(Ok(buf.split_to(buf.len()).freeze()))); } } }