diff --git a/gengokit/httptransport/httptransport.go b/gengokit/httptransport/httptransport.go index b0d920f4..bee6d58e 100644 --- a/gengokit/httptransport/httptransport.go +++ b/gengokit/httptransport/httptransport.go @@ -6,6 +6,7 @@ import ( "bytes" "fmt" "go/format" + "regexp" "strconv" "strings" "text/template" @@ -71,7 +72,7 @@ func NewBinding(i int, meth *svcdef.ServiceMethod) *Binding { binding := meth.Bindings[i] nBinding := Binding{ Label: meth.Name + EnglishNumber(i), - PathTemplate: binding.Path, + PathTemplate: getMuxPathTemplate(binding.Path), BasePath: basePath(binding.Path), Verb: binding.Verb, } @@ -240,6 +241,12 @@ func (b *Binding) GenClientEncode() (string, error) { // "fmt.Sprint(req.A)", // } func (b *Binding) PathSections() []string { + path := b.PathTemplate + re := regexp.MustCompile(`{.+:.+}`) + path = re.ReplaceAllStringFunc(path, func(v string) string { + return strings.Split(v, ":")[0] + "}" + }) + isEnum := make(map[string]struct{}) for _, v := range b.Fields { if v.IsEnum { @@ -248,7 +255,7 @@ func (b *Binding) PathSections() []string { } rv := []string{} - parts := strings.Split(b.PathTemplate, "/") + parts := strings.Split(path, "/") for _, part := range parts { if len(part) > 2 && part[0] == '{' && part[len(part)-1] == '}' { name := RemoveBraces(part) @@ -490,6 +497,19 @@ func getZeroValue(f Field) string { } } +// getMuxPathTemplate translates gRPC Transcoding path into gorilla/mux +// compatible path template. +func getMuxPathTemplate(path string) string { + re := regexp.MustCompile(`{.+=.+}`) + stars := regexp.MustCompile(`\*{2,}`) + return re.ReplaceAllStringFunc(path, func(v string) string { + v = strings.Replace(v, "=", ":", 1) + v = stars.ReplaceAllLiteralString(v, `.+`) + v = strings.ReplaceAll(v, "*", `[^/]+`) + return v + }) +} + // The 'basePath' of a path is the section from the start of the string till // the first '{' character. func basePath(path string) string { diff --git a/gengokit/httptransport/httptransport_test.go b/gengokit/httptransport/httptransport_test.go index ff9b9246..322926b7 100644 --- a/gengokit/httptransport/httptransport_test.go +++ b/gengokit/httptransport/httptransport_test.go @@ -179,3 +179,87 @@ func TestLowCamelName(t *testing.T) { } } } + +func Test_getMuxPathTemplate(t *testing.T) { + tests := []struct { + name string + path string + want string + }{ + { + name: "no pattern", + path: "/v1/{parent}/books", + want: "/v1/{parent}/books", + }, + { + name: "no *", + path: "/v1/{parent=shelves}/books", + want: "/v1/{parent:shelves}/books", + }, + { + name: "single *", + path: "/v1/{parent=shelves/*}/books", + want: `/v1/{parent:shelves/[^/]+}/books`, + }, + { + name: "multiple *", + path: "/v1/{name=shelves/*/books/*}", + want: `/v1/{name:shelves/[^/]+/books/[^/]+}`, + }, + { + name: "**", + path: "/v1/shelves/{name=books/**}", + want: `/v1/shelves/{name:books/.+}`, + }, + { + name: "mixed * and **", + path: "/v1/{name=shelves/*/books/**}", + want: `/v1/{name:shelves/[^/]+/books/.+}`, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := getMuxPathTemplate(tt.path); got != tt.want { + t.Errorf("getMuxPathTemplate() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestBinding_PathSections(t *testing.T) { + tests := []struct { + name string + pathTemplate string + want []string + }{ + { + name: "simple", + pathTemplate: "/sum/{a}", + want: []string{ + `""`, + `"sum"`, + "fmt.Sprint(req.A)", + }, + }, + { + name: "pattern", + pathTemplate: `/v1/{parent:shelves/[^/]+}/books`, + want: []string{ + `""`, + `"v1"`, + "fmt.Sprint(req.Parent)", + `"books"`, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + b := &Binding{ + PathTemplate: tt.pathTemplate, + } + if got := b.PathSections(); !reflect.DeepEqual(got, tt.want) { + t.Errorf("Binding.PathSections() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/svcdef/consolidate_http.go b/svcdef/consolidate_http.go index 84efe384..44a36ee4 100644 --- a/svcdef/consolidate_http.go +++ b/svcdef/consolidate_http.go @@ -6,8 +6,8 @@ import ( "regexp" "strings" - log "github.com/sirupsen/logrus" "github.com/pkg/errors" + log "github.com/sirupsen/logrus" gogen "github.com/gogo/protobuf/protoc-gen-gogo/generator" @@ -164,11 +164,10 @@ func paramLocation(field *Field, binding *svcparse.HTTPBinding) string { func getPathParams(binding *svcparse.HTTPBinding) []string { _, path := getVerb(binding) findParams := regexp.MustCompile("{(.*?)}") - removeBraces := regexp.MustCompile("{|}") params := findParams.FindAllString(path, -1) rv := []string{} for _, p := range params { - rv = append(rv, removeBraces.ReplaceAllString(p, "")) + rv = append(rv, strings.Split(p[1:len(p)-1], "=")[0]) } return rv } diff --git a/svcdef/consolidate_http_test.go b/svcdef/consolidate_http_test.go index 484cec91..b24ba4cf 100644 --- a/svcdef/consolidate_http_test.go +++ b/svcdef/consolidate_http_test.go @@ -10,21 +10,36 @@ import ( ) func TestGetPathParams(t *testing.T) { - binding := &svcparse.HTTPBinding{ - Fields: []*svcparse.Field{ - &svcparse.Field{ - Kind: "get", - Value: `"/{a}/{b}"`, - }, + tests := []struct { + name string + value string + want []string + }{ + { + name: "basic", + value: `"/{a}/{b}"`, + want: []string{"a", "b"}, + }, + { + name: "variable with path segments", + value: `"/v1/{parent=shelves/*}/books"`, + want: []string{"parent"}, }, } - params := getPathParams(binding) - if len(params) != 2 { - t.Fatalf("Params (%v) is length '%v', expected length 2", params, len(params)) - } - expected := []string{"a", "b"} - if !reflect.DeepEqual(params, expected) { - t.Fatalf("Params is %v, expected %v", params, expected) + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + binding := &svcparse.HTTPBinding{ + Fields: []*svcparse.Field{ + { + Kind: "get", + Value: tt.value, + }, + }, + } + if got := getPathParams(binding); !reflect.DeepEqual(got, tt.want) { + t.Errorf("getPathParams() = %v, want %v", got, tt.want) + } + }) } }