Skip to content

Commit

Permalink
fix(protoc-gen-openapiv2): exclude from query params oneof fields in …
Browse files Browse the repository at this point in the history
…the same group as the one used in the body
  • Loading branch information
ovargas committed Jan 5, 2024
1 parent 9490c9a commit aa1640d
Show file tree
Hide file tree
Showing 5 changed files with 129 additions and 19 deletions.
7 changes: 0 additions & 7 deletions examples/internal/clients/echo/api/swagger.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -449,13 +449,6 @@ paths:
type: "string"
x-exportParamName: "StatusNote"
x-optionalDataType: "String"
- name: "en"
in: "query"
required: false
type: "string"
format: "int64"
x-exportParamName: "En"
x-optionalDataType: "String"
responses:
200:
description: "A successful response."
Expand Down
5 changes: 0 additions & 5 deletions examples/internal/clients/echo/api_echo_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -837,7 +837,6 @@ EchoServiceApiService EchoBody method receives a simple message and returns it.
* @param "Lang" (optional.String) -
* @param "StatusProgress" (optional.String) -
* @param "StatusNote" (optional.String) -
* @param "En" (optional.String) -
@return ExamplepbSimpleMessage
*/
Expand All @@ -848,7 +847,6 @@ type EchoServiceEchoBody2Opts struct {
Lang optional.String
StatusProgress optional.String
StatusNote optional.String
En optional.String
}

func (a *EchoServiceApiService) EchoServiceEchoBody2(ctx context.Context, id string, no ExamplepbEmbedded, localVarOptionals *EchoServiceEchoBody2Opts) (ExamplepbSimpleMessage, *http.Response, error) {
Expand Down Expand Up @@ -883,9 +881,6 @@ func (a *EchoServiceApiService) EchoServiceEchoBody2(ctx context.Context, id str
if localVarOptionals != nil && localVarOptionals.StatusNote.IsSet() {
localVarQueryParams.Add("status.note", parameterToString(localVarOptionals.StatusNote.Value(), ""))
}
if localVarOptionals != nil && localVarOptionals.En.IsSet() {
localVarQueryParams.Add("en", parameterToString(localVarOptionals.En.Value(), ""))
}
// to determine the Content-Type header
localVarHttpContentTypes := []string{"application/json"}

Expand Down
7 changes: 0 additions & 7 deletions examples/internal/proto/examplepb/echo_service.swagger.json
Original file line number Diff line number Diff line change
Expand Up @@ -535,13 +535,6 @@
"in": "query",
"required": false,
"type": "string"
},
{
"name": "en",
"in": "query",
"required": false,
"type": "string",
"format": "int64"
}
],
"tags": [
Expand Down
108 changes: 108 additions & 0 deletions protoc-gen-openapiv2/internal/genopenapi/generator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1182,6 +1182,108 @@ func TestGenerateRPCOrderPreservedAdditionalBindings(t *testing.T) {
}
}

