From f33cdd4864b4c6f67b18dc8f09e2a8cd878ae7fc Mon Sep 17 00:00:00 2001 From: Edward McFarlane Date: Tue, 7 Nov 2017 15:43:22 +0000 Subject: [PATCH] runtime: stream errors forwarding --- runtime/handler.go | 23 +++++--- runtime/handler_test.go | 116 ++++++++++++++++++++++++++-------------- 2 files changed, 92 insertions(+), 47 deletions(-) diff --git a/runtime/handler.go b/runtime/handler.go index ae6a5d551cf..38bd516a024 100644 --- a/runtime/handler.go +++ b/runtime/handler.go @@ -34,34 +34,36 @@ func ForwardResponseStream(ctx context.Context, mux *ServeMux, marshaler Marshal w.Header().Set("Transfer-Encoding", "chunked") w.Header().Set("Content-Type", marshaler.ContentType()) if err := handleForwardResponseOptions(ctx, w, nil, opts); err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) + HTTPError(ctx, mux, marshaler, w, req, err) return } - w.WriteHeader(http.StatusOK) - f.Flush() + + var wroteHeader bool for { resp, err := recv() if err == io.EOF { return } if err != nil { - handleForwardResponseStreamError(marshaler, w, err) + handleForwardResponseStreamError(wroteHeader, marshaler, w, err) return } if err := handleForwardResponseOptions(ctx, w, resp, opts); err != nil { - handleForwardResponseStreamError(marshaler, w, err) + handleForwardResponseStreamError(wroteHeader, marshaler, w, err) return } buf, err := marshaler.Marshal(streamChunk(resp, nil)) if err != nil { grpclog.Printf("Failed to marshal response chunk: %v", err) + handleForwardResponseStreamError(wroteHeader, marshaler, w, err) return } if _, err = w.Write(buf); err != nil { grpclog.Printf("Failed to send response chunk: %v", err) return } + wroteHeader = true f.Flush() } } @@ -134,13 +136,20 @@ func handleForwardResponseOptions(ctx context.Context, w http.ResponseWriter, re return nil } -func handleForwardResponseStreamError(marshaler Marshaler, w http.ResponseWriter, err error) { +func handleForwardResponseStreamError(wroteHeader bool, marshaler Marshaler, w http.ResponseWriter, err error) { buf, merr := marshaler.Marshal(streamChunk(nil, err)) if merr != nil { grpclog.Printf("Failed to marshal an error: %v", merr) return } - if _, werr := fmt.Fprintf(w, "%s\n", buf); werr != nil { + if !wroteHeader { + s, ok := status.FromError(err) + if !ok { + s = status.New(codes.Unknown, err.Error()) + } + w.WriteHeader(HTTPStatusFromCode(s.Code())) + } + if _, werr := w.Write(buf); werr != nil { grpclog.Printf("Failed to notify error to client: %v", werr) return } diff --git a/runtime/handler_test.go b/runtime/handler_test.go index 8c61ee1510a..d63948cdd9f 100644 --- a/runtime/handler_test.go +++ b/runtime/handler_test.go @@ -11,57 +11,93 @@ import ( pb "github.com/grpc-ecosystem/grpc-gateway/examples/examplepb" "github.com/grpc-ecosystem/grpc-gateway/runtime" "golang.org/x/net/context" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" ) func TestForwardResponseStream(t *testing.T) { - var ( - msgs = []proto.Message{ - &pb.SimpleMessage{Id: "One"}, - &pb.SimpleMessage{Id: "Two"}, - } + type msg struct { + pb proto.Message + err error + } + tests := []struct { + name string + msgs []msg + statusCode int + }{{ + name: "encoding", + msgs: []msg{ + {&pb.SimpleMessage{Id: "One"}, nil}, + {&pb.SimpleMessage{Id: "Two"}, nil}, + }, + statusCode: http.StatusOK, + }, { + name: "empty", + statusCode: http.StatusOK, + }, { + name: "error", + msgs: []msg{{nil, grpc.Errorf(codes.OutOfRange, "400")}}, + statusCode: http.StatusBadRequest, + }, { + name: "stream_error", + msgs: []msg{ + {&pb.SimpleMessage{Id: "One"}, nil}, + {nil, grpc.Errorf(codes.OutOfRange, "400")}, + }, + statusCode: http.StatusOK, + }} - ctx = runtime.NewServerMetadataContext( - context.Background(), runtime.ServerMetadata{}, - ) - mux = runtime.NewServeMux() - marshaler = &runtime.JSONPb{} - req = httptest.NewRequest("GET", "http://example.com/foo", nil) - resp = httptest.NewRecorder() - count = 0 - recv = func() (proto.Message, error) { - if count >= len(msgs) { + newTestRecv := func(t *testing.T, msgs []msg) func() (proto.Message, error) { + var count int + return func() (proto.Message, error) { + if count == len(msgs) { return nil, io.EOF + } else if count > len(msgs) { + t.Errorf("recv() called %d times for %d messages", count, len(msgs)) } count++ - return msgs[count-1], nil + msg := msgs[count-1] + return msg.pb, msg.err } - ) + } + ctx := runtime.NewServerMetadataContext(context.Background(), runtime.ServerMetadata{}) + marshaler := &runtime.JSONPb{} + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + recv := newTestRecv(t, tt.msgs) + req := httptest.NewRequest("GET", "http://example.com/foo", nil) + resp := httptest.NewRecorder() - runtime.ForwardResponseStream(ctx, mux, marshaler, resp, req, recv) + runtime.ForwardResponseStream(ctx, runtime.NewServeMux(), marshaler, resp, req, recv) - w := resp.Result() - if w.StatusCode != http.StatusOK { - t.Errorf(" got %d want %d", w.StatusCode, http.StatusOK) - } - if h := w.Header.Get("Transfer-Encoding"); h != "chunked" { - t.Errorf("ForwardResponseStream missing header chunked") - } - body, err := ioutil.ReadAll(w.Body) - if err != nil { - t.Errorf("Failed to read response body with %v", err) - } - w.Body.Close() + w := resp.Result() + if w.StatusCode != tt.statusCode { + t.Errorf("StatusCode %d want %d", w.StatusCode, tt.statusCode) + } + if h := w.Header.Get("Transfer-Encoding"); h != "chunked" { + t.Errorf("ForwardResponseStream missing header chunked") + } + body, err := ioutil.ReadAll(w.Body) + if err != nil { + t.Errorf("Failed to read response body with %v", err) + } + w.Body.Close() - var want []byte - for _, msg := range msgs { - b, err := marshaler.Marshal(map[string]proto.Message{"result": msg}) - if err != nil { - t.Errorf("marshaler.Marshal() failed %v", err) - } - want = append(want, b...) - } + var want []byte + for _, msg := range tt.msgs { + if msg.err != nil { + t.Skip("checking erorr encodings") + } + b, err := marshaler.Marshal(map[string]proto.Message{"result": msg.pb}) + if err != nil { + t.Errorf("marshaler.Marshal() failed %v", err) + } + want = append(want, b...) + } - if string(body) != string(want) { - t.Errorf("ForwardResponseStream() = \"%s\" want \"%s\"", body, want) + if string(body) != string(want) { + t.Errorf("ForwardResponseStream() = \"%s\" want \"%s\"", body, want) + } + }) } }