diff --git a/connect_ext_test.go b/connect_ext_test.go index fefb9f7d..86bdeeff 100644 --- a/connect_ext_test.go +++ b/connect_ext_test.go @@ -2423,7 +2423,7 @@ func TestClientDisconnect(t *testing.T) { assert.NotNil(t, err) <-gotResponse assert.NotNil(t, handlerReceiveErr) - assert.Equal(t, connect.CodeOf(handlerReceiveErr), connect.CodeCanceled) + assert.Equal(t, connect.CodeOf(handlerReceiveErr), connect.CodeCanceled, assert.Sprintf("got %v", handlerReceiveErr)) assert.ErrorIs(t, handlerContextErr, context.Canceled) }) t.Run("handler_writes", func(t *testing.T) { @@ -2434,7 +2434,7 @@ func TestClientDisconnect(t *testing.T) { gotResponse = make(chan struct{}) ) pingServer := &pluggablePingServer{ - countUp: func(ctx context.Context, req *connect.Request[pingv1.CountUpRequest], stream *connect.ServerStream[pingv1.CountUpResponse]) error { + countUp: func(ctx context.Context, _ *connect.Request[pingv1.CountUpRequest], stream *connect.ServerStream[pingv1.CountUpResponse]) error { close(gotRequest) var err error for err == nil { diff --git a/error_writer.go b/error_writer.go index 629918ea..466c3b8e 100644 --- a/error_writer.go +++ b/error_writer.go @@ -41,86 +41,54 @@ const ( type ErrorWriter struct { bufferPool *bufferPool protobuf Codec - grpcContentTypes map[string]struct{} - grpcWebContentTypes map[string]struct{} - unaryConnectContentTypes map[string]struct{} - streamingConnectContentTypes map[string]struct{} requireConnectProtocolHeader bool } -// NewErrorWriter constructs an ErrorWriter. To properly recognize supported -// RPC Content-Types in net/http middleware, you must pass the same -// HandlerOptions to NewErrorWriter and any wrapped Connect handlers. +// NewErrorWriter constructs an ErrorWriter. Handler options may be passed to +// configure the error writer behaviour to match the handlers. +// [WithRequiredConnectProtocolHeader] will assert that Connect protocol +// requests include the version header allowing the error writer to correctly +// classify the request. // Options supplied via [WithConditionalHandlerOptions] are ignored. func NewErrorWriter(opts ...HandlerOption) *ErrorWriter { config := newHandlerConfig("", StreamTypeUnary, opts) - writer := &ErrorWriter{ + codecs := newReadOnlyCodecs(config.Codecs) + return &ErrorWriter{ bufferPool: config.BufferPool, - protobuf: newReadOnlyCodecs(config.Codecs).Protobuf(), - grpcContentTypes: make(map[string]struct{}), - grpcWebContentTypes: make(map[string]struct{}), - unaryConnectContentTypes: make(map[string]struct{}), - streamingConnectContentTypes: make(map[string]struct{}), + protobuf: codecs.Protobuf(), requireConnectProtocolHeader: config.RequireConnectProtocolHeader, } - for name := range config.Codecs { - unary := connectContentTypeFromCodecName(StreamTypeUnary, name) - writer.unaryConnectContentTypes[unary] = struct{}{} - streaming := connectContentTypeFromCodecName(StreamTypeBidi, name) - writer.streamingConnectContentTypes[streaming] = struct{}{} - } - if config.HandleGRPC { - writer.grpcContentTypes[grpcContentTypeDefault] = struct{}{} - for name := range config.Codecs { - ct := grpcContentTypeFromCodecName(false /* web */, name) - writer.grpcContentTypes[ct] = struct{}{} - } - } - if config.HandleGRPCWeb { - writer.grpcWebContentTypes[grpcWebContentTypeDefault] = struct{}{} - for name := range config.Codecs { - ct := grpcContentTypeFromCodecName(true /* web */, name) - writer.grpcWebContentTypes[ct] = struct{}{} - } - } - return writer } func (w *ErrorWriter) classifyRequest(request *http.Request) protocolType { ctype := canonicalizeContentType(getHeaderCanonical(request.Header, headerContentType)) - if _, ok := w.unaryConnectContentTypes[ctype]; ok { - if err := connectCheckProtocolVersion(request, w.requireConnectProtocolHeader); err != nil { - return unknownProtocol - } - return connectUnaryProtocol - } - if _, ok := w.streamingConnectContentTypes[ctype]; ok { + isPost := request.Method == http.MethodPost + isGet := request.Method == http.MethodGet + switch { + case isPost && (ctype == grpcContentTypeDefault || strings.HasPrefix(ctype, grpcContentTypePrefix)): + return grpcProtocol + case isPost && (ctype == grpcWebContentTypeDefault || strings.HasPrefix(ctype, grpcWebContentTypePrefix)): + return grpcWebProtocol + case isPost && strings.HasPrefix(ctype, connectStreamingContentTypePrefix): // Streaming ignores the requireConnectProtocolHeader option as the // Content-Type is enough to determine the protocol. if err := connectCheckProtocolVersion(request, false /* required */); err != nil { return unknownProtocol } return connectStreamProtocol - } - if _, ok := w.grpcContentTypes[ctype]; ok { - return grpcProtocol - } - if _, ok := w.grpcWebContentTypes[ctype]; ok { - return grpcWebProtocol - } - // Check for Connect-Protocol-Version header or connect protocol query - // parameter to support connect GET requests. - if request.Method == http.MethodGet { - connectVersion := getHeaderCanonical(request.Header, connectProtocolVersion) - if connectVersion == connectProtocolVersion { - return connectUnaryProtocol + case isPost && strings.HasPrefix(ctype, connectUnaryContentTypePrefix): + if err := connectCheckProtocolVersion(request, w.requireConnectProtocolHeader); err != nil { + return unknownProtocol } - connectVersion = request.URL.Query().Get(connectUnaryConnectQueryParameter) - if connectVersion == connectUnaryConnectQueryValue { - return connectUnaryProtocol + return connectUnaryProtocol + case isGet: + if err := connectCheckProtocolVersion(request, w.requireConnectProtocolHeader); err != nil { + return unknownProtocol } + return connectUnaryProtocol + default: + return unknownProtocol } - return unknownProtocol } // IsSupported checks whether a request is using one of the ErrorWriter's diff --git a/error_writer_test.go b/error_writer_test.go index 0b3be022..913b5669 100644 --- a/error_writer_test.go +++ b/error_writer_test.go @@ -24,11 +24,9 @@ import ( func TestErrorWriter(t *testing.T) { t.Parallel() - t.Run("RequireConnectProtocolHeader", func(t *testing.T) { t.Parallel() writer := NewErrorWriter(WithRequireConnectProtocolHeader()) - t.Run("Unary", func(t *testing.T) { req := httptest.NewRequest(http.MethodPost, "http://localhost", nil) req.Header.Set("Content-Type", connectUnaryContentTypePrefix+codecNameJSON) @@ -52,4 +50,64 @@ func TestErrorWriter(t *testing.T) { assert.True(t, writer.IsSupported(req)) }) }) + t.Run("Protocols", func(t *testing.T) { + t.Parallel() + writer := NewErrorWriter() // All supported by default + t.Run("ConnectUnary", func(t *testing.T) { + req := httptest.NewRequest(http.MethodPost, "http://localhost", nil) + req.Header.Set("Content-Type", connectUnaryContentTypePrefix+codecNameJSON) + assert.True(t, writer.IsSupported(req)) + }) + t.Run("ConnectUnaryGET", func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "http://localhost", nil) + assert.True(t, writer.IsSupported(req)) + }) + t.Run("ConnectStream", func(t *testing.T) { + req := httptest.NewRequest(http.MethodPost, "http://localhost", nil) + req.Header.Set("Content-Type", connectStreamingContentTypePrefix+codecNameJSON) + assert.True(t, writer.IsSupported(req)) + }) + t.Run("GRPC", func(t *testing.T) { + req := httptest.NewRequest(http.MethodPost, "http://localhost", nil) + req.Header.Set("Content-Type", grpcContentTypeDefault) + assert.True(t, writer.IsSupported(req)) + req.Header.Set("Content-Type", grpcContentTypePrefix+"json") + assert.True(t, writer.IsSupported(req)) + }) + t.Run("GRPCWeb", func(t *testing.T) { + req := httptest.NewRequest(http.MethodPost, "http://localhost", nil) + req.Header.Set("Content-Type", grpcWebContentTypeDefault) + assert.True(t, writer.IsSupported(req)) + req.Header.Set("Content-Type", grpcWebContentTypePrefix+"json") + assert.True(t, writer.IsSupported(req)) + }) + }) + t.Run("UnknownCodec", func(t *testing.T) { + // An Unknown codec should return supported as the protocol is known and + // the error codec is agnostic to the codec used. The server can respond + // with a protocol error for the unknown codec. + t.Parallel() + writer := NewErrorWriter() + unknownCodec := "invalid" + t.Run("ConnectUnary", func(t *testing.T) { + req := httptest.NewRequest(http.MethodPost, "http://localhost", nil) + req.Header.Set("Content-Type", connectUnaryContentTypePrefix+unknownCodec) + assert.True(t, writer.IsSupported(req)) + }) + t.Run("ConnectStream", func(t *testing.T) { + req := httptest.NewRequest(http.MethodPost, "http://localhost", nil) + req.Header.Set("Content-Type", connectStreamingContentTypePrefix+unknownCodec) + assert.True(t, writer.IsSupported(req)) + }) + t.Run("GRPC", func(t *testing.T) { + req := httptest.NewRequest(http.MethodPost, "http://localhost", nil) + req.Header.Set("Content-Type", grpcContentTypePrefix+unknownCodec) + assert.True(t, writer.IsSupported(req)) + }) + t.Run("GRPCWeb", func(t *testing.T) { + req := httptest.NewRequest(http.MethodPost, "http://localhost", nil) + req.Header.Set("Content-Type", grpcWebContentTypePrefix+unknownCodec) + assert.True(t, writer.IsSupported(req)) + }) + }) } diff --git a/handler.go b/handler.go index 88e80360..77724bdf 100644 --- a/handler.go +++ b/handler.go @@ -274,8 +274,6 @@ type handlerConfig struct { Procedure string Schema any Initializer maybeInitializer - HandleGRPC bool - HandleGRPCWeb bool RequireConnectProtocolHeader bool IdempotencyLevel IdempotencyLevel BufferPool *bufferPool @@ -290,8 +288,6 @@ func newHandlerConfig(procedure string, streamType StreamType, options []Handler Procedure: protoPath, CompressionPools: make(map[string]*compressionPool), Codecs: make(map[string]Codec), - HandleGRPC: true, - HandleGRPCWeb: true, BufferPool: newBufferPool(), StreamType: streamType, } @@ -314,12 +310,10 @@ func (c *handlerConfig) newSpec() Spec { } func (c *handlerConfig) newProtocolHandlers() []protocolHandler { - protocols := []protocol{&protocolConnect{}} - if c.HandleGRPC { - protocols = append(protocols, &protocolGRPC{web: false}) - } - if c.HandleGRPCWeb { - protocols = append(protocols, &protocolGRPC{web: true}) + protocols := []protocol{ + &protocolConnect{}, + &protocolGRPC{web: false}, + &protocolGRPC{web: true}, } handlers := make([]protocolHandler, 0, len(protocols)) codecs := newReadOnlyCodecs(c.Codecs)