func TestGenerateRPCOneOfFieldBodyAdditionalBindings(t *testing.T) {
t.Parallel()

const in = `
file_to_generate: "exampleproto/v1/example.proto"
parameter: "output_format=yaml,allow_delete_body=true"
proto_file: {
name: "exampleproto/v1/example.proto"
package: "example.v1"
message_type: {
name: "Foo"
oneof_decl: {
name: "foo"
}
field: {
name: "bar"
number: 1
label: LABEL_OPTIONAL
type: TYPE_STRING
json_name: "bar"
oneof_index: 0
}
field: {
name: "baz"
number: 2
label: LABEL_OPTIONAL
type: TYPE_STRING
json_name: "bar"
oneof_index: 0
}
}
service: {
name: "TestService"
method: {
name: "Test1"
input_type: ".example.v1.Foo"
output_type: ".example.v1.Foo"
options: {
[google.api.http]: {
post: "/b/foo"
body: "*"
additional_bindings {
post: "/b/foo/bar"
body: "bar"
}
additional_bindings {
post: "/b/foo/baz"
body: "baz"
}
}
}
}
}
options: {
go_package: "exampleproto/v1;exampleproto"
}
}`

var req pluginpb.CodeGeneratorRequest
if err := prototext.Unmarshal([]byte(in), &req); err != nil {
t.Fatalf("failed to marshall yaml: %s", err)
}

formats := [...]genopenapi.Format{
genopenapi.FormatJSON,
genopenapi.FormatYAML,
}

for _, format := range formats {
format := format
t.Run(string(format), func(t *testing.T) {
t.Parallel()

resp := requireGenerate(t, &req, format, true, false)
if len(resp) != 1 {
t.Fatalf("invalid count, expected: 1, actual: %d", len(resp))
}

content := resp[0].GetContent()

t.Log(content)

contentsSlice := strings.Fields(content)
expectedPaths := []string{"/b/foo", "/b/foo/bar", "/b/foo/baz"}

foundPaths := []string{}
for _, contentValue := range contentsSlice {
findExpectedPaths(&foundPaths, expectedPaths, contentValue)
}

if allPresent := reflect.DeepEqual(foundPaths, expectedPaths); !allPresent {
t.Fatalf("Found paths differed from expected paths. Got: %#v, want %#v", foundPaths, expectedPaths)
}

// The input message only contains oneof fields, so no other fields should be mapped to the query.
if strings.Contains(content, "query") {
t.Fatalf("Found query in content, expected not to find any")
}
})
}
}

func TestGenerateRPCOrderNotPreservedAdditionalBindings(t *testing.T) {
t.Parallel()

Expand Down Expand Up @@ -1694,6 +1796,12 @@ func TestGenerateRPCOrderNotPreservedMergeFilesAdditionalBindingsMultipleService
func findExpectedPaths(foundPaths *[]string, expectedPaths []string, potentialPath string) {
seenPaths := map[string]struct{}{}

// foundPaths may bot be empty when this function is called multiple times.
// so we add them to seenPaths map to avoid duplicates.
for _, path := range *foundPaths {
seenPaths[path] = struct{}{}
}

for _, path := range expectedPaths {
_, pathAlreadySeen := seenPaths[path]
if strings.Contains(potentialPath, path) && !pathAlreadySeen {
Expand Down
21 changes: 21 additions & 0 deletions protoc-gen-openapiv2/internal/genopenapi/template.go
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,11 @@ func getEnumDefaultNumber(reg *descriptor.Registry, enum *descriptor.Enum) inter
// messageToQueryParameters converts a message to a list of OpenAPI query parameters.
func messageToQueryParameters(message *descriptor.Message, reg *descriptor.Registry, pathParams []descriptor.Parameter, body *descriptor.Body, httpMethod string) (params []openapiParameterObject, err error) {
for _, field := range message.Fields {
// When body is set to oneof field, we want to skip other fields in the oneof group.
if isBodySameOneOf(body, field) {
continue
}

if !isVisible(getFieldVisibilityOption(field), reg) {
continue
}
Expand All @@ -183,6 +188,22 @@ func messageToQueryParameters(message *descriptor.Message, reg *descriptor.Regis
return params, nil
}

func isBodySameOneOf(body *descriptor.Body, field *descriptor.Field) bool {
if field.OneofIndex == nil {
return false
}

if body == nil || len(body.FieldPath) == 0 {
return false
}

if body.FieldPath[0].Target.OneofIndex == nil {
return false
}

return *body.FieldPath[0].Target.OneofIndex == *field.OneofIndex
}

// queryParams converts a field to a list of OpenAPI query parameters recursively through the use of nestedQueryParams.
func queryParams(message *descriptor.Message, field *descriptor.Field, prefix string, reg *descriptor.Registry, pathParams []descriptor.Parameter, body *descriptor.Body, recursiveCount int) (params []openapiParameterObject, err error) {
return nestedQueryParams(message, field, prefix, reg, pathParams, body, newCycleChecker(recursiveCount))
Expand Down

0 comments on commit aa1640d

Please sign in to comment.