Skip to content

Commit

Permalink
Add proto marshaller for proto-over-http (#459)
Browse files Browse the repository at this point in the history
  • Loading branch information
MatthewDolan authored and achew22 committed Nov 16, 2017
1 parent 93cf3eb commit 0395325
Show file tree
Hide file tree
Showing 3 changed files with 154 additions and 0 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
_output/
.idea
62 changes: 62 additions & 0 deletions runtime/marshal_proto.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
package runtime

import (
"io"

"errors"
"github.com/golang/protobuf/proto"
"io/ioutil"
)

// ProtoMarshaller is a Marshaller which marshals/unmarshals into/from serialize proto bytes
type ProtoMarshaller struct{}

// ContentType always returns "application/octet-stream".
func (*ProtoMarshaller) ContentType() string {
return "application/octet-stream"
}

// Marshal marshals "value" into Proto
func (*ProtoMarshaller) Marshal(value interface{}) ([]byte, error) {
message, ok := value.(proto.Message)
if !ok {
return nil, errors.New("unable to marshal non proto field")
}
return proto.Marshal(message)
}

// Unmarshal unmarshals proto "data" into "value"
func (*ProtoMarshaller) Unmarshal(data []byte, value interface{}) error {
message, ok := value.(proto.Message)
if !ok {
return errors.New("unable to unmarshal non proto field")
}
return proto.Unmarshal(data, message)
}

// NewDecoder returns a Decoder which reads proto stream from "reader".
func (marshaller *ProtoMarshaller) NewDecoder(reader io.Reader) Decoder {
return DecoderFunc(func(value interface{}) error {
buffer, err := ioutil.ReadAll(reader)
if err != nil {
return err
}
return marshaller.Unmarshal(buffer, value)
})
}

// NewEncoder returns an Encoder which writes proto stream into "writer".
func (marshaller *ProtoMarshaller) NewEncoder(writer io.Writer) Encoder {
return EncoderFunc(func(value interface{}) error {
buffer, err := marshaller.Marshal(value)
if err != nil {
return err
}
_, err = writer.Write(buffer)
if err != nil {
return err
}

return nil
})
}
91 changes: 91 additions & 0 deletions runtime/marshal_proto_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
package runtime_test

import (
"reflect"
"testing"

"bytes"
"github.com/golang/protobuf/ptypes/timestamp"
"github.com/grpc-ecosystem/grpc-gateway/examples/examplepb"
"github.com/grpc-ecosystem/grpc-gateway/runtime"
)

var message = &examplepb.ABitOfEverything{
SingleNested: &examplepb.ABitOfEverything_Nested{},
RepeatedStringValue: nil,
MappedStringValue: nil,
MappedNestedValue: nil,
RepeatedEnumValue: nil,
TimestampValue: &timestamp.Timestamp{},
Uuid: "6EC2446F-7E89-4127-B3E6-5C05E6BECBA7",
Nested: []*examplepb.ABitOfEverything_Nested{
{
Name: "foo",
Amount: 12345,
},
},
Uint64Value: 0xFFFFFFFFFFFFFFFF,
EnumValue: examplepb.NumericEnum_ONE,
OneofValue: &examplepb.ABitOfEverything_OneofString{
OneofString: "bar",
},
MapValue: map[string]examplepb.NumericEnum{
"a": examplepb.NumericEnum_ONE,
"b": examplepb.NumericEnum_ZERO,
},
}

func TestProtoMarshalUnmarshal(t *testing.T) {
marshaller := runtime.ProtoMarshaller{}

// Marshal
buffer, err := marshaller.Marshal(message)
if err != nil {
t.Fatalf("Marshalling returned error: %s", err.Error())
}

// Unmarshal
unmarshalled := &examplepb.ABitOfEverything{}
err = marshaller.Unmarshal(buffer, unmarshalled)
if err != nil {
t.Fatalf("Unmarshalling returned error: %s", err.Error())
}

if !reflect.DeepEqual(unmarshalled, message) {
t.Errorf(
"Unmarshalled didn't match original message: (original = %v) != (unmarshalled = %v)",
unmarshalled,
message,
)
}
}

func TestProtoEncoderDecodert(t *testing.T) {
marshaller := runtime.ProtoMarshaller{}

var buf bytes.Buffer

encoder := marshaller.NewEncoder(&buf)
decoder := marshaller.NewDecoder(&buf)

// Encode
err := encoder.Encode(message)
if err != nil {
t.Fatalf("Encoding returned error: %s", err.Error())
}

// Decode
unencoded := &examplepb.ABitOfEverything{}
err = decoder.Decode(unencoded)
if err != nil {
t.Fatalf("Unmarshalling returned error: %s", err.Error())
}

if !reflect.DeepEqual(unencoded, message) {
t.Errorf(
"Unencoded didn't match original message: (original = %v) != (unencoded = %v)",
unencoded,
message,
)
}
}

0 comments on commit 0395325

Please sign in to comment.