Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ability to customize stream errors #930

Merged
merged 5 commits into from
May 9, 2019
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 24 additions & 12 deletions runtime/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,18 +53,18 @@ func ForwardResponseStream(ctx context.Context, mux *ServeMux, marshaler Marshal
return
}
if err != nil {
handleForwardResponseStreamError(wroteHeader, marshaler, w, err)
handleForwardResponseStreamError(ctx, wroteHeader, marshaler, w, req, mux, err)
return
}
if err := handleForwardResponseOptions(ctx, w, resp, opts); err != nil {
handleForwardResponseStreamError(wroteHeader, marshaler, w, err)
handleForwardResponseStreamError(ctx, wroteHeader, marshaler, w, req, mux, err)
return
}

buf, err := marshaler.Marshal(streamChunk(resp, nil))
if err != nil {
grpclog.Infof("Failed to marshal response chunk: %v", err)
handleForwardResponseStreamError(wroteHeader, marshaler, w, err)
handleForwardResponseStreamError(ctx, wroteHeader, marshaler, w, req, mux, err)
return
}
if _, err = w.Write(buf); err != nil {
Expand Down Expand Up @@ -168,18 +168,30 @@ func handleForwardResponseOptions(ctx context.Context, w http.ResponseWriter, re
return nil
}

func handleForwardResponseStreamError(wroteHeader bool, marshaler Marshaler, w http.ResponseWriter, err error) {
buf, merr := marshaler.Marshal(streamChunk(nil, err))
if merr != nil {
grpclog.Infof("Failed to marshal an error: %v", merr)
func handleForwardResponseStreamError(ctx context.Context, wroteHeader bool, marshaler Marshaler, w http.ResponseWriter, req *http.Request, mux *ServeMux, err error) {
if !wroteHeader {
// if we haven't already written header, use normal error handler to render
HTTPError(ctx, mux, marshaler, w, req, err)
return
johanbrandhorst marked this conversation as resolved.
Show resolved Hide resolved
}
if !wroteHeader {
s, ok := status.FromError(err)
if !ok {
s = status.New(codes.Unknown, err.Error())

// otherwise, render the error as a stream chunk in the response
var chunk map[string]proto.Message
if mux.streamErrorHandler != nil {
serr := mux.streamErrorHandler(ctx, err)
if serr == nil {
// TODO: log error about misbehaving error handler?
chunk = streamChunk(nil, err)
} else {
chunk = map[string]proto.Message{"error": (*internal.StreamError)(serr)}
johanbrandhorst marked this conversation as resolved.
Show resolved Hide resolved
}
w.WriteHeader(HTTPStatusFromCode(s.Code()))
} else {
chunk = streamChunk(nil, err)
}
johanbrandhorst marked this conversation as resolved.
Show resolved Hide resolved
buf, merr := marshaler.Marshal(chunk)
if merr != nil {
grpclog.Infof("Failed to marshal an error: %v", merr)
return
}
if _, werr := w.Write(buf); werr != nil {
grpclog.Infof("Failed to notify error to client: %v", werr)
Expand Down
15 changes: 15 additions & 0 deletions runtime/mux.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ type ServeMux struct {
incomingHeaderMatcher HeaderMatcherFunc
outgoingHeaderMatcher HeaderMatcherFunc
metadataAnnotators []func(context.Context, *http.Request) metadata.MD
streamErrorHandler StreamErrorHandlerFunc
protoErrorHandler ProtoErrorHandlerFunc
disablePathLengthFallback bool
}
Expand Down Expand Up @@ -110,6 +111,20 @@ func WithDisablePathLengthFallback() ServeMuxOption {
}
}

// WithStreamErrorHandler returns a ServeMuxOption that will use the given custom stream
// error handler, which allows for customizing the error trailer for server-streaming
// calls.
//
// For stream errors that occur before any response has been written, the mux's
// ProtoErrorHandler will be invoked. However, once data has been written, the errors must
// be handled differently: they must be included in the response body. The response body's
// final message will include the error details returned by the stream error handler.
func WithStreamErrorHandler(fn StreamErrorHandlerFunc) ServeMuxOption {
return func(serveMux *ServeMux) {
serveMux.streamErrorHandler = fn
}
}

// NewServeMux returns a new ServeMux whose internal mapping is empty.
func NewServeMux(opts ...ServeMuxOption) *ServeMux {
serveMux := &ServeMux{
Expand Down
11 changes: 10 additions & 1 deletion runtime/proto_errors.go
Original file line number Diff line number Diff line change
@@ -1,15 +1,24 @@
package runtime

import (
"context"
"io"
"net/http"

"context"
"github.com/grpc-ecosystem/grpc-gateway/internal"
johanbrandhorst marked this conversation as resolved.
Show resolved Hide resolved
"google.golang.org/grpc/codes"
"google.golang.org/grpc/grpclog"
"google.golang.org/grpc/status"
)

// StreamErrorHandlerFunc accepts an error as a gRPC error generated via status package and translates it into a
// a proto struct used to represent error at the end of a stream.
type StreamErrorHandlerFunc func(context.Context, error) *StreamError

// StreamError is the payload for the final message in a server stream in the event that the server returns an
// error after a response message has already been sent.
type StreamError internal.StreamError

// ProtoErrorHandlerFunc handles the error as a gRPC error generated via status package and replies to the request.
type ProtoErrorHandlerFunc func(context.Context, *ServeMux, Marshaler, http.ResponseWriter, *http.Request, error)

Expand Down