Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add h2::Error as a source for tonic::Status when converting from h2::Error #612

Merged
merged 14 commits into from
Jun 23, 2021
2 changes: 1 addition & 1 deletion tonic/src/client/grpc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ impl<T> Grpc<T> {
.inner
.call(request)
.await
.map_err(|err| Status::from_error(&*(err.into())))?;
.map_err(|err| Status::from_error(err.into()))?;

let status_code = response.status();
let trailers_only_status = Status::from_header_map(response.headers());
Expand Down
8 changes: 4 additions & 4 deletions tonic/src/codec/decode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -149,9 +149,9 @@ impl<T> Streaming<T> {
// them manually.
let map = future::poll_fn(|cx| Pin::new(&mut self.body).poll_trailers(cx))
.await
.map_err(|e| Status::from_error(&e))?;
.map_err(|e| Status::from_error(Box::new(e)));

Ok(map.map(MetadataMap::from_headers))
map.map(|x| x.map(MetadataMap::from_headers))
}

fn decode_chunk(&mut self) -> Result<Option<T>, Status> {
Expand Down Expand Up @@ -232,7 +232,7 @@ impl<T> Stream for Streaming<T> {
Some(Err(e)) => {
let err: crate::Error = e.into();
debug!("decoder inner stream error: {:?}", err);
let status = Status::from_error(&*err);
let status = Status::from_error(err);
return Poll::Ready(Some(Err(status)));
}
None => None,
Expand Down Expand Up @@ -266,7 +266,7 @@ impl<T> Stream for Streaming<T> {
Err(e) => {
let err: crate::Error = e.into();
debug!("decoder inner trailers error: {:?}", err);
let status = Status::from_error(&*err);
let status = Status::from_error(err);
return Some(Err(status)).into();
}
}
Expand Down
2 changes: 1 addition & 1 deletion tonic/src/codec/prost.rs
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ mod tests {

let msg = Vec::from(&[0u8; 1024][..]);

let messages = std::iter::repeat(Ok::<_, Status>(msg)).take(10000);
let messages = std::iter::repeat_with(move || Ok::<_, Status>(msg.clone())).take(10000);
let source = futures_util::stream::iter(messages);

let body = encode_server(encoder, source);
Expand Down
118 changes: 83 additions & 35 deletions tonic/src/status.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ const GRPC_STATUS_DETAILS_HEADER: &str = "grpc-status-details-bin";
/// assert_eq!(status1.code(), Code::InvalidArgument);
/// assert_eq!(status1.code(), status2.code());
/// ```
#[derive(Clone)]
pub struct Status {
/// The gRPC status code, found in the `grpc-status` header.
code: Code,
Expand All @@ -45,6 +44,8 @@ pub struct Status {
/// If the metadata contains any headers with names reserved either by the gRPC spec
/// or by `Status` fields above, they will be ignored.
metadata: MetadataMap,
/// Optional underlying error.
source: Option<Box<dyn Error + Send + Sync + 'static>>,
}

/// gRPC status codes used by [`Status`].
Expand Down Expand Up @@ -162,6 +163,7 @@ impl Status {
message: message.into(),
details: Bytes::new(),
metadata: MetadataMap::new(),
source: None,
}
}

Expand Down Expand Up @@ -302,38 +304,34 @@ impl Status {
}

#[cfg_attr(not(feature = "transport"), allow(dead_code))]
pub(crate) fn from_error(err: &(dyn Error + 'static)) -> Status {
Status::try_from_error(err).unwrap_or_else(|| Status::new(Code::Unknown, err.to_string()))
pub(crate) fn from_error(err: Box<dyn Error + Send + Sync + 'static>) -> Status {
Status::try_from_error(err)
.unwrap_or_else(|err| Status::new(Code::Unknown, err.to_string()))
}

pub(crate) fn try_from_error(err: &(dyn Error + 'static)) -> Option<Status> {
let mut cause = Some(err);

while let Some(err) = cause {
if let Some(status) = err.downcast_ref::<Status>() {
return Some(Status {
code: status.code,
message: status.message.clone(),
details: status.details.clone(),
metadata: status.metadata.clone(),
});
pub(crate) fn try_from_error(
err: Box<dyn Error + Send + Sync + 'static>,
) -> Result<Status, Box<dyn Error + Send + Sync + 'static>> {
let err = match err.downcast::<Status>() {
Ok(status) => {
return Ok(*status);
}
Err(err) => err,
};

#[cfg(feature = "transport")]
{
if let Some(h2) = err.downcast_ref::<h2::Error>() {
return Some(Status::from_h2_error(h2));
}

if let Some(timeout) = err.downcast_ref::<crate::transport::TimeoutExpired>() {
return Some(Status::cancelled(timeout.to_string()));
}
#[cfg(feature = "transport")]
let err = match err.downcast::<h2::Error>() {
Ok(h2) => {
return Ok(Status::from_h2_error(&*h2));
}
Err(err) => err,
};

cause = err.source();
if let Some(status) = find_status_in_source_chain(&*err) {
return Ok(status);
}

None
Err(err)
}

// FIXME: bubble this into `transport` and expose generic http2 reasons.
Expand All @@ -356,7 +354,13 @@ impl Status {
_ => Code::Unknown,
};

Status::new(code, format!("h2 protocol error: {}", err))
let mut status = Self::new(code, format!("h2 protocol error: {}", err));
let error = err
.reason()
.map(h2::Error::from)
.map(|err| Box::new(err) as Box<dyn Error + Send + Sync + 'static>);
davidpdrsn marked this conversation as resolved.
Show resolved Hide resolved
status.source = error;
status
}

#[cfg(feature = "transport")]
Expand All @@ -374,7 +378,8 @@ impl Status {
where
E: Into<Box<dyn Error + Send + Sync>>,
{
Status::from_error(&*err.into())
let err: Box<dyn Error + Send + Sync> = err.into();
Status::from_error(err)
}

/// Extract a `Status` from a hyper `HeaderMap`.
Expand Down Expand Up @@ -410,6 +415,7 @@ impl Status {
message,
details,
metadata: MetadataMap::from_headers(other_headers),
source: None,
},
Err(err) => {
warn!("Error deserializing status message header: {}", err);
Expand All @@ -418,6 +424,7 @@ impl Status {
message: format!("Error deserializing status message header: {}", err),
details,
metadata: MetadataMap::from_headers(other_headers),
source: None,
}
}
}
Expand Down Expand Up @@ -505,6 +512,7 @@ impl Status {
message: message.into(),
details,
metadata,
source: None,
}
}

Expand All @@ -524,6 +532,32 @@ impl Status {
}
}

fn find_status_in_source_chain(err: &(dyn Error + 'static)) -> Option<Status> {
let mut source = Some(err);

while let Some(err) = source {
if let Some(status) = err.downcast_ref::<Status>() {
return Some(Status {
code: status.code,
message: status.message.clone(),
details: status.details.clone(),
metadata: status.metadata.clone(),
// Since `Status` is not `Clone`, any `source` on the original Status
// cannot be cloned so must remain with the original `Status`.
source: None,
});
}

#[cfg(feature = "transport")]
if let Some(timeout) = err.downcast_ref::<crate::transport::TimeoutExpired>() {
return Some(Status::cancelled(timeout.to_string()));
}

source = err.source();
}

None
}
impl fmt::Debug for Status {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
// A manual impl to reduce the noise of frequently empty fields.
Expand All @@ -543,6 +577,8 @@ impl fmt::Debug for Status {
builder.field("metadata", &self.metadata);
}

builder.field("source", &self.source);

builder.finish()
}
}
Expand Down Expand Up @@ -609,7 +645,11 @@ impl fmt::Display for Status {
}
}

impl Error for Status {}
impl Error for Status {
fn source(&self) -> Option<&(dyn Error + 'static)> {
self.source.as_ref().map(|err| (&**err) as _)
}
}

///
/// Take the `Status` value from `trailers` if it is available, else from `status_code`.
Expand Down Expand Up @@ -775,25 +815,25 @@ mod tests {
#[test]
fn from_error_status() {
let orig = Status::new(Code::OutOfRange, "weeaboo");
let found = Status::from_error(&orig);
let found = Status::from_error(Box::new(orig));

assert_eq!(orig.code(), found.code());
assert_eq!(orig.message(), found.message());
assert_eq!(found.code(), Code::OutOfRange);
assert_eq!(found.message(), "weeaboo");
}

#[test]
fn from_error_unknown() {
let orig: Error = "peek-a-boo".into();
let found = Status::from_error(&*orig);
let found = Status::from_error(orig);

assert_eq!(found.code(), Code::Unknown);
assert_eq!(found.message(), orig.to_string());
assert_eq!(found.message(), "peek-a-boo".to_string());
}

#[test]
fn from_error_nested() {
let orig = Nested(Box::new(Status::new(Code::OutOfRange, "weeaboo")));
let found = Status::from_error(&orig);
let found = Status::from_error(Box::new(orig));

assert_eq!(found.code(), Code::OutOfRange);
assert_eq!(found.message(), "weeaboo");
Expand All @@ -802,10 +842,18 @@ mod tests {
#[test]
#[cfg(feature = "transport")]
fn from_error_h2() {
use std::error::Error as _;

let orig = h2::Error::from(h2::Reason::CANCEL);
let found = Status::from_error(&orig);
let found = Status::from_error(Box::new(orig));

assert_eq!(found.code(), Code::Cancelled);

let source = found
.source()
.and_then(|err| err.downcast_ref::<h2::Error>())
.unwrap();
assert_eq!(source.reason(), Some(h2::Reason::CANCEL));
}

#[test]
Expand Down
9 changes: 4 additions & 5 deletions tonic/src/transport/server/recover_error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,15 +67,14 @@ where
let response = response.map(MaybeEmptyBody::full);
Poll::Ready(Ok(response))
}
Err(err) => {
if let Some(status) = Status::try_from_error(&*err) {
Err(err) => match Status::try_from_error(err) {
Ok(status) => {
let mut res = Response::new(MaybeEmptyBody::empty());
status.add_header(res.headers_mut()).unwrap();
Poll::Ready(Ok(res))
} else {
Poll::Ready(Err(err))
}
}
Err(err) => Poll::Ready(Err(err)),
},
}
}
}
Expand Down