Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add protocol helpers to infer procedure and type #756

Closed
wants to merge 8 commits into from
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 32 additions & 28 deletions error_writer.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,34 +61,7 @@ func NewErrorWriter(opts ...HandlerOption) *ErrorWriter {
}

func (w *ErrorWriter) classifyRequest(request *http.Request) protocolType {
ctype := canonicalizeContentType(getHeaderCanonical(request.Header, headerContentType))
isPost := request.Method == http.MethodPost
isGet := request.Method == http.MethodGet
switch {
case isPost && (ctype == grpcContentTypeDefault || strings.HasPrefix(ctype, grpcContentTypePrefix)):
return grpcProtocol
case isPost && (ctype == grpcWebContentTypeDefault || strings.HasPrefix(ctype, grpcWebContentTypePrefix)):
return grpcWebProtocol
case isPost && strings.HasPrefix(ctype, connectStreamingContentTypePrefix):
// Streaming ignores the requireConnectProtocolHeader option as the
// Content-Type is enough to determine the protocol.
if err := connectCheckProtocolVersion(request, false /* required */); err != nil {
return unknownProtocol
}
return connectStreamProtocol
case isPost && strings.HasPrefix(ctype, connectUnaryContentTypePrefix):
if err := connectCheckProtocolVersion(request, w.requireConnectProtocolHeader); err != nil {
return unknownProtocol
}
return connectUnaryProtocol
case isGet:
if err := connectCheckProtocolVersion(request, w.requireConnectProtocolHeader); err != nil {
return unknownProtocol
}
return connectUnaryProtocol
default:
return unknownProtocol
}
return classifyRequest(request, w.requireConnectProtocolHeader)
}

// IsSupported checks whether a request is using one of the ErrorWriter's
Expand Down Expand Up @@ -177,3 +150,34 @@ func (w *ErrorWriter) writeGRPCWeb(response http.ResponseWriter, err error) erro
response.WriteHeader(http.StatusOK)
return nil
}

