diff --git a/codegenerator/doc.go b/codegenerator/doc.go new file mode 100644 index 00000000000..3645317175b --- /dev/null +++ b/codegenerator/doc.go @@ -0,0 +1,4 @@ +/* +Package codegenerator contains reusable functions used by the code generators. +*/ +package codegenerator diff --git a/codegenerator/parse_req.go b/codegenerator/parse_req.go new file mode 100644 index 00000000000..e74575bddcd --- /dev/null +++ b/codegenerator/parse_req.go @@ -0,0 +1,23 @@ +package codegenerator + +import ( + "fmt" + "io" + "io/ioutil" + + "github.com/golang/protobuf/proto" + plugin "github.com/golang/protobuf/protoc-gen-go/plugin" +) + +// ParseRequest parses a code generator request from a proto Message. +func ParseRequest(r io.Reader) (*plugin.CodeGeneratorRequest, error) { + input, err := ioutil.ReadAll(r) + if err != nil { + return nil, fmt.Errorf("failed to read code generator request: %v", err) + } + req := new(plugin.CodeGeneratorRequest) + if err = proto.Unmarshal(input, req); err != nil { + return nil, fmt.Errorf("failed to unmarshal code generator request: %v", err) + } + return req, nil +} diff --git a/codegenerator/parse_req_test.go b/codegenerator/parse_req_test.go new file mode 100644 index 00000000000..5f37aad9589 --- /dev/null +++ b/codegenerator/parse_req_test.go @@ -0,0 +1,69 @@ +package codegenerator_test + +import ( + "bytes" + "fmt" + "io" + "reflect" + "strings" + "testing" + + "github.com/golang/protobuf/proto" + plugin "github.com/golang/protobuf/protoc-gen-go/plugin" + "github.com/grpc-ecosystem/grpc-gateway/codegenerator" +) + +var parseReqTests = []struct { + name string + in io.Reader + out *plugin.CodeGeneratorRequest + err error +}{ + { + "Empty input should produce empty output", + mustGetReader(&plugin.CodeGeneratorRequest{}), + &plugin.CodeGeneratorRequest{}, + nil, + }, + { + "Invalid reader should produce error", + &invalidReader{}, + nil, + fmt.Errorf("failed to read code generator request: invalid reader"), + }, + { + "Invalid proto message should produce error", + strings.NewReader("{}"), + nil, + fmt.Errorf("failed to unmarshal code generator request: unexpected EOF"), + }, +} + +func TestParseRequest(t *testing.T) { + for _, tt := range parseReqTests { + t.Run(tt.name, func(t *testing.T) { + out, err := codegenerator.ParseRequest(tt.in) + if !reflect.DeepEqual(err, tt.err) { + t.Errorf("got %v, want %v", err, tt.err) + } + if err == nil && !reflect.DeepEqual(*out, *tt.out) { + t.Errorf("got %v, want %v", *out, *tt.out) + } + }) + } +} + +func mustGetReader(pb proto.Message) io.Reader { + b, err := proto.Marshal(pb) + if err != nil { + panic(err) + } + return bytes.NewBuffer(b) +} + +type invalidReader struct { +} + +func (*invalidReader) Read(p []byte) (int, error) { + return 0, fmt.Errorf("invalid reader") +} diff --git a/protoc-gen-grpc-gateway/main.go b/protoc-gen-grpc-gateway/main.go index 4b875d51b36..d4569d2079b 100644 --- a/protoc-gen-grpc-gateway/main.go +++ b/protoc-gen-grpc-gateway/main.go @@ -10,14 +10,13 @@ package main import ( "flag" - "io" - "io/ioutil" "os" "strings" "github.com/golang/glog" "github.com/golang/protobuf/proto" plugin "github.com/golang/protobuf/protoc-gen-go/plugin" + "github.com/grpc-ecosystem/grpc-gateway/codegenerator" "github.com/grpc-ecosystem/grpc-gateway/protoc-gen-grpc-gateway/descriptor" "github.com/grpc-ecosystem/grpc-gateway/protoc-gen-grpc-gateway/gengateway" ) @@ -29,33 +28,18 @@ var ( allowDeleteBody = flag.Bool("allow_delete_body", false, "unless set, HTTP DELETE methods may not have a body") ) -func parseReq(r io.Reader) (*plugin.CodeGeneratorRequest, error) { - glog.V(1).Info("Parsing code generator request") - input, err := ioutil.ReadAll(r) - if err != nil { - glog.Errorf("Failed to read code generator request: %v", err) - return nil, err - } - req := new(plugin.CodeGeneratorRequest) - if err = proto.Unmarshal(input, req); err != nil { - glog.Errorf("Failed to unmarshal code generator request: %v", err) - return nil, err - } - glog.V(1).Info("Parsed code generator request") - return req, nil -} - func main() { flag.Parse() defer glog.Flush() reg := descriptor.NewRegistry() - glog.V(1).Info("Processing code generator request") - req, err := parseReq(os.Stdin) + glog.V(1).Info("Parsing code generator request") + req, err := codegenerator.ParseRequest(os.Stdin) if err != nil { glog.Fatal(err) } + glog.V(1).Info("Parsed code generator request") if req.Parameter != nil { for _, p := range strings.Split(req.GetParameter(), ",") { spec := strings.SplitN(p, "=", 2) diff --git a/protoc-gen-swagger/main.go b/protoc-gen-swagger/main.go index ebbaecbd36f..3d7f1ab7580 100644 --- a/protoc-gen-swagger/main.go +++ b/protoc-gen-swagger/main.go @@ -3,14 +3,13 @@ package main import ( "flag" "fmt" - "io" - "io/ioutil" "os" "strings" "github.com/golang/glog" "github.com/golang/protobuf/proto" plugin "github.com/golang/protobuf/protoc-gen-go/plugin" + "github.com/grpc-ecosystem/grpc-gateway/codegenerator" "github.com/grpc-ecosystem/grpc-gateway/protoc-gen-grpc-gateway/descriptor" "github.com/grpc-ecosystem/grpc-gateway/protoc-gen-swagger/genswagger" ) @@ -21,22 +20,6 @@ var ( allowDeleteBody = flag.Bool("allow_delete_body", false, "unless set, HTTP DELETE methods may not have a body") ) -func parseReq(r io.Reader) (*plugin.CodeGeneratorRequest, error) { - glog.V(1).Info("Parsing code generator request") - input, err := ioutil.ReadAll(r) - if err != nil { - glog.Errorf("Failed to read code generator request: %v", err) - return nil, err - } - req := new(plugin.CodeGeneratorRequest) - if err = proto.Unmarshal(input, req); err != nil { - glog.Errorf("Failed to unmarshal code generator request: %v", err) - return nil, err - } - glog.V(1).Info("Parsed code generator request") - return req, nil -} - func main() { flag.Parse() defer glog.Flush() @@ -52,10 +35,12 @@ func main() { glog.Fatal(err) } } - req, err := parseReq(f) + glog.V(1).Info("Parsing code generator request") + req, err := codegenerator.ParseRequest(f) if err != nil { glog.Fatal(err) } + glog.V(1).Info("Parsed code generator request") pkgMap := make(map[string]string) if req.Parameter != nil { err := parseReqParam(req.GetParameter(), flag.CommandLine, pkgMap)