diff --git a/runtime/handler.go b/runtime/handler.go index 1770b85344d..1626fa79d96 100644 --- a/runtime/handler.go +++ b/runtime/handler.go @@ -42,7 +42,7 @@ func ForwardResponseStream(ctx context.Context, mux *ServeMux, marshaler Marshal if d, ok := marshaler.(Delimited); ok { delimiter = d.Delimiter() } else { - delimiter = []byte("\n") + delimiter = []byte("\n") } var wroteHeader bool @@ -52,18 +52,18 @@ func ForwardResponseStream(ctx context.Context, mux *ServeMux, marshaler Marshal return } if err != nil { - handleForwardResponseStreamError(wroteHeader, marshaler, w, err) + handleForwardResponseStreamError(wroteHeader, marshaler, mux, w, err) return } if err := handleForwardResponseOptions(ctx, w, resp, opts); err != nil { - handleForwardResponseStreamError(wroteHeader, marshaler, w, err) + handleForwardResponseStreamError(wroteHeader, marshaler, mux, w, err) return } - buf, err := marshaler.Marshal(streamChunk(resp, nil)) + buf, err := marshaler.Marshal(streamChunk(resp, nil, mux.protoStreamErrorFormatter)) if err != nil { grpclog.Printf("Failed to marshal response chunk: %v", err) - handleForwardResponseStreamError(wroteHeader, marshaler, w, err) + handleForwardResponseStreamError(wroteHeader, marshaler, mux, w, err) return } if _, err = w.Write(buf); err != nil { @@ -147,8 +147,8 @@ func handleForwardResponseOptions(ctx context.Context, w http.ResponseWriter, re return nil } -func handleForwardResponseStreamError(wroteHeader bool, marshaler Marshaler, w http.ResponseWriter, err error) { - buf, merr := marshaler.Marshal(streamChunk(nil, err)) +func handleForwardResponseStreamError(wroteHeader bool, marshaler Marshaler, mux *ServeMux, w http.ResponseWriter, err error) { + buf, merr := marshaler.Marshal(streamChunk(nil, err, mux.protoStreamErrorFormatter)) if merr != nil { grpclog.Printf("Failed to marshal an error: %v", merr) return @@ -166,8 +166,11 @@ func handleForwardResponseStreamError(wroteHeader bool, marshaler Marshaler, w h } } -func streamChunk(result proto.Message, err error) map[string]proto.Message { +func streamChunk(result proto.Message, err error, errorFormatter ProtoStreamErrorFormatterFunc) map[string]proto.Message { if err != nil { + if errorFormatter != nil { + return errorFormatter(err) + } grpcCode := codes.Unknown if s, ok := status.FromError(err); ok { grpcCode = s.Code() @@ -183,7 +186,7 @@ func streamChunk(result proto.Message, err error) map[string]proto.Message { } } if result == nil { - return streamChunk(nil, fmt.Errorf("empty response")) + return streamChunk(nil, fmt.Errorf("empty response"), errorFormatter) } return map[string]proto.Message{"result": result} } diff --git a/runtime/mux.go b/runtime/mux.go index 205bc430921..e1b00afa9d1 100644 --- a/runtime/mux.go +++ b/runtime/mux.go @@ -20,13 +20,14 @@ type HandlerFunc func(w http.ResponseWriter, r *http.Request, pathParams map[str // It matches http requests to patterns and invokes the corresponding handler. type ServeMux struct { // handlers maps HTTP method to a list of handlers. - handlers map[string][]handler - forwardResponseOptions []func(context.Context, http.ResponseWriter, proto.Message) error - marshalers marshalerRegistry - incomingHeaderMatcher HeaderMatcherFunc - outgoingHeaderMatcher HeaderMatcherFunc - metadataAnnotator func(context.Context, *http.Request) metadata.MD - protoErrorHandler ProtoErrorHandlerFunc + handlers map[string][]handler + forwardResponseOptions []func(context.Context, http.ResponseWriter, proto.Message) error + marshalers marshalerRegistry + incomingHeaderMatcher HeaderMatcherFunc + outgoingHeaderMatcher HeaderMatcherFunc + metadataAnnotator func(context.Context, *http.Request) metadata.MD + protoErrorHandler ProtoErrorHandlerFunc + protoStreamErrorFormatter ProtoStreamErrorFormatterFunc } // ServeMuxOption is an option that can be given to a ServeMux on construction. @@ -91,7 +92,7 @@ func WithMetadata(annotator func(context.Context, *http.Request) metadata.MD) Se } } -// WithProtoErrorHandler returns a ServeMuxOption for passing metadata to a gRPC context. +// WithProtoErrorHandler returns a ServeMuxOption for controlling the error output of the gateway. // // This can be used to handle an error as general proto message defined by gRPC. // The response including body and status is not backward compatible with the default error handler. @@ -102,6 +103,16 @@ func WithProtoErrorHandler(fn ProtoErrorHandlerFunc) ServeMuxOption { } } +// WithProtoStreamErrorFormatter returns a ServeMuxOption for controlling +// the error output of streaming endpoints of the gateway. +// +// This can be used to control the message returned to streaming endpoints when an error occurs. +func WithProtoStreamErrorFormatter(fn ProtoErrorHandlerFunc) ServeMuxOption { + return func(serveMux *ServeMux) { + serveMux.protoErrorHandler = fn + } +} + // NewServeMux returns a new ServeMux whose internal mapping is empty. func NewServeMux(opts ...ServeMuxOption) *ServeMux { serveMux := &ServeMux{ diff --git a/runtime/proto_errors.go b/runtime/proto_errors.go index b1b089273b6..ec083463743 100644 --- a/runtime/proto_errors.go +++ b/runtime/proto_errors.go @@ -4,6 +4,7 @@ import ( "io" "net/http" + "github.com/golang/protobuf/proto" "golang.org/x/net/context" "google.golang.org/grpc/codes" "google.golang.org/grpc/grpclog" @@ -59,3 +60,22 @@ func DefaultHTTPProtoErrorHandler(ctx context.Context, mux *ServeMux, marshaler handleForwardResponseTrailer(w, md) } + +// ProtoStreamErrorFormatterFunc handles the error as a gRPC error generated via status package and sends it on the stream. +type ProtoStreamErrorFormatterFunc func(error) map[string]proto.Message + +var _ ProtoStreamErrorFormatterFunc = DefaultStreamProtoErrorFormatter + +// DefaultStreamProtoErrorFormatter is an implementation of ProtoStreamErrorFormatterFunc. +// It simply returns the protobuf representation of the gRPC status. +// This is consistent with DefaultHTTPProtoErrorHandler. +func DefaultStreamProtoErrorFormatter(err error) map[string]proto.Message { + s, ok := status.FromError(err) + if !ok { + s = status.New(codes.Unknown, err.Error()) + } + + return map[string]proto.Message{ + "error": s.Proto(), + } +}