Skip to content

Commit

Permalink
add unit tests and fix casing
Browse files Browse the repository at this point in the history
  • Loading branch information
Ian Lee authored and Ian Lee committed Oct 24, 2024
1 parent dbf58d4 commit fd6924c
Show file tree
Hide file tree
Showing 5 changed files with 54 additions and 4 deletions.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion examples/internal/proto/examplepb/stream.pb.gw.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion protoc-gen-grpc-gateway/internal/gengateway/template.go
Original file line number Diff line number Diff line change
Expand Up @@ -747,7 +747,7 @@ func Register{{$svc.GetName}}{{$.RegisterFuncSuffix}}Client(ctx context.Context,
go func() {
for err := range reqErrChan {
if err != nil && err != io.EOF {
runtime.HttpStreamError(annotatedContext, mux, outboundMarshaler, w, req, err)
runtime.HTTPStreamError(annotatedContext, mux, outboundMarshaler, w, req, err)
}
}
}()
Expand Down
2 changes: 1 addition & 1 deletion runtime/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ func HTTPError(ctx context.Context, mux *ServeMux, marshaler Marshaler, w http.R
}

// HttpStreamError uses the mux-configured stream error handler to notify error to the client without closing the connection.
func HttpStreamError(ctx context.Context, mux *ServeMux, marshaler Marshaler, w http.ResponseWriter, r *http.Request, err error) {
func HTTPStreamError(ctx context.Context, mux *ServeMux, marshaler Marshaler, w http.ResponseWriter, r *http.Request, err error) {
st := mux.streamErrorHandler(ctx, err)
msg := errorChunk(st)
buf, err := marshaler.Marshal(msg)
Expand Down
50 changes: 50 additions & 0 deletions runtime/errors_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package runtime_test

import (
"bytes"
"context"
"errors"
"net/http"
Expand Down Expand Up @@ -132,3 +133,52 @@ func TestDefaultHTTPError(t *testing.T) {
})
}
}

func TestHTTPStreamError(t *testing.T) {
ctx := context.Background()

for _, tc := range []struct {
name string
err error
expectedStatus *status.Status
expectedResponse []byte
}{
{
name: "Simple error",
err: errors.New("simple error"),
expectedStatus: status.New(codes.Unknown, "simple error"),
expectedResponse: []byte(`{"error":{"code":2,"message":"simple error"}}`),
},
{
name: "Invalid request error",
err: status.Error(codes.InvalidArgument, "invalid request"),
expectedStatus: status.New(codes.InvalidArgument, "invalid request"),
expectedResponse: []byte(`{"error":{"code":3,"message":"invalid request"}}`),
},
} {
t.Run(tc.name, func(t *testing.T) {
w := httptest.NewRecorder()
r := httptest.NewRequest("GET", "/", nil)

mux := runtime.NewServeMux(runtime.WithStreamErrorHandler(
runtime.DefaultStreamErrorHandler,
))

marshaler := &runtime.JSONPb{}

runtime.HTTPStreamError(ctx, mux, marshaler, w, r, tc.err)

if w.Code != http.StatusOK {
t.Errorf("Expected status code %d, got %d", http.StatusOK, w.Code)
}

if !proto.Equal(status.Convert(tc.err).Proto(), tc.expectedStatus.Proto()) {
t.Errorf("Expected status %v, got %v", tc.expectedStatus, status.Convert(tc.err))
}

if !bytes.Equal(w.Body.Bytes(), tc.expectedResponse) {
t.Errorf("Expected response %s, got %s", tc.expectedResponse, w.Body.Bytes())
}
})
}
}

0 comments on commit fd6924c

Please sign in to comment.