diff --git a/protoc-gen-grpc-gateway/descriptor/registry.go b/protoc-gen-grpc-gateway/descriptor/registry.go index f293f555824..123280990d3 100644 --- a/protoc-gen-grpc-gateway/descriptor/registry.go +++ b/protoc-gen-grpc-gateway/descriptor/registry.go @@ -30,6 +30,9 @@ type Registry struct { // pkgAliases is a mapping from package aliases to package paths in go which are already taken. pkgAliases map[string]string + + // allowDeleteBody permits http delete methods to have a body + allowDeleteBody bool } // NewRegistry returns a new Registry. @@ -260,6 +263,12 @@ func (r *Registry) GetAllFQENs() []string { return keys } +// SetAllowDeleteBody controls whether http delete methods may have a +// body or fail loading if encountered. +func (r *Registry) SetAllowDeleteBody(allow bool) { + r.allowDeleteBody = allow +} + // defaultGoPackageName returns the default go package name to be used for go files generated from "f". // You might need to use an unique alias for the package when you import it. Use ReserveGoPackageAlias to get a unique alias. func defaultGoPackageName(f *descriptor.FileDescriptorProto) string { diff --git a/protoc-gen-grpc-gateway/descriptor/services.go b/protoc-gen-grpc-gateway/descriptor/services.go index edf73caf77b..7bd928678f0 100644 --- a/protoc-gen-grpc-gateway/descriptor/services.go +++ b/protoc-gen-grpc-gateway/descriptor/services.go @@ -89,7 +89,7 @@ func (r *Registry) newMethod(svc *Service, md *descriptor.MethodDescriptorProto, case opts.GetDelete() != "": httpMethod = "DELETE" pathTemplate = opts.GetDelete() - if opts.Body != "" { + if opts.Body != "" && !r.allowDeleteBody { return nil, fmt.Errorf("needs request body even though http method is DELETE: %s", md.GetName()) } diff --git a/protoc-gen-grpc-gateway/descriptor/services_test.go b/protoc-gen-grpc-gateway/descriptor/services_test.go index 7b4af4e6898..eda34d4141e 100644 --- a/protoc-gen-grpc-gateway/descriptor/services_test.go +++ b/protoc-gen-grpc-gateway/descriptor/services_test.go @@ -1107,3 +1107,104 @@ func TestResolveFieldPath(t *testing.T) { } } } + +func TestExtractServicesWithDeleteBody(t *testing.T) { + for _, spec := range []struct { + allowDeleteBody bool + expectErr bool + target string + srcs []string + }{ + // body for DELETE, but registry configured to allow it + { + allowDeleteBody: true, + expectErr: false, + target: "path/to/example.proto", + srcs: []string{ + ` + name: "path/to/example.proto", + package: "example" + message_type < + name: "StringMessage" + field < + name: "string" + number: 1 + label: LABEL_OPTIONAL + type: TYPE_STRING + > + > + service < + name: "ExampleService" + method < + name: "RemoveResource" + input_type: "StringMessage" + output_type: "StringMessage" + options < + [google.api.http] < + delete: "/v1/example/resource" + body: "string" + > + > + > + > + `, + }, + }, + // body for DELETE, registry configured not to allow it + { + allowDeleteBody: false, + expectErr: true, + target: "path/to/example.proto", + srcs: []string{ + ` + name: "path/to/example.proto", + package: "example" + message_type < + name: "StringMessage" + field < + name: "string" + number: 1 + label: LABEL_OPTIONAL + type: TYPE_STRING + > + > + service < + name: "ExampleService" + method < + name: "RemoveResource" + input_type: "StringMessage" + output_type: "StringMessage" + options < + [google.api.http] < + delete: "/v1/example/resource" + body: "string" + > + > + > + > + `, + }, + }, + } { + reg := NewRegistry() + reg.SetAllowDeleteBody(spec.allowDeleteBody) + + var fds []*descriptor.FileDescriptorProto + for _, src := range spec.srcs { + var fd descriptor.FileDescriptorProto + if err := proto.UnmarshalText(src, &fd); err != nil { + t.Fatalf("proto.UnmarshalText(%s, &fd) failed with %v; want success", src, err) + } + reg.loadFile(&fd) + fds = append(fds, &fd) + } + err := reg.loadServices(reg.files[spec.target]) + if spec.expectErr && err == nil { + t.Errorf("loadServices(%q) succeeded; want an error; allowDeleteBody=%v, files=%v", spec.target, spec.allowDeleteBody, spec.srcs) + } + if !spec.expectErr && err != nil { + t.Errorf("loadServices(%q) failed; do not want an error; allowDeleteBody=%v, files=%v", spec.target, spec.allowDeleteBody, spec.srcs) + } + t.Log(err) + } +} diff --git a/protoc-gen-swagger/main.go b/protoc-gen-swagger/main.go index fc3a0030a11..db747704e23 100644 --- a/protoc-gen-swagger/main.go +++ b/protoc-gen-swagger/main.go @@ -2,6 +2,7 @@ package main import ( "flag" + "fmt" "io" "io/ioutil" "os" @@ -15,8 +16,9 @@ import ( ) var ( - importPrefix = flag.String("import_prefix", "", "prefix to be added to go package paths for imported proto files") - file = flag.String("file", "stdin", "where to load data from") + importPrefix = flag.String("import_prefix", "", "prefix to be added to go package paths for imported proto files") + file = flag.String("file", "stdin", "where to load data from") + allowDeleteBody = flag.Bool("allow_delete_body", false, "unless set, HTTP DELETE methods may not have a body") ) func parseReq(r io.Reader) (*plugin.CodeGeneratorRequest, error) { @@ -50,29 +52,21 @@ func main() { if err != nil { glog.Fatal(err) } + pkgMap := make(map[string]string) if req.Parameter != nil { - for _, p := range strings.Split(req.GetParameter(), ",") { - spec := strings.SplitN(p, "=", 2) - if len(spec) == 1 { - if err := flag.CommandLine.Set(spec[0], ""); err != nil { - glog.Fatalf("Cannot set flag %s", p) - } - continue - } - name, value := spec[0], spec[1] - if strings.HasPrefix(name, "M") { - reg.AddPkgMap(name[1:], value) - continue - } - if err := flag.CommandLine.Set(name, value); err != nil { - glog.Fatalf("Cannot set flag %s", p) - } + err := parseReqParam(req.GetParameter(), flag.CommandLine, pkgMap) + if err != nil { + glog.Fatalf("Error parsing flags: %v", err) } } + reg.SetPrefix(*importPrefix) + reg.SetAllowDeleteBody(*allowDeleteBody) + for k, v := range pkgMap { + reg.AddPkgMap(k, v) + } g := genswagger.New(reg) - reg.SetPrefix(*importPrefix) if err := reg.Load(req); err != nil { emitError(err) return @@ -113,3 +107,38 @@ func emitResp(resp *plugin.CodeGeneratorResponse) { glog.Fatal(err) } } + +// parseReqParam parses a CodeGeneratorRequest parameter and adds the +// extracted values to the given FlagSet and pkgMap. Returns a non-nil +// error if setting a flag failed. +func parseReqParam(param string, f *flag.FlagSet, pkgMap map[string]string) error { + if param == "" { + return nil + } + for _, p := range strings.Split(param, ",") { + spec := strings.SplitN(p, "=", 2) + if len(spec) == 1 { + if spec[0] == "allow_delete_body" { + err := f.Set(spec[0], "true") + if err != nil { + return fmt.Errorf("Cannot set flag %s: %v", p, err) + } + continue + } + err := f.Set(spec[0], "") + if err != nil { + return fmt.Errorf("Cannot set flag %s: %v", p, err) + } + continue + } + name, value := spec[0], spec[1] + if strings.HasPrefix(name, "M") { + pkgMap[name[1:]] = value + continue + } + if err := f.Set(name, value); err != nil { + return fmt.Errorf("Cannot set flag %s: %v", p, err) + } + } + return nil +} diff --git a/protoc-gen-swagger/main_test.go b/protoc-gen-swagger/main_test.go new file mode 100644 index 00000000000..c4d12dd20bd --- /dev/null +++ b/protoc-gen-swagger/main_test.go @@ -0,0 +1,129 @@ +package main + +import ( + "flag" + "reflect" + "testing" +) + +func TestParseReqParam(t *testing.T) { + + f := flag.CommandLine + + // this one must be first - with no leading clearFlags call it + // verifies our expectation of default values as we reset by + // clearFlags + pkgMap := make(map[string]string) + expected := map[string]string{} + err := parseReqParam("", f, pkgMap) + if err != nil { + t.Errorf("Test 0: unexpected parse error '%v'", err) + } + if !reflect.DeepEqual(pkgMap, expected) { + t.Errorf("Test 0: pkgMap parse error, expected '%v', got '%v'", expected, pkgMap) + } + checkFlags(false, "stdin", "", t, 0) + + clearFlags() + pkgMap = make(map[string]string) + expected = map[string]string{"google/api/annotations.proto": "github.com/grpc-ecosystem/grpc-gateway/third_party/googleapis/google/api"} + err = parseReqParam("allow_delete_body,file=./foo.pb,import_prefix=/bar/baz,Mgoogle/api/annotations.proto=github.com/grpc-ecosystem/grpc-gateway/third_party/googleapis/google/api", f, pkgMap) + if err != nil { + t.Errorf("Test 1: unexpected parse error '%v'", err) + } + if !reflect.DeepEqual(pkgMap, expected) { + t.Errorf("Test 1: pkgMap parse error, expected '%v', got '%v'", expected, pkgMap) + } + checkFlags(true, "./foo.pb", "/bar/baz", t, 1) + + clearFlags() + pkgMap = make(map[string]string) + expected = map[string]string{"google/api/annotations.proto": "github.com/grpc-ecosystem/grpc-gateway/third_party/googleapis/google/api"} + err = parseReqParam("allow_delete_body=true,file=./foo.pb,import_prefix=/bar/baz,Mgoogle/api/annotations.proto=github.com/grpc-ecosystem/grpc-gateway/third_party/googleapis/google/api", f, pkgMap) + if err != nil { + t.Errorf("Test 2: unexpected parse error '%v'", err) + } + if !reflect.DeepEqual(pkgMap, expected) { + t.Errorf("Test 2: pkgMap parse error, expected '%v', got '%v'", expected, pkgMap) + } + checkFlags(true, "./foo.pb", "/bar/baz", t, 2) + + clearFlags() + pkgMap = make(map[string]string) + expected = map[string]string{"a/b/c.proto": "github.com/x/y/z", "f/g/h.proto": "github.com/1/2/3/"} + err = parseReqParam("allow_delete_body=false,Ma/b/c.proto=github.com/x/y/z,Mf/g/h.proto=github.com/1/2/3/", f, pkgMap) + if err != nil { + t.Errorf("Test 3: unexpected parse error '%v'", err) + } + if !reflect.DeepEqual(pkgMap, expected) { + t.Errorf("Test 3: pkgMap parse error, expected '%v', got '%v'", expected, pkgMap) + } + checkFlags(false, "stdin", "", t, 3) + + clearFlags() + pkgMap = make(map[string]string) + expected = map[string]string{} + err = parseReqParam("", f, pkgMap) + if err != nil { + t.Errorf("Test 4: unexpected parse error '%v'", err) + } + if !reflect.DeepEqual(pkgMap, expected) { + t.Errorf("Test 4: pkgMap parse error, expected '%v', got '%v'", expected, pkgMap) + } + checkFlags(false, "stdin", "", t, 4) + + clearFlags() + pkgMap = make(map[string]string) + expected = map[string]string{} + err = parseReqParam("unknown_param=17", f, pkgMap) + if err == nil { + t.Error("Test 5: expected parse error not returned") + } + if !reflect.DeepEqual(pkgMap, expected) { + t.Errorf("Test 5: pkgMap parse error, expected '%v', got '%v'", expected, pkgMap) + } + checkFlags(false, "stdin", "", t, 5) + + clearFlags() + pkgMap = make(map[string]string) + expected = map[string]string{} + err = parseReqParam("Mfoo", f, pkgMap) + if err == nil { + t.Error("Test 6: expected parse error not returned") + } + if !reflect.DeepEqual(pkgMap, expected) { + t.Errorf("Test 6: pkgMap parse error, expected '%v', got '%v'", expected, pkgMap) + } + checkFlags(false, "stdin", "", t, 6) + + clearFlags() + pkgMap = make(map[string]string) + expected = map[string]string{} + err = parseReqParam("allow_delete_body,file,import_prefix", f, pkgMap) + if err != nil { + t.Errorf("Test 7: unexpected parse error '%v'", err) + } + if !reflect.DeepEqual(pkgMap, expected) { + t.Errorf("Test 7: pkgMap parse error, expected '%v', got '%v'", expected, pkgMap) + } + checkFlags(true, "", "", t, 7) + +} + +func checkFlags(allowDeleteV bool, fileV, importPathV string, t *testing.T, tid int) { + if *importPrefix != importPathV { + t.Errorf("Test %v: import_prefix misparsed, expected '%v', got '%v'", tid, importPathV, *importPrefix) + } + if *file != fileV { + t.Errorf("Test %v: file misparsed, expected '%v', got '%v'", tid, fileV, *file) + } + if *allowDeleteBody != allowDeleteV { + t.Errorf("Test %v: allow_delete_body misparsed, expected '%v', got '%v'", tid, allowDeleteV, *allowDeleteBody) + } +} + +func clearFlags() { + *importPrefix = "" + *file = "stdin" + *allowDeleteBody = false +}