Skip to content

Commit

Permalink
ability to customize stream errors
Browse files Browse the repository at this point in the history
  • Loading branch information
jhump committed May 9, 2019
1 parent 15c7b2e commit c11240b
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 12 deletions.
36 changes: 24 additions & 12 deletions runtime/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,18 +53,18 @@ func ForwardResponseStream(ctx context.Context, mux *ServeMux, marshaler Marshal
return
}
if err != nil {
handleForwardResponseStreamError(wroteHeader, marshaler, w, err)
handleForwardResponseStreamError(ctx, wroteHeader, marshaler, w, req, mux, err)
return
}
if err := handleForwardResponseOptions(ctx, w, resp, opts); err != nil {
handleForwardResponseStreamError(wroteHeader, marshaler, w, err)
handleForwardResponseStreamError(ctx, wroteHeader, marshaler, w, req, mux, err)
return
}

buf, err := marshaler.Marshal(streamChunk(resp, nil))
if err != nil {
grpclog.Infof("Failed to marshal response chunk: %v", err)
handleForwardResponseStreamError(wroteHeader, marshaler, w, err)
handleForwardResponseStreamError(ctx, wroteHeader, marshaler, w, req, mux, err)
return
}
if _, err = w.Write(buf); err != nil {
Expand Down Expand Up @@ -168,18 +168,30 @@ func handleForwardResponseOptions(ctx context.Context, w http.ResponseWriter, re
return nil
}

func handleForwardResponseStreamError(wroteHeader bool, marshaler Marshaler, w http.ResponseWriter, err error) {
buf, merr := marshaler.Marshal(streamChunk(nil, err))
if merr != nil {
grpclog.Infof("Failed to marshal an error: %v", merr)
func handleForwardResponseStreamError(ctx context.Context, wroteHeader bool, marshaler Marshaler, w http.ResponseWriter, req *http.Request, mux *ServeMux, err error) {
if !wroteHeader {
// if we haven't already written header, use normal error handler to render
HTTPError(ctx, mux, marshaler, w, req, err)
return
}
if !wroteHeader {
s, ok := status.FromError(err)
if !ok {
s = status.New(codes.Unknown, err.Error())

// otherwise, render the error as a stream chunk in the response
var chunk map[string]proto.Message
if mux.streamErrorHandler != nil {
serr := mux.streamErrorHandler(ctx, err)
if serr == nil {
// TODO: log error about misbehaving error handler?
chunk = streamChunk(nil, err)
} else {
chunk = map[string]proto.Message{"error": (*internal.StreamError)(serr)}
}
w.WriteHeader(HTTPStatusFromCode(s.Code()))
} else {
chunk = streamChunk(nil, err)
}
buf, merr := marshaler.Marshal(chunk)
if merr != nil {
grpclog.Infof("Failed to marshal an error: %v", merr)
return
}
if _, werr := w.Write(buf); werr != nil {
grpclog.Infof("Failed to notify error to client: %v", werr)
Expand Down
15 changes: 15 additions & 0 deletions runtime/mux.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ type ServeMux struct {
incomingHeaderMatcher HeaderMatcherFunc
outgoingHeaderMatcher HeaderMatcherFunc
metadataAnnotators []func(context.Context, *http.Request) metadata.MD
streamErrorHandler StreamErrorHandlerFunc
protoErrorHandler ProtoErrorHandlerFunc
disablePathLengthFallback bool
}
Expand Down Expand Up @@ -110,6 +111,20 @@ func WithDisablePathLengthFallback() ServeMuxOption {
}
}

// WithStreamErrorHandler returns a ServeMuxOption that will use the given custom stream
// error handler, which allows for customizing the error trailer for server-streaming
// calls.
//
// For stream errors that occur before any response has been written, the mux's
// ProtoErrorHandler will be invoked. However, once data has been written, the errors must
// be handled differently: they must included in the response body. The response body's
// final message will include the error details returned by the stream error handler.
func WithStreamErrorHandler(fn StreamErrorHandlerFunc) ServeMuxOption {
return func(serveMux *ServeMux) {
serveMux.streamErrorHandler = fn
}
}

// NewServeMux returns a new ServeMux whose internal mapping is empty.
func NewServeMux(opts ...ServeMuxOption) *ServeMux {
serveMux := &ServeMux{
Expand Down
9 changes: 9 additions & 0 deletions runtime/proto_errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,17 @@ import (
"google.golang.org/grpc/codes"
"google.golang.org/grpc/grpclog"
"google.golang.org/grpc/status"
"github.com/grpc-ecosystem/grpc-gateway/runtime/internal"
)

// StreamErrorHandlerFunc accepts an error as a gRPC error generated via status package and translates it into a
// a proto struct used to represent error at the end of a stream.
type StreamErrorHandlerFunc func(context.Context, error) *StreamError

// StreamError is the payload for the final message in a server stream in the event that the server returns an
// error after a response message has already been sent.
type StreamError internal.StreamError

// ProtoErrorHandlerFunc handles the error as a gRPC error generated via status package and replies to the request.
type ProtoErrorHandlerFunc func(context.Context, *ServeMux, Marshaler, http.ResponseWriter, *http.Request, error)

Expand Down

0 comments on commit c11240b

Please sign in to comment.