diff --git a/codec/codec.go b/codec/codec.go index 5490999..6967a74 100644 --- a/codec/codec.go +++ b/codec/codec.go @@ -48,6 +48,8 @@ func getSize(v any) (int, error) { switch v := v.(type) { case vtprotoMessage: return v.SizeVT(), nil + case gogoProtoMessage: + return v.Size(), nil case gproto.Message: return gproto.Size(v), nil case protoadapt.MessageV1: @@ -61,6 +63,8 @@ func marshal(v any) ([]byte, error) { switch v := v.(type) { case vtprotoMessage: return v.MarshalVT() + case gogoProtoMessage: + return v.Marshal() case gproto.Message: return gproto.Marshal(v) case protoadapt.MessageV1: @@ -76,6 +80,8 @@ func marshalAppend(dst []byte, v any) error { switch v := v.(type) { case vtprotoMessage: return takeErr(v.MarshalToSizedBufferVT(dst)) + case gogoProtoMessage: + return takeErr(v.MarshalToSizedBuffer(dst)) case gproto.Message: return takeErr((gproto.MarshalOptions{}).MarshalAppend(dst[:0], v)) case protoadapt.MessageV1: @@ -92,6 +98,8 @@ func (vtprotoCodec) Unmarshal(data mem.BufferSlice, v any) error { switch v := v.(type) { case vtprotoMessage: return v.UnmarshalVT(buf.ReadOnlyData()) + case gogoProtoMessage: + return v.Unmarshal(buf.ReadOnlyData()) case gproto.Message: return gproto.Unmarshal(buf.ReadOnlyData(), v) case protoadapt.MessageV1: @@ -110,4 +118,12 @@ type vtprotoMessage interface { SizeVT() int } +type gogoProtoMessage interface { + protoadapt.MessageV1 + MarshalToSizedBuffer([]byte) (int, error) + Marshal() ([]byte, error) + Unmarshal([]byte) error + Size() int +} + func init() { encoding.RegisterCodecV2(vtprotoCodec{}) } diff --git a/codec/codec_test.go b/codec/codec_test.go index f5f97e9..378c9d4 100644 --- a/codec/codec_test.go +++ b/codec/codec_test.go @@ -222,3 +222,70 @@ func benchmarkProtobuf(fn func(t testing.TB)) func(b *testing.B) { } } } + +func TestGoGoProtobuf(t *testing.T) { + tests := map[string]testData{ + "short string": { + length: 42, + allocs: 5, + }, + "long string": { + length: 10240, + allocs: allocsCount, + }, + } + + for name, d := range tests { + if !t.Run(name, func(t *testing.T) { + str := generateString(d.length) + value := (*gogoProto)(wrapperspb.String(str)) + c := checkGogo(str) + + testProtobuf(t, value, c) + + res := testing.Benchmark(benchmarkProtobuf(func(t testing.TB) { testProtobuf(t, value, c) })) + + if allocs := res.AllocsPerOp(); d.allocs != allocs { + t.Fatalf("unexpected number of allocations: expected %d != actual %d", d.allocs, allocs) + } + }) { + break + } + } +} + +func BenchmarkGoGoProtobuf(b *testing.B) { + str := generateString(10240) + value := (*gogoProto)(wrapperspb.String(str)) + c := checkGogo(str) + + benchmarkProtobuf(func(t testing.TB) { testProtobuf(t, value, c) })(b) +} + +func checkGogo(expected string) func(t testing.TB, what *gogoProto) { + return func(t testing.TB, what *gogoProto) { + if expected != what.Value { + t.Fatal("strings are not equal", expected, what.Value) + } + } +} + +// Let's pretend our vt proto is actually gogo proto. +type gogoProto vtwrapperspb.StringValue + +func (x *gogoProto) MarshalToSizedBuffer(b []byte) (int, error) { + return (*vtwrapperspb.StringValue)(x).MarshalToSizedBufferVT(b) +} + +func (x *gogoProto) Marshal() ([]byte, error) { + return (*vtwrapperspb.StringValue)(x).MarshalVT() +} + +func (x *gogoProto) Unmarshal(dest []byte) error { + return (*vtwrapperspb.StringValue)(x).UnmarshalVT(dest) +} + +func (x *gogoProto) Size() int { return (*vtwrapperspb.StringValue)(x).SizeVT() } +func (x *gogoProto) Reset() { (*wrapperspb.StringValue)(x).Reset() } +func (x *gogoProto) String() string { return messageString((*wrapperspb.StringValue)(x)) } +func (*gogoProto) ProtoMessage() {}