Skip to content

Commit

Permalink
Restrict handler metadata headers (#748)
Browse files Browse the repository at this point in the history
This PR fixes the setting of protocol headers to avoid multiple value
headers when providing metadata to a handler. The metadata headers are
further restricted to avoid setting protocol headers like
"Content-Type". This restriction allows the user to pass the response of
a proxy call to a handler without having to filter the response headers
themselves. This enforces the protocol headers are set by the handler
and that they are unaffected from any user provided metadata.

Previously, returning the response of a client request to a handler
would merge the headers together leading to protocol errors from invalid
headers such as "Content-Type" having multiple values.

---------

Signed-off-by: Edward McFarlane <emcfarlane@buf.build>
  • Loading branch information
emcfarlane authored May 31, 2024
1 parent 8349c6d commit aba3ff5
Show file tree
Hide file tree
Showing 7 changed files with 104 additions and 34 deletions.
3 changes: 3 additions & 0 deletions .golangci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -132,3 +132,6 @@ issues:
# We want to show examples with http.Get
- linters: [noctx]
path: internal/memhttp/memhttp_test.go
# We need to initialize a map of all protocol headers
- linters: [gochecknoglobals]
path: header.go
64 changes: 64 additions & 0 deletions connect_ext_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2631,6 +2631,70 @@ func TestBlankImportCodeGeneration(t *testing.T) {
assert.NotNil(t, desc)
}

// TestSetProtocolHeaders tests that headers required by the protocols are set
// overriding user provided headers.
func TestSetProtocolHeaders(t *testing.T) {
t.Parallel()
tests := []struct {
name string
clientOption connect.ClientOption
expectContentType string
}{{
name: "connect",
expectContentType: "application/proto",
}, {
name: "grpc",
clientOption: connect.WithGRPC(),
expectContentType: "application/grpc",
}, {
name: "grpcweb",
clientOption: connect.WithGRPCWeb(),
expectContentType: "application/grpc-web+proto",
}}
for _, tt := range tests {
testcase := tt
t.Run(testcase.name, func(t *testing.T) {
t.Parallel()
pingServer := &pingServer{}
mux := http.NewServeMux()
mux.Handle(pingv1connect.NewPingServiceHandler(pingServer))
server := memhttptest.NewServer(t, mux)

clientOpts := []connect.ClientOption{}
if testcase.clientOption == nil {
// Use a different protocol to test the override.
clientOpts = append(clientOpts, connect.WithGRPC())
}
client := pingv1connect.NewPingServiceClient(server.Client(), server.URL(), clientOpts...)

pingProxyServer := &pluggablePingServer{
ping: func(ctx context.Context, request *connect.Request[pingv1.PingRequest]) (*connect.Response[pingv1.PingResponse], error) {
return client.Ping(ctx, request)
},
}
proxyMux := http.NewServeMux()
proxyMux.Handle(pingv1connect.NewPingServiceHandler(pingProxyServer))
proxyServer := memhttptest.NewServer(t, proxyMux)

proxyClientOpts := []connect.ClientOption{}
if testcase.clientOption != nil {
proxyClientOpts = append(proxyClientOpts, testcase.clientOption)
}
proxyClient := pingv1connect.NewPingServiceClient(proxyServer.Client(), proxyServer.URL(), proxyClientOpts...)

request := connect.NewRequest(&pingv1.PingRequest{Number: 42})
request.Header().Set("X-Test", t.Name())
response, err := proxyClient.Ping(context.Background(), request)
if !assert.Nil(t, err) {
return
}
// Assert the Content-Type is set for the proxy clients protocol and not the client's.
assert.Equal(t, response.Header().Get("Content-Type"), testcase.expectContentType)
assert.Equal(t, len(response.Header().Values("Content-Type")), 1)
})
}
}

type unflushableWriter struct {
w http.ResponseWriter
}
Expand Down
2 changes: 1 addition & 1 deletion error_writer.go
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ func (w *ErrorWriter) Write(response http.ResponseWriter, request *http.Request,

func (w *ErrorWriter) writeConnectUnary(response http.ResponseWriter, err error) error {
if connectErr, ok := asError(err); ok && !connectErr.wireErr {
mergeMetadataHeaders(response.Header(), connectErr.meta)
mergeNonProtocolHeaders(response.Header(), connectErr.meta)
}
response.WriteHeader(connectCodeToHTTP(CodeOf(err)))
data, marshalErr := json.Marshal(newConnectWireError(err))
Expand Down
4 changes: 2 additions & 2 deletions handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,8 @@ func NewUnaryHandler[Req, Res any](
if err != nil {
return err
}
mergeHeaders(conn.ResponseHeader(), response.Header())
mergeHeaders(conn.ResponseTrailer(), response.Trailer())
mergeNonProtocolHeaders(conn.ResponseHeader(), response.Header())
mergeNonProtocolHeaders(conn.ResponseTrailer(), response.Trailer())
return conn.Send(response.Any())
}

Expand Down
59 changes: 31 additions & 28 deletions header.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,33 @@ import (
"net/http"
)

var (
protocolHeaders = map[string]struct{}{
// HTTP headers.
headerContentType: {},
headerContentLength: {},
headerContentEncoding: {},
headerHost: {},
headerUserAgent: {},
headerTrailer: {},
headerDate: {},
// Connect headers.
connectUnaryHeaderAcceptCompression: {},
connectUnaryTrailerPrefix: {},
connectStreamingHeaderCompression: {},
connectStreamingHeaderAcceptCompression: {},
connectHeaderTimeout: {},
connectHeaderProtocolVersion: {},
// gRPC headers.
grpcHeaderCompression: {},
grpcHeaderAcceptCompression: {},
grpcHeaderTimeout: {},
grpcHeaderStatus: {},
grpcHeaderMessage: {},
grpcHeaderDetails: {},
}
)

// EncodeBinaryHeader base64-encodes the data. It always emits unpadded values.
//
// In the Connect, gRPC, and gRPC-Web protocols, binary headers must have keys
Expand Down Expand Up @@ -57,41 +84,17 @@ func mergeHeaders(into, from http.Header) {
}
}

// mergeMetdataHeaders merges the metadata headers from the "from" header into
// the "into" header. It skips over non metadata headers that should not be
// propagated from the server to the client.
func mergeMetadataHeaders(into, from http.Header) {
// mergeNonProtocolHeaders merges headers excluding protocol headers defined in
// protocolHeaders.
func mergeNonProtocolHeaders(into, from http.Header) {
for key, vals := range from {
if len(vals) == 0 {
// For response trailers, net/http will pre-populate entries
// with nil values based on the "Trailer" header. But if there
// are no actual values for those keys, we skip them.
continue
}
switch http.CanonicalHeaderKey(key) {
case headerContentType,
headerContentLength,
headerContentEncoding,
headerHost,
headerUserAgent,
headerTrailer,
headerDate:
// HTTP headers.
case connectUnaryHeaderAcceptCompression,
connectUnaryTrailerPrefix,
connectStreamingHeaderCompression,
connectStreamingHeaderAcceptCompression,
connectHeaderTimeout,
connectHeaderProtocolVersion:
// Connect headers.
case grpcHeaderCompression,
grpcHeaderAcceptCompression,
grpcHeaderTimeout,
grpcHeaderStatus,
grpcHeaderMessage,
grpcHeaderDetails:
// gRPC headers.
default:
if _, isProtocolHeader := protocolHeaders[key]; !isProtocolHeader {
into[key] = append(into[key], vals...)
}
}
Expand Down
4 changes: 2 additions & 2 deletions protocol_connect.go
Original file line number Diff line number Diff line change
Expand Up @@ -765,7 +765,7 @@ func (hc *connectUnaryHandlerConn) mergeResponseHeader(err error) {
}
if err != nil {
if connectErr, ok := asError(err); ok && !connectErr.wireErr {
mergeMetadataHeaders(header, connectErr.meta)
mergeNonProtocolHeaders(header, connectErr.meta)
}
}
for k, v := range hc.responseTrailer {
Expand Down Expand Up @@ -850,7 +850,7 @@ func (m *connectStreamingMarshaler) MarshalEndStream(err error, trailer http.Hea
if err != nil {
end.Error = newConnectWireError(err)
if connectErr, ok := asError(err); ok && !connectErr.wireErr {
mergeMetadataHeaders(end.Trailer, connectErr.meta)
mergeNonProtocolHeaders(end.Trailer, connectErr.meta)
}
}
data, marshalErr := json.Marshal(end)
Expand Down
2 changes: 1 addition & 1 deletion protocol_grpc.go
Original file line number Diff line number Diff line change
Expand Up @@ -841,7 +841,7 @@ func grpcErrorToTrailer(trailer http.Header, protobuf Codec, err error) {
return
}
if connectErr, ok := asError(err); ok && !connectErr.wireErr {
mergeMetadataHeaders(trailer, connectErr.meta)
mergeNonProtocolHeaders(trailer, connectErr.meta)
}
var (
status = grpcStatusFromError(err)
Expand Down

0 comments on commit aba3ff5

Please sign in to comment.