Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ability to customize stream errors #930

Merged
merged 5 commits into from
May 9, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 67 additions & 0 deletions docs/_docs/customizingyourgateway.md
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,73 @@ if err := http.ListenAndServe(":8080", tracingWrapper(mux)); err != nil {
## Error handler
http://mycodesmells.com/post/grpc-gateway-error-handler

## Stream Error Handler
The error handler described in the previous section applies only
to RPC methods that have a unary response.

When the method has a streaming response, grpc-gateway handles
that by emitting a newline-separated stream of "chunks". Each
chunk is an envelope that can container either a response message
or an error. Only the last chunk will include an error, and only
when the RPC handler ends abnormally (i.e. with an error code).

Because of the way the errors are included in the response body,
the other error handler signature is insufficient. So for server
streams, you must install a _different_ error handler:

```go
mux := runtime.NewServeMux(
runtime.WithStreamErrorHandler(handleStreamError))
```

The signature of the handler is much more rigid because we need
to know the structure of the error payload in order to properly
encode the "chunk" schema into a Swagger/OpenAPI spec.

So the function must return a `*runtime.StreamError`. The handler
can choose to omit some fields and can filter/transform the original
error, such as stripping stack traces from error messages.

Here's an example custom handler:
```go
// handleStreamError overrides default behavior for computing an error
// message for a server stream.
//
// It uses a default "502 Bad Gateway" HTTP code; only emits "safe"
// messages; and does not set gRPC code or details fields (so they will
// be omitted from the resulting JSON object that is sent to client).
func handleStreamError(ctx context.Context, err error) *runtime.StreamError {
code := http.StatusBadGateway
msg := "unexpected error"
if s, ok := status.FromError(err); ok {
code = runtime.HTTPStatusFromCode(s.Code())
// default message, based on the name of the gRPC code
msg = code.String()
// see if error details include "safe" message to send
// to external callers
for _, msg := s.Details() {
if safe, ok := msg.(*SafeMessage); ok {
msg = safe.Text
break
}
}
}
return &runtime.StreamError{
HttpCode: int32(code),
HttpStatus: http.StatusText(code),
Message: msg,
}
}
```

If no custom handler is provided, the default stream error handler
will include any gRPC error attributes (code, message, detail messages),
if the error being reported includes them. If the error does not have
these attributes, a gRPC code of `Unknown` (2) is reported. The default
handler will also include an HTTP code and status, which is derived
from the gRPC code (or set to `"500 Internal Server Error"` when
the source error has no gRPC attributes).

## Replace a response forwarder per method
You might want to keep the behavior of the current marshaler but change only a message forwarding of a certain API method.

Expand Down
72 changes: 33 additions & 39 deletions runtime/handler.go
Original file line number Diff line number Diff line change
@@ -1,20 +1,20 @@
package runtime

import (
"errors"
"fmt"
"io"
"net/http"
"net/textproto"

"context"
"github.com/golang/protobuf/proto"
"github.com/golang/protobuf/ptypes/any"
"github.com/grpc-ecosystem/grpc-gateway/internal"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/grpclog"
"google.golang.org/grpc/status"
)

var errEmptyResponse = errors.New("empty response")

// ForwardResponseStream forwards the stream from gRPC server to REST client.
func ForwardResponseStream(ctx context.Context, mux *ServeMux, marshaler Marshaler, w http.ResponseWriter, req *http.Request, recv func() (proto.Message, error), opts ...func(context.Context, http.ResponseWriter, proto.Message) error) {
f, ok := w.(http.Flusher)
Expand Down Expand Up @@ -53,18 +53,18 @@ func ForwardResponseStream(ctx context.Context, mux *ServeMux, marshaler Marshal
return
}
if err != nil {
handleForwardResponseStreamError(wroteHeader, marshaler, w, err)
handleForwardResponseStreamError(ctx, wroteHeader, marshaler, w, req, mux, err)
return
}
if err := handleForwardResponseOptions(ctx, w, resp, opts); err != nil {
handleForwardResponseStreamError(wroteHeader, marshaler, w, err)
handleForwardResponseStreamError(ctx, wroteHeader, marshaler, w, req, mux, err)
return
}

buf, err := marshaler.Marshal(streamChunk(resp, nil))
buf, err := marshaler.Marshal(streamChunk(ctx, resp, mux.streamErrorHandler))
if err != nil {
grpclog.Infof("Failed to marshal response chunk: %v", err)
handleForwardResponseStreamError(wroteHeader, marshaler, w, err)
handleForwardResponseStreamError(ctx, wroteHeader, marshaler, w, req, mux, err)
return
}
if _, err = w.Write(buf); err != nil {
Expand Down Expand Up @@ -124,7 +124,7 @@ func ForwardResponseMessage(ctx context.Context, mux *ServeMux, marshaler Marsha

contentType := marshaler.ContentType()
// Check marshaler on run time in order to keep backwards compatability
// An interface param needs to be added to the ContentType() function on
// An interface param needs to be added to the ContentType() function on
// the Marshal interface to be able to remove this check
if httpBodyMarshaler, ok := marshaler.(*HTTPBodyMarshaler); ok {
contentType = httpBodyMarshaler.ContentTypeFromMessage(resp)
Expand Down Expand Up @@ -168,48 +168,42 @@ func handleForwardResponseOptions(ctx context.Context, w http.ResponseWriter, re
return nil
}

func handleForwardResponseStreamError(wroteHeader bool, marshaler Marshaler, w http.ResponseWriter, err error) {
buf, merr := marshaler.Marshal(streamChunk(nil, err))
func handleForwardResponseStreamError(ctx context.Context, wroteHeader bool, marshaler Marshaler, w http.ResponseWriter, req *http.Request, mux *ServeMux, err error) {
serr := streamError(ctx, mux.streamErrorHandler, err)
if !wroteHeader {
w.WriteHeader(int(serr.HttpCode))
}
johanbrandhorst marked this conversation as resolved.
Show resolved Hide resolved
buf, merr := marshaler.Marshal(errorChunk(serr))
if merr != nil {
grpclog.Infof("Failed to marshal an error: %v", merr)
return
}
if !wroteHeader {
s, ok := status.FromError(err)
if !ok {
s = status.New(codes.Unknown, err.Error())
}
w.WriteHeader(HTTPStatusFromCode(s.Code()))
}
if _, werr := w.Write(buf); werr != nil {
grpclog.Infof("Failed to notify error to client: %v", werr)
return
}
}

func streamChunk(result proto.Message, err error) map[string]proto.Message {
if err != nil {
grpcCode := codes.Unknown
grpcMessage := err.Error()
var grpcDetails []*any.Any
if s, ok := status.FromError(err); ok {
grpcCode = s.Code()
grpcMessage = s.Message()
grpcDetails = s.Proto().GetDetails()
}
httpCode := HTTPStatusFromCode(grpcCode)
return map[string]proto.Message{
"error": &internal.StreamError{
GrpcCode: int32(grpcCode),
HttpCode: int32(httpCode),
Message: grpcMessage,
HttpStatus: http.StatusText(httpCode),
Details: grpcDetails,
},
}
}
// streamChunk returns a chunk in a response stream for the given result. The
// given errHandler is used to render an error chunk if result is nil.
func streamChunk(ctx context.Context, result proto.Message, errHandler StreamErrorHandlerFunc) map[string]proto.Message {
if result == nil {
return streamChunk(nil, fmt.Errorf("empty response"))
return errorChunk(streamError(ctx, errHandler, errEmptyResponse))
}
return map[string]proto.Message{"result": result}
}

// streamError returns the payload for the final message in a response stream
// that represents the given err.
func streamError(ctx context.Context, errHandler StreamErrorHandlerFunc, err error) *StreamError {
serr := errHandler(ctx, err)
if serr != nil {
return serr
}
// TODO: log about misbehaving stream error handler?
return DefaultHTTPStreamErrorHandler(ctx, err)
}

func errorChunk(err *StreamError) map[string]proto.Message {
return map[string]proto.Message{"error": (*internal.StreamError)(err)}
}
johanbrandhorst marked this conversation as resolved.
Show resolved Hide resolved
16 changes: 16 additions & 0 deletions runtime/mux.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ type ServeMux struct {
incomingHeaderMatcher HeaderMatcherFunc
outgoingHeaderMatcher HeaderMatcherFunc
metadataAnnotators []func(context.Context, *http.Request) metadata.MD
streamErrorHandler StreamErrorHandlerFunc
protoErrorHandler ProtoErrorHandlerFunc
disablePathLengthFallback bool
}
Expand Down Expand Up @@ -110,12 +111,27 @@ func WithDisablePathLengthFallback() ServeMuxOption {
}
}

// WithStreamErrorHandler returns a ServeMuxOption that will use the given custom stream
// error handler, which allows for customizing the error trailer for server-streaming
// calls.
//
// For stream errors that occur before any response has been written, the mux's
// ProtoErrorHandler will be invoked. However, once data has been written, the errors must
// be handled differently: they must be included in the response body. The response body's
// final message will include the error details returned by the stream error handler.
func WithStreamErrorHandler(fn StreamErrorHandlerFunc) ServeMuxOption {
return func(serveMux *ServeMux) {
serveMux.streamErrorHandler = fn
}
}

// NewServeMux returns a new ServeMux whose internal mapping is empty.
func NewServeMux(opts ...ServeMuxOption) *ServeMux {
serveMux := &ServeMux{
handlers: make(map[string][]handler),
forwardResponseOptions: make([]func(context.Context, http.ResponseWriter, proto.Message) error, 0),
marshalers: makeMarshalerMIMERegistry(),
streamErrorHandler: DefaultHTTPStreamErrorHandler,
}

for _, opt := range opts {
Expand Down
40 changes: 38 additions & 2 deletions runtime/proto_errors.go
Original file line number Diff line number Diff line change
@@ -1,15 +1,25 @@
package runtime

import (
"context"
"io"
"net/http"

"context"
"github.com/golang/protobuf/ptypes/any"
"github.com/grpc-ecosystem/grpc-gateway/internal"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/grpclog"
"google.golang.org/grpc/status"
)

// StreamErrorHandlerFunc accepts an error as a gRPC error generated via status package and translates it into a
// a proto struct used to represent error at the end of a stream.
type StreamErrorHandlerFunc func(context.Context, error) *StreamError

// StreamError is the payload for the final message in a server stream in the event that the server returns an
// error after a response message has already been sent.
type StreamError internal.StreamError

// ProtoErrorHandlerFunc handles the error as a gRPC error generated via status package and replies to the request.
type ProtoErrorHandlerFunc func(context.Context, *ServeMux, Marshaler, http.ResponseWriter, *http.Request, error)

Expand All @@ -35,7 +45,7 @@ func DefaultHTTPProtoErrorHandler(ctx context.Context, mux *ServeMux, marshaler

contentType := marshaler.ContentType()
// Check marshaler on run time in order to keep backwards compatability
// An interface param needs to be added to the ContentType() function on
// An interface param needs to be added to the ContentType() function on
// the Marshal interface to be able to remove this check
if httpBodyMarshaler, ok := marshaler.(*HTTPBodyMarshaler); ok {
pb := s.Proto()
Expand Down Expand Up @@ -68,3 +78,29 @@ func DefaultHTTPProtoErrorHandler(ctx context.Context, mux *ServeMux, marshaler

handleForwardResponseTrailer(w, md)
}

// DefaultHTTPStreamErrorHandler converts the given err into a *StreamError via
// default logic.
//
// It extracts the gRPC status from err if possible. The fields of the status are
// used to populate the returned StreamError, and the HTTP status code is derived
// from the gRPC code via HTTPStatusFromCode. If the given err does not contain a
// gRPC status, an "Unknown" gRPC code is used and "Internal Server Error" HTTP code.
func DefaultHTTPStreamErrorHandler(_ context.Context, err error) *StreamError {
grpcCode := codes.Unknown
grpcMessage := err.Error()
var grpcDetails []*any.Any
if s, ok := status.FromError(err); ok {
grpcCode = s.Code()
grpcMessage = s.Message()
grpcDetails = s.Proto().GetDetails()
}
httpCode := HTTPStatusFromCode(grpcCode)
return &StreamError{
GrpcCode: int32(grpcCode),
HttpCode: int32(httpCode),
Message: grpcMessage,
HttpStatus: http.StatusText(httpCode),
Details: grpcDetails,
}
}