Skip to content

Commit

Permalink
runtime: stream errors forwarding
Browse files Browse the repository at this point in the history
  • Loading branch information
emcfarlane authored and achew22 committed Nov 18, 2017
1 parent 0395325 commit f33cdd4
Show file tree
Hide file tree
Showing 2 changed files with 92 additions and 47 deletions.
23 changes: 16 additions & 7 deletions runtime/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,34 +34,36 @@ func ForwardResponseStream(ctx context.Context, mux *ServeMux, marshaler Marshal
w.Header().Set("Transfer-Encoding", "chunked")
w.Header().Set("Content-Type", marshaler.ContentType())
if err := handleForwardResponseOptions(ctx, w, nil, opts); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
HTTPError(ctx, mux, marshaler, w, req, err)
return
}
w.WriteHeader(http.StatusOK)
f.Flush()

var wroteHeader bool
for {
resp, err := recv()
if err == io.EOF {
return
}
if err != nil {
handleForwardResponseStreamError(marshaler, w, err)
handleForwardResponseStreamError(wroteHeader, marshaler, w, err)
return
}
if err := handleForwardResponseOptions(ctx, w, resp, opts); err != nil {
handleForwardResponseStreamError(marshaler, w, err)
handleForwardResponseStreamError(wroteHeader, marshaler, w, err)
return
}

buf, err := marshaler.Marshal(streamChunk(resp, nil))
if err != nil {
grpclog.Printf("Failed to marshal response chunk: %v", err)
handleForwardResponseStreamError(wroteHeader, marshaler, w, err)
return
}
if _, err = w.Write(buf); err != nil {
grpclog.Printf("Failed to send response chunk: %v", err)
return
}
wroteHeader = true
f.Flush()
}
}
Expand Down Expand Up @@ -134,13 +136,20 @@ func handleForwardResponseOptions(ctx context.Context, w http.ResponseWriter, re
return nil
}

func handleForwardResponseStreamError(marshaler Marshaler, w http.ResponseWriter, err error) {
func handleForwardResponseStreamError(wroteHeader bool, marshaler Marshaler, w http.ResponseWriter, err error) {
buf, merr := marshaler.Marshal(streamChunk(nil, err))
if merr != nil {
grpclog.Printf("Failed to marshal an error: %v", merr)
return
}
if _, werr := fmt.Fprintf(w, "%s\n", buf); werr != nil {
if !wroteHeader {
s, ok := status.FromError(err)
if !ok {
s = status.New(codes.Unknown, err.Error())
}
w.WriteHeader(HTTPStatusFromCode(s.Code()))
}
if _, werr := w.Write(buf); werr != nil {
grpclog.Printf("Failed to notify error to client: %v", werr)
return
}
Expand Down
116 changes: 76 additions & 40 deletions runtime/handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,57 +11,93 @@ import (
pb "github.com/grpc-ecosystem/grpc-gateway/examples/examplepb"
"github.com/grpc-ecosystem/grpc-gateway/runtime"
"golang.org/x/net/context"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
)

func TestForwardResponseStream(t *testing.T) {
var (
msgs = []proto.Message{
&pb.SimpleMessage{Id: "One"},
&pb.SimpleMessage{Id: "Two"},
}
type msg struct {
pb proto.Message
err error
}
tests := []struct {
name string
msgs []msg
statusCode int
}{{
name: "encoding",
msgs: []msg{
{&pb.SimpleMessage{Id: "One"}, nil},
{&pb.SimpleMessage{Id: "Two"}, nil},
},
statusCode: http.StatusOK,
}, {
name: "empty",
statusCode: http.StatusOK,
}, {
name: "error",
msgs: []msg{{nil, grpc.Errorf(codes.OutOfRange, "400")}},
statusCode: http.StatusBadRequest,
}, {
name: "stream_error",
msgs: []msg{
{&pb.SimpleMessage{Id: "One"}, nil},
{nil, grpc.Errorf(codes.OutOfRange, "400")},
},
statusCode: http.StatusOK,
}}

ctx = runtime.NewServerMetadataContext(
context.Background(), runtime.ServerMetadata{},
)
mux = runtime.NewServeMux()
marshaler = &runtime.JSONPb{}
req = httptest.NewRequest("GET", "http://example.com/foo", nil)
resp = httptest.NewRecorder()
count = 0
recv = func() (proto.Message, error) {
if count >= len(msgs) {
newTestRecv := func(t *testing.T, msgs []msg) func() (proto.Message, error) {
var count int
return func() (proto.Message, error) {
if count == len(msgs) {
return nil, io.EOF
} else if count > len(msgs) {
t.Errorf("recv() called %d times for %d messages", count, len(msgs))
}
count++
return msgs[count-1], nil
msg := msgs[count-1]
return msg.pb, msg.err
}
)
}
ctx := runtime.NewServerMetadataContext(context.Background(), runtime.ServerMetadata{})
marshaler := &runtime.JSONPb{}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
recv := newTestRecv(t, tt.msgs)
req := httptest.NewRequest("GET", "http://example.com/foo", nil)
resp := httptest.NewRecorder()

runtime.ForwardResponseStream(ctx, mux, marshaler, resp, req, recv)
runtime.ForwardResponseStream(ctx, runtime.NewServeMux(), marshaler, resp, req, recv)

w := resp.Result()
if w.StatusCode != http.StatusOK {
t.Errorf(" got %d want %d", w.StatusCode, http.StatusOK)
}
if h := w.Header.Get("Transfer-Encoding"); h != "chunked" {
t.Errorf("ForwardResponseStream missing header chunked")
}
body, err := ioutil.ReadAll(w.Body)
if err != nil {
t.Errorf("Failed to read response body with %v", err)
}
w.Body.Close()
w := resp.Result()
if w.StatusCode != tt.statusCode {
t.Errorf("StatusCode %d want %d", w.StatusCode, tt.statusCode)
}
if h := w.Header.Get("Transfer-Encoding"); h != "chunked" {
t.Errorf("ForwardResponseStream missing header chunked")
}
body, err := ioutil.ReadAll(w.Body)
if err != nil {
t.Errorf("Failed to read response body with %v", err)
}
w.Body.Close()

var want []byte
for _, msg := range msgs {
b, err := marshaler.Marshal(map[string]proto.Message{"result": msg})
if err != nil {
t.Errorf("marshaler.Marshal() failed %v", err)
}
want = append(want, b...)
}
var want []byte
for _, msg := range tt.msgs {
if msg.err != nil {
t.Skip("checking erorr encodings")
}
b, err := marshaler.Marshal(map[string]proto.Message{"result": msg.pb})
if err != nil {
t.Errorf("marshaler.Marshal() failed %v", err)
}
want = append(want, b...)
}

if string(body) != string(want) {
t.Errorf("ForwardResponseStream() = \"%s\" want \"%s\"", body, want)
if string(body) != string(want) {
t.Errorf("ForwardResponseStream() = \"%s\" want \"%s\"", body, want)
}
})
}
}

0 comments on commit f33cdd4

Please sign in to comment.