Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add method for custom formatting of stream errors. #564

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 12 additions & 9 deletions runtime/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ func ForwardResponseStream(ctx context.Context, mux *ServeMux, marshaler Marshal
if d, ok := marshaler.(Delimited); ok {
delimiter = d.Delimiter()
} else {
delimiter = []byte("\n")
delimiter = []byte("\n")
}

var wroteHeader bool
Expand All @@ -52,18 +52,18 @@ func ForwardResponseStream(ctx context.Context, mux *ServeMux, marshaler Marshal
return
}
if err != nil {
handleForwardResponseStreamError(wroteHeader, marshaler, w, err)
handleForwardResponseStreamError(wroteHeader, marshaler, mux, w, err)
return
}
if err := handleForwardResponseOptions(ctx, w, resp, opts); err != nil {
handleForwardResponseStreamError(wroteHeader, marshaler, w, err)
handleForwardResponseStreamError(wroteHeader, marshaler, mux, w, err)
return
}

buf, err := marshaler.Marshal(streamChunk(resp, nil))
buf, err := marshaler.Marshal(streamChunk(resp, nil, mux.protoStreamErrorFormatter))
if err != nil {
grpclog.Printf("Failed to marshal response chunk: %v", err)
handleForwardResponseStreamError(wroteHeader, marshaler, w, err)
handleForwardResponseStreamError(wroteHeader, marshaler, mux, w, err)
return
}
if _, err = w.Write(buf); err != nil {
Expand Down Expand Up @@ -147,8 +147,8 @@ func handleForwardResponseOptions(ctx context.Context, w http.ResponseWriter, re
return nil
}

func handleForwardResponseStreamError(wroteHeader bool, marshaler Marshaler, w http.ResponseWriter, err error) {
buf, merr := marshaler.Marshal(streamChunk(nil, err))
func handleForwardResponseStreamError(wroteHeader bool, marshaler Marshaler, mux *ServeMux, w http.ResponseWriter, err error) {
buf, merr := marshaler.Marshal(streamChunk(nil, err, mux.protoStreamErrorFormatter))
if merr != nil {
grpclog.Printf("Failed to marshal an error: %v", merr)
return
Expand All @@ -166,8 +166,11 @@ func handleForwardResponseStreamError(wroteHeader bool, marshaler Marshaler, w h
}
}

func streamChunk(result proto.Message, err error) map[string]proto.Message {
func streamChunk(result proto.Message, err error, errorFormatter ProtoStreamErrorFormatterFunc) map[string]proto.Message {
if err != nil {
if errorFormatter != nil {
return errorFormatter(err)
}
grpcCode := codes.Unknown
if s, ok := status.FromError(err); ok {
grpcCode = s.Code()
Expand All @@ -183,7 +186,7 @@ func streamChunk(result proto.Message, err error) map[string]proto.Message {
}
}
if result == nil {
return streamChunk(nil, fmt.Errorf("empty response"))
return streamChunk(nil, fmt.Errorf("empty response"), errorFormatter)
}
return map[string]proto.Message{"result": result}
}
27 changes: 19 additions & 8 deletions runtime/mux.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,14 @@ type HandlerFunc func(w http.ResponseWriter, r *http.Request, pathParams map[str
// It matches http requests to patterns and invokes the corresponding handler.
type ServeMux struct {
// handlers maps HTTP method to a list of handlers.
handlers map[string][]handler
forwardResponseOptions []func(context.Context, http.ResponseWriter, proto.Message) error
marshalers marshalerRegistry
incomingHeaderMatcher HeaderMatcherFunc
outgoingHeaderMatcher HeaderMatcherFunc
metadataAnnotator func(context.Context, *http.Request) metadata.MD
protoErrorHandler ProtoErrorHandlerFunc
handlers map[string][]handler
forwardResponseOptions []func(context.Context, http.ResponseWriter, proto.Message) error
marshalers marshalerRegistry
incomingHeaderMatcher HeaderMatcherFunc
outgoingHeaderMatcher HeaderMatcherFunc
metadataAnnotator func(context.Context, *http.Request) metadata.MD
protoErrorHandler ProtoErrorHandlerFunc
protoStreamErrorFormatter ProtoStreamErrorFormatterFunc
}

// ServeMuxOption is an option that can be given to a ServeMux on construction.
Expand Down Expand Up @@ -91,7 +92,7 @@ func WithMetadata(annotator func(context.Context, *http.Request) metadata.MD) Se
}
}

// WithProtoErrorHandler returns a ServeMuxOption for passing metadata to a gRPC context.
// WithProtoErrorHandler returns a ServeMuxOption for controlling the error output of the gateway.
//
// This can be used to handle an error as general proto message defined by gRPC.
// The response including body and status is not backward compatible with the default error handler.
Expand All @@ -102,6 +103,16 @@ func WithProtoErrorHandler(fn ProtoErrorHandlerFunc) ServeMuxOption {
}
}

// WithProtoStreamErrorFormatter returns a ServeMuxOption for controlling
// the error output of streaming endpoints of the gateway.
//
// This can be used to control the message returned to streaming endpoints when an error occurs.
func WithProtoStreamErrorFormatter(fn ProtoErrorHandlerFunc) ServeMuxOption {
return func(serveMux *ServeMux) {
serveMux.protoErrorHandler = fn
}
}

// NewServeMux returns a new ServeMux whose internal mapping is empty.
func NewServeMux(opts ...ServeMuxOption) *ServeMux {
serveMux := &ServeMux{
Expand Down
20 changes: 20 additions & 0 deletions runtime/proto_errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"io"
"net/http"

"github.com/golang/protobuf/proto"
"golang.org/x/net/context"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/grpclog"
Expand Down Expand Up @@ -59,3 +60,22 @@ func DefaultHTTPProtoErrorHandler(ctx context.Context, mux *ServeMux, marshaler

handleForwardResponseTrailer(w, md)
}

// ProtoStreamErrorFormatterFunc handles the error as a gRPC error generated via status package and sends it on the stream.
type ProtoStreamErrorFormatterFunc func(error) map[string]proto.Message

var _ ProtoStreamErrorFormatterFunc = DefaultStreamProtoErrorFormatter

// DefaultStreamProtoErrorFormatter is an implementation of ProtoStreamErrorFormatterFunc.
// It simply returns the protobuf representation of the gRPC status.
// This is consistent with DefaultHTTPProtoErrorHandler.
func DefaultStreamProtoErrorFormatter(err error) map[string]proto.Message {
s, ok := status.FromError(err)
if !ok {
s = status.New(codes.Unknown, err.Error())
}

return map[string]proto.Message{
"error": s.Proto(),
}
}