Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(server/v2): use net/http for matching logic in auto-gateway #23390

Merged
merged 12 commits into from
Jan 15, 2025
11 changes: 3 additions & 8 deletions server/v2/api/grpcgateway/doc.go
Original file line number Diff line number Diff line change
@@ -1,11 +1,6 @@
// Package grpcgateway provides a custom http mux that utilizes the global gogoproto registry to match
// grpc gateway requests to query handlers. POST requests with JSON bodies and GET requests with query params are supported.
// Wildcard endpoints (i.e. foo/bar/{baz}), as well as catch-all endpoints (i.e. foo/bar/{baz=**} are supported. Using
// header `x-cosmos-block-height` allows you to specify a height for the query.
// Package grpcgateway utilizes the global gogoproto registry to create dynamic query handlers on net/http's mux router.
//
// The URL matching logic is achieved by building regular expressions from the gateway HTTP annotations. These regular expressions
// are then used to match against incoming requests to the HTTP server.
// Header `x-cosmos-block-height` allows you to specify a height for the query.
//
// In cases where the custom http mux is unable to handle the query (i.e. no match found), the request will fall back to the
// ServeMux from github.com/grpc-ecosystem/grpc-gateway/runtime.
// Requests that do not have a dynamic handler registered will be routed to the canonical gRPC gateway mux.
package grpcgateway
302 changes: 302 additions & 0 deletions server/v2/api/grpcgateway/handler.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,302 @@
package grpcgateway

import (
"bytes"
"errors"
"fmt"
"io"
"maps"
"net/http"
"reflect"
"regexp"
"slices"
"strconv"
"strings"

gogoproto "github.com/cosmos/gogoproto/proto"
"github.com/grpc-ecosystem/grpc-gateway/runtime"
"github.com/grpc-ecosystem/grpc-gateway/utilities"
"google.golang.org/genproto/googleapis/api/annotations"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/reflect/protoreflect"

"cosmossdk.io/core/transaction"
"cosmossdk.io/log"
"cosmossdk.io/server/v2/appmanager"
)

const MaxBodySize = 1 << 20 // 1 MB

var (
_ http.Handler = &protoHandler[transaction.Tx]{}

wildcardRegex = regexp.MustCompile(`\{([^}]*)\}`)
)

// queryMetadata holds information related to handling gateway queries.
type queryMetadata struct {
// queryInputProtoName is the proto name of the query's input type.
msg gogoproto.Message
// wildcardKeyNames are the wildcard key names from the query's HTTP annotation.
// for example /foo/bar/{baz}/{qux} would produce []string{"baz", "qux"}
// this is used for building the query's path parameter map.
wildcardKeyNames []string
}

// registerGatewayToMux registers handlers for grpc gateway annotations to the httpMux.
func registerGatewayToMux[T transaction.Tx](logger log.Logger, httpMux *http.ServeMux, gateway *runtime.ServeMux, am appmanager.AppManager[T]) error {
annotationMapping, err := newHTTPAnnotationMapping()
if err != nil {
return err
}
annotationToMetadata, err := annotationsToQueryMetadata(annotationMapping)
if err != nil {
return err
}
registerMethods[T](logger, httpMux, am, gateway, annotationToMetadata)
return nil
}

// registerMethods registers the endpoints specified in the annotation mapping to the mux.
func registerMethods[T transaction.Tx](logger log.Logger, mux *http.ServeMux, am appmanager.AppManager[T], gateway *runtime.ServeMux, annotationToMetadata map[string]queryMetadata) {
// register the fallback handler. this will run if the mux isn't able to get a match from the registrations below.
mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
gateway.ServeHTTP(w, r)
})

// register in deterministic order. we do this because of the problem mentioned below, and different nodes could
// end up with one version of the handler or the other.
uris := slices.Sorted(maps.Keys(annotationToMetadata))

for _, uri := range uris {
queryMD := annotationToMetadata[uri]
// we need to wrap this in a panic handler because cosmos SDK proto stubs contains a duplicate annotation
// that causes the registration to panic.
func(u string, qMD queryMetadata) {
defer func() {
if err := recover(); err != nil {
logger.Warn("duplicate HTTP annotation detected", "error", err)
}
}()
mux.Handle(u, &protoHandler[T]{
msg: qMD.msg,
gateway: gateway,
appManager: am,
wildcardKeyNames: qMD.wildcardKeyNames,
})
}(uri, queryMD)
}
}

