diff --git a/runtime/context.go b/runtime/context.go index 03a3015a82a..896057e1e1e 100644 --- a/runtime/context.go +++ b/runtime/context.go @@ -1,14 +1,16 @@ package runtime import ( + "context" + "encoding/base64" "fmt" "net" "net/http" + "net/textproto" "strconv" "strings" "time" - "context" "google.golang.org/grpc/codes" "google.golang.org/grpc/grpclog" "google.golang.org/grpc/metadata" @@ -28,6 +30,7 @@ const MetadataPrefix = "grpcgateway-" const MetadataTrailerPrefix = "Grpc-Trailer-" const metadataGrpcTimeout = "Grpc-Timeout" +const metadataHeaderBinarySuffix = "-Bin" const xForwardedFor = "X-Forwarded-For" const xForwardedHost = "X-Forwarded-Host" @@ -38,6 +41,14 @@ var ( DefaultContextTimeout = 0 * time.Second ) +func decodeBinHeader(v string) ([]byte, error) { + if len(v)%4 == 0 { + // Input was padded, or padding was not necessary. + return base64.StdEncoding.DecodeString(v) + } + return base64.RawStdEncoding.DecodeString(v) +} + /* AnnotateContext adds context information such as metadata from the request. @@ -58,11 +69,22 @@ func AnnotateContext(ctx context.Context, mux *ServeMux, req *http.Request) (con for key, vals := range req.Header { for _, val := range vals { + key = textproto.CanonicalMIMEHeaderKey(key) // For backwards-compatibility, pass through 'authorization' header with no prefix. - if strings.ToLower(key) == "authorization" { + if key == "Authorization" { pairs = append(pairs, "authorization", val) } if h, ok := mux.incomingHeaderMatcher(key); ok { + // Handles "-bin" metadata in grpc, since grpc will do another base64 + // encode before sending to server, we need to decode it first. + if strings.HasSuffix(key, metadataHeaderBinarySuffix) { + b, err := decodeBinHeader(val) + if err != nil { + return nil, status.Errorf(codes.InvalidArgument, "invalid binary header %s: %s", key, err) + } + + val = string(b) + } pairs = append(pairs, h, val) } } diff --git a/runtime/context_test.go b/runtime/context_test.go index e78a037d033..5f752408a03 100644 --- a/runtime/context_test.go +++ b/runtime/context_test.go @@ -1,12 +1,13 @@ package runtime_test import ( + "context" + "encoding/base64" "net/http" "reflect" "testing" "time" - "context" "github.com/grpc-ecosystem/grpc-gateway/runtime" "google.golang.org/grpc/metadata" ) @@ -68,6 +69,30 @@ func TestAnnotateContext_ForwardsGrpcMetadata(t *testing.T) { } } +func TestAnnotateContext_ForwardGrpcBinaryMetadata(t *testing.T) { + ctx := context.Background() + request, err := http.NewRequest("GET", "http://www.example.com", nil) + if err != nil { + t.Fatalf("http.NewRequest(%q, %q, nil) failed with %v; want success", "GET", "http://www.example.com", err) + } + + binData := []byte("\x00test-binary-data") + request.Header.Add("Grpc-Metadata-Test-Bin", base64.StdEncoding.EncodeToString(binData)) + + annotated, err := runtime.AnnotateContext(ctx, runtime.NewServeMux(), request) + if err != nil { + t.Errorf("runtime.AnnotateContext(ctx, %#v) failed with %v; want success", request, err) + return + } + md, ok := metadata.FromOutgoingContext(annotated) + if !ok || len(md) != emptyForwardMetaCount+1 { + t.Errorf("Expected %d metadata items in context; got %v", emptyForwardMetaCount+1, md) + } + if got, want := md["test-bin"], []string{string(binData)}; !reflect.DeepEqual(got, want) { + t.Errorf(`md["test-bin"] = %q want %q`, got, want) + } +} + func TestAnnotateContext_XForwardedFor(t *testing.T) { ctx := context.Background() request, err := http.NewRequest("GET", "http://bar.foo.example.com", nil) diff --git a/runtime/mux.go b/runtime/mux.go index a184291e16a..8f7dd36b700 100644 --- a/runtime/mux.go +++ b/runtime/mux.go @@ -3,7 +3,6 @@ package runtime import ( "fmt" "net/http" - "net/textproto" "strings" "context" @@ -51,7 +50,6 @@ type HeaderMatcherFunc func(string) (string, bool) // keys (as specified by the IANA) to gRPC context with grpcgateway- prefix. HTTP headers that start with // 'Grpc-Metadata-' are mapped to gRPC metadata after removing prefix 'Grpc-Metadata-'. func DefaultHeaderMatcher(key string) (string, bool) { - key = textproto.CanonicalMIMEHeaderKey(key) if isPermanentHTTPHeader(key) { return MetadataPrefix + key, true } else if strings.HasPrefix(key, MetadataHeaderPrefix) {