func classifyRequest(request *http.Request, requireConnectProtocolHeader bool) protocolType {
ctype := canonicalizeContentType(getHeaderCanonical(request.Header, headerContentType))
isPost := request.Method == http.MethodPost
isGet := request.Method == http.MethodGet
switch {
case isPost && (ctype == grpcContentTypeDefault || strings.HasPrefix(ctype, grpcContentTypePrefix)):
return grpcProtocol
case isPost && (ctype == grpcWebContentTypeDefault || strings.HasPrefix(ctype, grpcWebContentTypePrefix)):
return grpcWebProtocol
case isPost && strings.HasPrefix(ctype, connectStreamingContentTypePrefix):
// Streaming ignores the requireConnectProtocolHeader option as the
// Content-Type is enough to determine the protocol.
if err := connectCheckProtocolVersion(request, false /* required */); err != nil {
return unknownProtocol
}
return connectStreamProtocol
case isPost && strings.HasPrefix(ctype, connectUnaryContentTypePrefix):
if err := connectCheckProtocolVersion(request, requireConnectProtocolHeader); err != nil {
return unknownProtocol
}
return connectUnaryProtocol
case isGet:
if err := connectCheckProtocolVersion(request, requireConnectProtocolHeader); err != nil {
return unknownProtocol
}
return connectUnaryProtocol
default:
return unknownProtocol
}
}
39 changes: 39 additions & 0 deletions protocol.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,45 @@ const (

var errNoTimeout = errors.New("no timeout")

// ProtocolFromRequest returns the inferred protocol name for parsing an
// HTTP request. It inspects the request's method and headers to determine the
// protocol. If the request doesn't match any known protocol, an empty string
// is returned.
func ProtocolFromRequest(request *http.Request) (string, bool) {
Copy link
Member

@jhump jhump Aug 5, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've been thinking about if we'd want to ever return something more structured/richer -- like a Protocol type that can inform the caller (for example) if it's connect streaming vs. connect unary or even the version (in the event we ever add a v2 of the protocol), maybe the codec name, etc.

I suppose we can leave this as is for now. In Spec, we'd have to add another accessor -- like maybe ProtocolDetails, and maybe we could then deprecate this and add ProtocolDetailsFromRequest at that time. WDYT?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The stream is needed by the error writer so it would be nice to expose. It doesn't translate nicely to the grpc or grpc web streaming where we need the descriptor to determine the stream type. These methods are useful, but not that useful so it's annoying for them to clutter the API, they should only be used if the Spec isn't available (maybe that needs to be in the description to avoid misuse?).

My preference would be to move this back to authn-go until we comeup with a better implementation. Maybe we could have a http library in the future that the connect package can utilize with more of these helper methods.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My preference would be to move this back to authn-go until we comeup with a better implementation.

It would be nice to have some canonical implementation that we can also use from vanguard. Admittedly, in the past, we (with @akshayjshah) discussed a possible connectrpc/protocol-go repo where we could put lower-level details of the protocol like this, and connectrpc/connect-go (et al) could then import that. I'm a little concerned with putting it in authn-go if we expect to move it out later. But I guess it's okay as long as we have this figured out before we get to a v1.0 of authn-go 🤷.

switch classifyRequest(request, false) {
case connectUnaryProtocol, connectStreamProtocol:
return ProtocolConnect, true
case grpcProtocol:
return ProtocolGRPC, true
case grpcWebProtocol:
return ProtocolGRPCWeb, true
case unknownProtocol:
return "", false
default:
return "", false
}
}

// ProcedureFromURL returns the inferred procedure name from a URL. It's
// returned in the form "/service/method" if a valid suffix is found. If the
// path doesn't contain a service and method, the entire path is returned.
func ProcedureFromURL(url *url.URL) (string, bool) {
emcfarlane marked this conversation as resolved.
Show resolved Hide resolved
path := strings.TrimSuffix(url.Path, "/")
ultimate := strings.LastIndex(path, "/")
if ultimate < 0 {
return url.Path, false
}
penultimate := strings.LastIndex(path[:ultimate], "/")
if penultimate < 0 {
return url.Path, false
}
procedure := path[penultimate:]
if len(procedure) < 4 { // two slashes + service + method
emcfarlane marked this conversation as resolved.
Show resolved Hide resolved
return url.Path, false
}
return procedure, true
}

// A Protocol defines the HTTP semantics to use when sending and receiving
// messages. It ties together codecs, compressors, and net/http to produce
// Senders and Receivers.
Expand Down
34 changes: 34 additions & 0 deletions protocol_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
package connect

import (
"net/url"
"testing"

"connectrpc.com/connect/internal/assert"
Expand Down Expand Up @@ -63,3 +64,36 @@ func BenchmarkCanonicalizeContentType(b *testing.B) {
b.ReportAllocs()
})
}

func TestProcedureFromURL(t *testing.T) {
t.Parallel()
tests := []struct {
name string
url string
want string
}{
{name: "simple", url: "http://localhost:8080/foo", want: "/foo"},
{name: "service", url: "http://localhost:8080/service/bar", want: "/service/bar"},
{name: "trailing", url: "http://localhost:8080/service/bar/", want: "/service/bar"},
{name: "subroute", url: "http://localhost:8080/api/service/bar/", want: "/service/bar"},
{name: "subrouteTrailing", url: "http://localhost:8080/api/service/bar/", want: "/service/bar"},
{
name: "real",
url: "http://localhost:8080/connect.ping.v1.PingService/Ping",
want: "/connect.ping.v1.PingService/Ping",
},
}
for _, testcase := range tests {
testcase := testcase
t.Run(testcase.name, func(t *testing.T) {
t.Parallel()
url, err := url.Parse(testcase.url)
if !assert.Nil(t, err) {
return
}
t.Log(url.String())
got, _ := ProcedureFromURL(url)
assert.Equal(t, got, testcase.want)
})
}
}