// protoHandler handles turning data in http.Request to the gogoproto.Message
type protoHandler[T transaction.Tx] struct {
// msg is the gogoproto message type.
msg gogoproto.Message
// wildcardKeyNames are the wildcard key names, if any, specified in the http annotation. (i.e. /foo/bar/{baz})
wildcardKeyNames []string
// gateway is the canonical gateway ServeMux to use as a fallback if the query does not have a handler in AppManager.
gateway *runtime.ServeMux
// appManager is used to route queries.
appManager appmanager.AppManager[T]
}

func (p *protoHandler[T]) ServeHTTP(writer http.ResponseWriter, request *http.Request) {
in, out := runtime.MarshalerForRequest(p.gateway, request)

// we clone here as handlers are concurrent and using p.msg would trample.
msg := gogoproto.Clone(p.msg)

// extract path parameters.
params := make(map[string]string)
for _, wildcardKeyName := range p.wildcardKeyNames {
params[wildcardKeyName] = request.PathValue(wildcardKeyName)
}

inputMsg, err := p.populateMessage(request, in, msg, params)
if err != nil {
// the errors returned from the message creation return status errors. no need to make one here.
runtime.HTTPError(request.Context(), p.gateway, out, writer, request, err)
return
}

// get the height from the header.
var height uint64
heightStr := request.Header.Get(GRPCBlockHeightHeader)
heightStr = strings.Trim(heightStr, `\"`)
if heightStr != "" && heightStr != "latest" {
height, err = strconv.ParseUint(heightStr, 10, 64)
if err != nil {
runtime.HTTPError(request.Context(), p.gateway, out, writer, request, status.Errorf(codes.InvalidArgument, "invalid height in header: %s", heightStr))
return
}
}

responseMsg, err := p.appManager.Query(request.Context(), height, inputMsg)
if err != nil {
// if we couldn't find a handler for this request, just fall back to the gateway mux.
if strings.Contains(err.Error(), "no handler") {
p.gateway.ServeHTTP(writer, request)
} else {
// for all other errors, we just return the error.
runtime.HTTPError(request.Context(), p.gateway, out, writer, request, err)
}
return
}

runtime.ForwardResponseMessage(request.Context(), p.gateway, out, writer, request, responseMsg)
}

func (p *protoHandler[T]) populateMessage(req *http.Request, marshaler runtime.Marshaler, input gogoproto.Message, pathParams map[string]string) (gogoproto.Message, error) {
// see if we have path params to populate the message with.
if len(pathParams) > 0 {
for pathKey, pathValue := range pathParams {
if err := runtime.PopulateFieldFromPath(input, pathKey, pathValue); err != nil {
return nil, status.Error(codes.InvalidArgument, fmt.Errorf("failed to populate field %s with value %s: %w", pathKey, pathValue, err).Error())
}
}
}

// handle query parameters.
if err := req.ParseForm(); err != nil {
return nil, status.Errorf(codes.InvalidArgument, "%v", err)
}

filter := filterFromPathParams(pathParams)
err := runtime.PopulateQueryParameters(input, req.Form, filter)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "%v", err)
}

// see if we have a body to unmarshal.
if req.ContentLength > 0 {
if req.ContentLength > MaxBodySize {
return nil, status.Errorf(codes.InvalidArgument, "request body too large: %d bytes, max=%d", req.ContentLength, MaxBodySize)
}

// this block of code ensures that the body can be re-read. this is needed as if the query fails in the
// app's query handler, we need to pass the request back to the canonical gateway, which needs to be able to
// read the body again.
bodyBytes, err := io.ReadAll(req.Body)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "%v", err)
}
req.Body = io.NopCloser(bytes.NewReader(bodyBytes))

if err = marshaler.NewDecoder(bytes.NewReader(bodyBytes)).Decode(input); err != nil && !errors.Is(err, io.EOF) {
return nil, status.Errorf(codes.InvalidArgument, "%v", err)
}
}

return input, nil
}

