From 81a58f893e9e0170031e7e7eb059d081b67cf735 Mon Sep 17 00:00:00 2001 From: Edward McFarlane Date: Tue, 25 Jun 2024 10:46:51 -0400 Subject: [PATCH 1/7] Add protocol helpers to infer procedure and type Two new methods are added to allow for inferring the procedure and protocol type of a request. These are provided to be used with http middleware to deduce information about the requests. For example, authentication middleware may wish to block on certain protocols or to conditionally allow routes. Signed-off-by: Edward McFarlane --- error_writer.go | 60 ++++++++++++++++++++++++++---------------------- protocol.go | 39 +++++++++++++++++++++++++++++++ protocol_test.go | 33 ++++++++++++++++++++++++++ 3 files changed, 104 insertions(+), 28 deletions(-) diff --git a/error_writer.go b/error_writer.go index f05d19ec..fd50168b 100644 --- a/error_writer.go +++ b/error_writer.go @@ -61,34 +61,7 @@ func NewErrorWriter(opts ...HandlerOption) *ErrorWriter { } func (w *ErrorWriter) classifyRequest(request *http.Request) protocolType { - ctype := canonicalizeContentType(getHeaderCanonical(request.Header, headerContentType)) - isPost := request.Method == http.MethodPost - isGet := request.Method == http.MethodGet - switch { - case isPost && (ctype == grpcContentTypeDefault || strings.HasPrefix(ctype, grpcContentTypePrefix)): - return grpcProtocol - case isPost && (ctype == grpcWebContentTypeDefault || strings.HasPrefix(ctype, grpcWebContentTypePrefix)): - return grpcWebProtocol - case isPost && strings.HasPrefix(ctype, connectStreamingContentTypePrefix): - // Streaming ignores the requireConnectProtocolHeader option as the - // Content-Type is enough to determine the protocol. - if err := connectCheckProtocolVersion(request, false /* required */); err != nil { - return unknownProtocol - } - return connectStreamProtocol - case isPost && strings.HasPrefix(ctype, connectUnaryContentTypePrefix): - if err := connectCheckProtocolVersion(request, w.requireConnectProtocolHeader); err != nil { - return unknownProtocol - } - return connectUnaryProtocol - case isGet: - if err := connectCheckProtocolVersion(request, w.requireConnectProtocolHeader); err != nil { - return unknownProtocol - } - return connectUnaryProtocol - default: - return unknownProtocol - } + return classifyRequest(request, w.requireConnectProtocolHeader) } // IsSupported checks whether a request is using one of the ErrorWriter's @@ -177,3 +150,34 @@ func (w *ErrorWriter) writeGRPCWeb(response http.ResponseWriter, err error) erro response.WriteHeader(http.StatusOK) return nil } + +func classifyRequest(request *http.Request, requireConnectProtocolHeader bool) protocolType { + ctype := canonicalizeContentType(getHeaderCanonical(request.Header, headerContentType)) + isPost := request.Method == http.MethodPost + isGet := request.Method == http.MethodGet + switch { + case isPost && (ctype == grpcContentTypeDefault || strings.HasPrefix(ctype, grpcContentTypePrefix)): + return grpcProtocol + case isPost && (ctype == grpcWebContentTypeDefault || strings.HasPrefix(ctype, grpcWebContentTypePrefix)): + return grpcWebProtocol + case isPost && strings.HasPrefix(ctype, connectStreamingContentTypePrefix): + // Streaming ignores the requireConnectProtocolHeader option as the + // Content-Type is enough to determine the protocol. + if err := connectCheckProtocolVersion(request, false /* required */); err != nil { + return unknownProtocol + } + return connectStreamProtocol + case isPost && strings.HasPrefix(ctype, connectUnaryContentTypePrefix): + if err := connectCheckProtocolVersion(request, requireConnectProtocolHeader); err != nil { + return unknownProtocol + } + return connectUnaryProtocol + case isGet: + if err := connectCheckProtocolVersion(request, requireConnectProtocolHeader); err != nil { + return unknownProtocol + } + return connectUnaryProtocol + default: + return unknownProtocol + } +} diff --git a/protocol.go b/protocol.go index 9add614c..5e5d8804 100644 --- a/protocol.go +++ b/protocol.go @@ -48,6 +48,45 @@ const ( var errNoTimeout = errors.New("no timeout") +// InferProtocolFromRequest returns the inferred protocol name for parsing an +// HTTP request. It inspects the request's method and headers to determine the +// protocol. If the request doesn't match any known protocol, an empty string +// is returned. +func InferProtocolFromRequest(request *http.Request) string { + switch classifyRequest(request, false) { + case connectUnaryProtocol, connectStreamProtocol: + return ProtocolConnect + case grpcProtocol: + return ProtocolGRPC + case grpcWebProtocol: + return ProtocolGRPCWeb + case unknownProtocol: + return "" + default: + return "" + } +} + +// InferProcedureFromURL returns the inferred procedure name from a URL. It's +// returned in the form "/service/method" if a valid suffix is found. If the +// path doesn't contain a service and method, the entire path is returned. +func InferProcedureFromURL(url *url.URL) string { + path := strings.TrimSuffix(url.Path, "/") + ultimate := strings.LastIndex(path, "/") + if ultimate < 0 { + return url.Path + } + penultimate := strings.LastIndex(path[:ultimate], "/") + if penultimate < 0 { + return url.Path + } + procedure := path[penultimate:] + if len(procedure) < 4 { // two slashes + service + method + return url.Path + } + return procedure +} + // A Protocol defines the HTTP semantics to use when sending and receiving // messages. It ties together codecs, compressors, and net/http to produce // Senders and Receivers. diff --git a/protocol_test.go b/protocol_test.go index f35fa0ac..22d8b6a7 100644 --- a/protocol_test.go +++ b/protocol_test.go @@ -15,6 +15,7 @@ package connect import ( + "net/url" "testing" "connectrpc.com/connect/internal/assert" @@ -63,3 +64,35 @@ func BenchmarkCanonicalizeContentType(b *testing.B) { b.ReportAllocs() }) } + +func TestProcedureFromURL(t *testing.T) { + t.Parallel() + tests := []struct { + name string + url string + want string + }{ + {name: "simple", url: "http://localhost:8080/foo", want: "/foo"}, + {name: "service", url: "http://localhost:8080/service/bar", want: "/service/bar"}, + {name: "trailing", url: "http://localhost:8080/service/bar/", want: "/service/bar"}, + {name: "subroute", url: "http://localhost:8080/api/service/bar/", want: "/service/bar"}, + {name: "subrouteTrailing", url: "http://localhost:8080/api/service/bar/", want: "/service/bar"}, + { + name: "real", + url: "http://localhost:8080/connect.ping.v1.PingService/Ping", + want: "/connect.ping.v1.PingService/Ping", + }, + } + for _, testcase := range tests { + testcase := testcase + t.Run(testcase.name, func(t *testing.T) { + t.Parallel() + url, err := url.Parse(testcase.url) + if !assert.Nil(t, err) { + return + } + t.Log(url.String()) + assert.Equal(t, InferProcedureFromURL(url), testcase.want) + }) + } +} From 2004b2e665e0e2d5e4e34a1977c28ca3b688f4b0 Mon Sep 17 00:00:00 2001 From: Edward McFarlane Date: Tue, 2 Jul 2024 14:43:33 -0400 Subject: [PATCH 2/7] Return bool for indicitaing validity Signed-off-by: Edward McFarlane --- protocol.go | 26 +++++++++++++------------- protocol_test.go | 3 ++- 2 files changed, 15 insertions(+), 14 deletions(-) diff --git a/protocol.go b/protocol.go index 5e5d8804..b4e0cc6f 100644 --- a/protocol.go +++ b/protocol.go @@ -48,43 +48,43 @@ const ( var errNoTimeout = errors.New("no timeout") -// InferProtocolFromRequest returns the inferred protocol name for parsing an +// ProtocolFromRequest returns the inferred protocol name for parsing an // HTTP request. It inspects the request's method and headers to determine the // protocol. If the request doesn't match any known protocol, an empty string // is returned. -func InferProtocolFromRequest(request *http.Request) string { +func ProtocolFromRequest(request *http.Request) (string, bool) { switch classifyRequest(request, false) { case connectUnaryProtocol, connectStreamProtocol: - return ProtocolConnect + return ProtocolConnect, true case grpcProtocol: - return ProtocolGRPC + return ProtocolGRPC, true case grpcWebProtocol: - return ProtocolGRPCWeb + return ProtocolGRPCWeb, true case unknownProtocol: - return "" + return "", false default: - return "" + return "", false } } -// InferProcedureFromURL returns the inferred procedure name from a URL. It's +// ProcedureFromURL returns the inferred procedure name from a URL. It's // returned in the form "/service/method" if a valid suffix is found. If the // path doesn't contain a service and method, the entire path is returned. -func InferProcedureFromURL(url *url.URL) string { +func ProcedureFromURL(url *url.URL) (string, bool) { path := strings.TrimSuffix(url.Path, "/") ultimate := strings.LastIndex(path, "/") if ultimate < 0 { - return url.Path + return url.Path, false } penultimate := strings.LastIndex(path[:ultimate], "/") if penultimate < 0 { - return url.Path + return url.Path, false } procedure := path[penultimate:] if len(procedure) < 4 { // two slashes + service + method - return url.Path + return url.Path, false } - return procedure + return procedure, true } // A Protocol defines the HTTP semantics to use when sending and receiving diff --git a/protocol_test.go b/protocol_test.go index 22d8b6a7..e733b901 100644 --- a/protocol_test.go +++ b/protocol_test.go @@ -92,7 +92,8 @@ func TestProcedureFromURL(t *testing.T) { return } t.Log(url.String()) - assert.Equal(t, InferProcedureFromURL(url), testcase.want) + got, _ := ProcedureFromURL(url) + assert.Equal(t, got, testcase.want) }) } } From 74290d841c9121cb023d7ad80c59a030ad75aca4 Mon Sep 17 00:00:00 2001 From: Edward McFarlane Date: Mon, 29 Jul 2024 17:49:14 +0200 Subject: [PATCH 3/7] Fix validation of empty method or service name Signed-off-by: Edward McFarlane --- protocol.go | 10 ++++++---- protocol_test.go | 24 ++++++++++++++---------- 2 files changed, 20 insertions(+), 14 deletions(-) diff --git a/protocol.go b/protocol.go index b4e0cc6f..8e20f9a7 100644 --- a/protocol.go +++ b/protocol.go @@ -51,7 +51,7 @@ var errNoTimeout = errors.New("no timeout") // ProtocolFromRequest returns the inferred protocol name for parsing an // HTTP request. It inspects the request's method and headers to determine the // protocol. If the request doesn't match any known protocol, an empty string -// is returned. +// and false is returned. func ProtocolFromRequest(request *http.Request) (string, bool) { switch classifyRequest(request, false) { case connectUnaryProtocol, connectStreamProtocol: @@ -69,7 +69,8 @@ func ProtocolFromRequest(request *http.Request) (string, bool) { // ProcedureFromURL returns the inferred procedure name from a URL. It's // returned in the form "/service/method" if a valid suffix is found. If the -// path doesn't contain a service and method, the entire path is returned. +// path doesn't contain a service and method, the entire path and false is +// returned. func ProcedureFromURL(url *url.URL) (string, bool) { path := strings.TrimSuffix(url.Path, "/") ultimate := strings.LastIndex(path, "/") @@ -81,10 +82,11 @@ func ProcedureFromURL(url *url.URL) (string, bool) { return url.Path, false } procedure := path[penultimate:] - if len(procedure) < 4 { // two slashes + service + method + // Ensure that the service and method are non-empty. + if ultimate == len(path)-1 || penultimate == ultimate-1 { return url.Path, false } - return procedure, true + return procedure, false } // A Protocol defines the HTTP semantics to use when sending and receiving diff --git a/protocol_test.go b/protocol_test.go index e733b901..7bacea43 100644 --- a/protocol_test.go +++ b/protocol_test.go @@ -68,19 +68,23 @@ func BenchmarkCanonicalizeContentType(b *testing.B) { func TestProcedureFromURL(t *testing.T) { t.Parallel() tests := []struct { - name string - url string - want string + name string + url string + want string + valid bool }{ {name: "simple", url: "http://localhost:8080/foo", want: "/foo"}, - {name: "service", url: "http://localhost:8080/service/bar", want: "/service/bar"}, - {name: "trailing", url: "http://localhost:8080/service/bar/", want: "/service/bar"}, - {name: "subroute", url: "http://localhost:8080/api/service/bar/", want: "/service/bar"}, - {name: "subrouteTrailing", url: "http://localhost:8080/api/service/bar/", want: "/service/bar"}, + {name: "service", url: "http://localhost:8080/service/bar", want: "/service/bar", valid: true}, + {name: "trailing", url: "http://localhost:8080/service/bar/", want: "/service/bar", valid: true}, + {name: "subroute", url: "http://localhost:8080/api/service/bar/", want: "/service/bar", valid: true}, + {name: "subrouteTrailing", url: "http://localhost:8080/api/service/bar/", want: "/service/bar", valid: true}, + {name: "missingService", url: "http://localhost:8080//foo", want: "//foo"}, + {name: "missingMethod", url: "http://localhost:8080/foo//", want: "/foo//"}, { - name: "real", - url: "http://localhost:8080/connect.ping.v1.PingService/Ping", - want: "/connect.ping.v1.PingService/Ping", + name: "real", + url: "http://localhost:8080/connect.ping.v1.PingService/Ping", + want: "/connect.ping.v1.PingService/Ping", + valid: true, }, } for _, testcase := range tests { From 8810d3757d4c2e82713113e46808201baeb1c435 Mon Sep 17 00:00:00 2001 From: Edward McFarlane Date: Tue, 6 Aug 2024 21:46:50 +0200 Subject: [PATCH 4/7] Fix testcases Signed-off-by: Edward McFarlane --- protocol.go | 2 +- protocol_test.go | 77 +++++++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 77 insertions(+), 2 deletions(-) diff --git a/protocol.go b/protocol.go index 8e20f9a7..f2ef8444 100644 --- a/protocol.go +++ b/protocol.go @@ -86,7 +86,7 @@ func ProcedureFromURL(url *url.URL) (string, bool) { if ultimate == len(path)-1 || penultimate == ultimate-1 { return url.Path, false } - return procedure, false + return procedure, true } // A Protocol defines the HTTP semantics to use when sending and receiving diff --git a/protocol_test.go b/protocol_test.go index 7bacea43..fe381dc9 100644 --- a/protocol_test.go +++ b/protocol_test.go @@ -15,6 +15,8 @@ package connect import ( + "net/http" + "net/http/httptest" "net/url" "testing" @@ -65,6 +67,78 @@ func BenchmarkCanonicalizeContentType(b *testing.B) { }) } +func TestProtocolFromRequest(t *testing.T) { + t.Parallel() + tests := []struct { + name string + contentType string + method string + want string + valid bool + }{{ + name: "connectUnary", + contentType: "application/json", + method: http.MethodPost, + want: ProtocolConnect, + valid: true, + }, { + name: "connectStreaming", + contentType: "application/connec+json", + method: http.MethodPost, + want: ProtocolConnect, + valid: true, + }, { + name: "grpcWeb", + contentType: "application/grpc-web", + method: http.MethodPost, + want: ProtocolGRPCWeb, + valid: true, + }, { + name: "grpc", + contentType: "application/grpc", + method: http.MethodPost, + want: ProtocolGRPC, + valid: true, + }, { + name: "connectGet", + contentType: "application/connec+json", + method: http.MethodGet, + want: ProtocolConnect, + valid: true, + }, { + name: "grpcWebGet", + contentType: "application/grpc-web", + method: http.MethodGet, + want: ProtocolConnect, + valid: true, + }, { + name: "grpcGet", + contentType: "application/grpc+json", + method: http.MethodGet, + want: ProtocolConnect, + valid: true, + }, { + name: "unknown", + contentType: "text/html", + method: http.MethodPost, + valid: false, + }} + for _, testcase := range tests { + testcase := testcase + t.Run(testcase.name, func(t *testing.T) { + t.Parallel() + req := httptest.NewRequest(testcase.method, "http://localhost:8080/service/Method", nil) + if testcase.contentType != "" { + req.Header.Set("Content-Type", testcase.contentType) + } + req.Method = testcase.method + got, valid := ProtocolFromRequest(req) + assert.Equal(t, got, testcase.want, assert.Sprintf("protocol")) + assert.Equal(t, valid, testcase.valid, assert.Sprintf("valid")) + }) + } +} + func TestProcedureFromURL(t *testing.T) { t.Parallel() tests := []struct { @@ -96,8 +170,9 @@ func TestProcedureFromURL(t *testing.T) { return } t.Log(url.String()) - got, _ := ProcedureFromURL(url) + got, valid := ProcedureFromURL(url) assert.Equal(t, got, testcase.want) + assert.Equal(t, valid, testcase.valid) }) } } From 5d6a6b4e2571afec49f94ddbbee1d0b1ce46ec2d Mon Sep 17 00:00:00 2001 From: Edward McFarlane Date: Tue, 3 Sep 2024 11:09:03 -0400 Subject: [PATCH 5/7] Check for connect GET query params --- error_writer.go | 5 +++++ error_writer_test.go | 13 +++++++++--- protocol_test.go | 49 +++++++++++++++++++++++++++++++------------- 3 files changed, 50 insertions(+), 17 deletions(-) diff --git a/error_writer.go b/error_writer.go index fd50168b..246d63ee 100644 --- a/error_writer.go +++ b/error_writer.go @@ -176,6 +176,11 @@ func classifyRequest(request *http.Request, requireConnectProtocolHeader bool) p if err := connectCheckProtocolVersion(request, requireConnectProtocolHeader); err != nil { return unknownProtocol } + // Check for Connect required parameters. + params := request.URL.Query() + if !params.Has("message") || !params.Has("encoding") { + return unknownProtocol + } return connectUnaryProtocol default: return unknownProtocol diff --git a/error_writer_test.go b/error_writer_test.go index 913b5669..2b2c5f6d 100644 --- a/error_writer_test.go +++ b/error_writer_test.go @@ -17,6 +17,7 @@ package connect import ( "net/http" "net/http/httptest" + "net/url" "testing" "connectrpc.com/connect/internal/assert" @@ -37,9 +38,11 @@ func TestErrorWriter(t *testing.T) { t.Run("UnaryGET", func(t *testing.T) { req := httptest.NewRequest(http.MethodGet, "http://localhost", nil) assert.False(t, writer.IsSupported(req)) - query := req.URL.Query() - query.Set(connectUnaryConnectQueryParameter, connectUnaryConnectQueryValue) - req.URL.RawQuery = query.Encode() + req.URL.RawQuery = url.Values{ + connectUnaryConnectQueryParameter: []string{connectUnaryConnectQueryValue}, + connectUnaryEncodingQueryParameter: []string{"json"}, + connectUnaryMessageQueryParameter: []string{"{}"}, + }.Encode() assert.True(t, writer.IsSupported(req)) }) t.Run("Stream", func(t *testing.T) { @@ -60,6 +63,10 @@ func TestErrorWriter(t *testing.T) { }) t.Run("ConnectUnaryGET", func(t *testing.T) { req := httptest.NewRequest(http.MethodGet, "http://localhost", nil) + req.URL.RawQuery = url.Values{ + connectUnaryEncodingQueryParameter: []string{"json"}, + connectUnaryMessageQueryParameter: []string{"{}"}, + }.Encode() assert.True(t, writer.IsSupported(req)) }) t.Run("ConnectStream", func(t *testing.T) { diff --git a/protocol_test.go b/protocol_test.go index fe381dc9..1d74e498 100644 --- a/protocol_test.go +++ b/protocol_test.go @@ -73,6 +73,7 @@ func TestProtocolFromRequest(t *testing.T) { name string contentType string method string + params url.Values want string valid bool }{{ @@ -100,23 +101,40 @@ func TestProtocolFromRequest(t *testing.T) { want: ProtocolGRPC, valid: true, }, { - name: "connectGet", - contentType: "application/connec+json", - method: http.MethodGet, - want: ProtocolConnect, - valid: true, + name: "connectGet", + method: http.MethodGet, + params: url.Values{"message": []string{"{}"}, "encoding": []string{"json"}}, + want: ProtocolConnect, + valid: true, }, { - name: "grpcWebGet", - contentType: "application/grpc-web", - method: http.MethodGet, - want: ProtocolConnect, - valid: true, + name: "connectGetProto", + method: http.MethodGet, + params: url.Values{"message": []string{""}, "encoding": []string{"proto"}}, + want: ProtocolConnect, + valid: true, }, { - name: "grpcGet", - contentType: "application/grpc+json", + name: "connectGetMissingParams", + method: http.MethodGet, + valid: false, + }, { + name: "connectGetMissingParam-Message", + method: http.MethodGet, + params: url.Values{"encoding": []string{"json"}}, + valid: false, + }, { + name: "connectGetMissingParam-Encoding", + method: http.MethodGet, + params: url.Values{"message": []string{"{}"}}, + valid: false, + }, { + name: "connectGetContentType", + contentType: "application/connect+json", method: http.MethodGet, - want: ProtocolConnect, - valid: true, + valid: false, + }, { + name: "nakedGet", + method: http.MethodGet, + valid: false, }, { name: "unknown", contentType: "text/html", @@ -131,6 +149,9 @@ func TestProtocolFromRequest(t *testing.T) { if testcase.contentType != "" { req.Header.Set("Content-Type", testcase.contentType) } + if testcase.params != nil { + req.URL.RawQuery = testcase.params.Encode() + } req.Method = testcase.method got, valid := ProtocolFromRequest(req) assert.Equal(t, got, testcase.want, assert.Sprintf("protocol")) From 17f6741fca58a6375aae0d05913012d1ff96385f Mon Sep 17 00:00:00 2001 From: Edward McFarlane Date: Tue, 3 Sep 2024 11:12:19 -0400 Subject: [PATCH 6/7] Use constants --- error_writer.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/error_writer.go b/error_writer.go index 246d63ee..1ffe09d6 100644 --- a/error_writer.go +++ b/error_writer.go @@ -178,7 +178,7 @@ func classifyRequest(request *http.Request, requireConnectProtocolHeader bool) p } // Check for Connect required parameters. params := request.URL.Query() - if !params.Has("message") || !params.Has("encoding") { + if !params.Has(connectUnaryMessageQueryParameter) || !params.Has(connectUnaryEncodingQueryParameter) { return unknownProtocol } return connectUnaryProtocol From 38618d7c4f6c44b33848d16610b88a751330857a Mon Sep 17 00:00:00 2001 From: Edward McFarlane Date: Tue, 3 Sep 2024 11:20:54 -0400 Subject: [PATCH 7/7] Fix invalid test --- protocol_test.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/protocol_test.go b/protocol_test.go index 1d74e498..d4253913 100644 --- a/protocol_test.go +++ b/protocol_test.go @@ -127,9 +127,9 @@ func TestProtocolFromRequest(t *testing.T) { params: url.Values{"message": []string{"{}"}}, valid: false, }, { - name: "connectGetContentType", + name: "connectPutContentType", contentType: "application/connect+json", - method: http.MethodGet, + method: http.MethodPut, valid: false, }, { name: "nakedGet",