func filterFromPathParams(pathParams map[string]string) *utilities.DoubleArray {
var prefixPaths [][]string

for k := range pathParams {
prefixPaths = append(prefixPaths, []string{k})
}

return utilities.NewDoubleArray(prefixPaths)
}

// newHTTPAnnotationMapping returns a mapping of RPC Method HTTP GET annotation to the RPC Handler's Request Input type full name.
//
// example: "/cosmos/auth/v1beta1/account_info/{address}":"cosmos.auth.v1beta1.Query.AccountInfo"
func newHTTPAnnotationMapping() (map[string]string, error) {
protoFiles, err := gogoproto.MergedRegistry()
if err != nil {
return nil, err
}

annotationToQueryInputName := make(map[string]string)
protoFiles.RangeFiles(func(fd protoreflect.FileDescriptor) bool {
for i := 0; i < fd.Services().Len(); i++ {
serviceDesc := fd.Services().Get(i)
for j := 0; j < serviceDesc.Methods().Len(); j++ {
methodDesc := serviceDesc.Methods().Get(j)
httpExtension := proto.GetExtension(methodDesc.Options(), annotations.E_Http)
if httpExtension == nil {
continue
}

httpRule, ok := httpExtension.(*annotations.HttpRule)
if !ok || httpRule == nil {
continue
}
queryInputName := string(methodDesc.Input().FullName())
httpRules := append(httpRule.GetAdditionalBindings(), httpRule)
for _, rule := range httpRules {
if httpAnnotation := rule.GetGet(); httpAnnotation != "" {
annotationToQueryInputName[fixCatchAll(httpAnnotation)] = queryInputName
}
if httpAnnotation := rule.GetPost(); httpAnnotation != "" {
annotationToQueryInputName[fixCatchAll(httpAnnotation)] = queryInputName
}
if httpAnnotation := rule.GetPut(); httpAnnotation != "" {
annotationToQueryInputName[fixCatchAll(httpAnnotation)] = queryInputName
}
if httpAnnotation := rule.GetPatch(); httpAnnotation != "" {
annotationToQueryInputName[fixCatchAll(httpAnnotation)] = queryInputName
}
if httpAnnotation := rule.GetDelete(); httpAnnotation != "" {
annotationToQueryInputName[fixCatchAll(httpAnnotation)] = queryInputName
}
}
}
}
return true
})
return annotationToQueryInputName, nil
}

var catchAllRegex = regexp.MustCompile(`\{([^=]+)=\*\*\}`)

// fixCatchAll replaces grpc gateway catch all syntax with net/http syntax.
//
// {foo=**} -> {foo...}
func fixCatchAll(uri string) string {
return catchAllRegex.ReplaceAllString(uri, `{$1...}`)
}

// annotationsToQueryMetadata takes annotations and creates a mapping of URIs to queryMetadata.
func annotationsToQueryMetadata(annotations map[string]string) (map[string]queryMetadata, error) {
annotationToMetadata := make(map[string]queryMetadata)
for uri, queryInputName := range annotations {
// extract the proto message type.
msgType := gogoproto.MessageType(queryInputName)
if msgType == nil {
continue
}
msg, ok := reflect.New(msgType.Elem()).Interface().(gogoproto.Message)
if !ok {
return nil, fmt.Errorf("query input type %q does not implement gogoproto.Message", queryInputName)
}
annotationToMetadata[uri] = queryMetadata{
msg: msg,
wildcardKeyNames: extractWildcardKeyNames(uri),
}
}
return annotationToMetadata, nil
}

// extractWildcardKeyNames extracts the wildcard key names from the uri annotation.
//
// example:
// "/hello/{world}" -> []string{"world"}
// "/hello/{world}/and/{friends} -> []string{"world", "friends"}
// "/hello/world" -> []string{}
func extractWildcardKeyNames(uri string) []string {
matches := wildcardRegex.FindAllStringSubmatch(uri, -1)
var extracted []string
for _, match := range matches {
// match[0] is the full string including braces (i.e. "{bar}")
// match[1] is the captured group (i.e. "bar")
// we also need to handle the catch-all case with URI's like "bar..." and
// transform them to just "bar".
extracted = append(extracted, strings.TrimRight(match[1], "."))
}
return extracted
}
Loading
Loading