From c4d14cfc53a4f819877749e5d68f2c06209f9208 Mon Sep 17 00:00:00 2001 From: Malte Isberner Date: Tue, 16 Aug 2022 13:18:18 +0200 Subject: [PATCH] generate presence checks for bytes fields, where necessary --- .../internal/conformance/equalvt_test.go | 91 +++++++++++++++++++ .../test_messages_proto2_vtproto.pb.go | 4 +- features/equal/equal.go | 14 ++- testproto/proto2/scalars_vtproto.pb.go | 4 +- testproto/proto3opt/opt_vtproto.pb.go | 2 +- 5 files changed, 105 insertions(+), 10 deletions(-) diff --git a/conformance/internal/conformance/equalvt_test.go b/conformance/internal/conformance/equalvt_test.go index a6ef140..f5dccf4 100644 --- a/conformance/internal/conformance/equalvt_test.go +++ b/conformance/internal/conformance/equalvt_test.go @@ -4,6 +4,7 @@ import ( "fmt" "testing" + "github.com/planetscale/vtprotobuf/testproto/proto3opt" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "google.golang.org/protobuf/encoding/protojson" @@ -216,3 +217,93 @@ func TestEqualVT_Oneof_AbsenceVsZeroValue(t *testing.T) { require.NoError(t, err) } } + +func TestEqualVT_Proto2_BytesPresence(t *testing.T) { + a := &TestAllTypesProto2{ + OptionalBytes: nil, + } + b := &TestAllTypesProto2{ + OptionalBytes: []byte{}, + } + + require.False(t, proto.Equal(a, b)) + + aJson, err := protojson.Marshal(a) + require.NoError(t, err) + bJson, err := protojson.Marshal(b) + require.NoError(t, err) + + if a.EqualVT(b) { + assert.JSONEq(t, string(aJson), string(bJson)) + err := fmt.Errorf("these %T should not be equal:\nmsg = %+v\noriginal = %+v", a, a, b) + require.NoError(t, err) + } +} + +func TestEqualVT_Proto3_BytesPresence(t *testing.T) { + a := &proto3opt.OptionalFieldInProto3{ + OptionalBytes: nil, + } + b := &proto3opt.OptionalFieldInProto3{ + OptionalBytes: []byte{}, + } + + require.False(t, proto.Equal(a, b)) + + aJson, err := protojson.Marshal(a) + require.NoError(t, err) + bJson, err := protojson.Marshal(b) + require.NoError(t, err) + + if a.EqualVT(b) { + assert.JSONEq(t, string(aJson), string(bJson)) + err := fmt.Errorf("these %T should not be equal:\nmsg = %+v\noriginal = %+v", a, a, b) + require.NoError(t, err) + } +} + +func TestEqualVT_Proto2_BytesNoPresence(t *testing.T) { + a := &TestAllTypesProto2{ + RepeatedBytes: [][]byte{nil}, + OneofField: &TestAllTypesProto2_OneofBytes{ + OneofBytes: nil, + }, + } + b := &TestAllTypesProto2{ + RepeatedBytes: [][]byte{{}}, + OneofField: &TestAllTypesProto2_OneofBytes{ + OneofBytes: []byte{}, + }, + } + + require.True(t, proto.Equal(a, b)) + + if !a.EqualVT(b) { + err := fmt.Errorf("these %T should be equal:\nmsg = %+v\noriginal = %+v", a, a, b) + require.NoError(t, err) + } +} + +func TestEqualVT_Proto3_BytesNoPresence(t *testing.T) { + a := &TestAllTypesProto3{ + RepeatedBytes: [][]byte{nil}, + OneofField: &TestAllTypesProto3_OneofBytes{ + OneofBytes: nil, + }, + OptionalBytes: nil, + } + b := &TestAllTypesProto3{ + RepeatedBytes: [][]byte{{}}, + OneofField: &TestAllTypesProto3_OneofBytes{ + OneofBytes: []byte{}, + }, + OptionalBytes: []byte{}, + } + + require.True(t, proto.Equal(a, b)) + + if !a.EqualVT(b) { + err := fmt.Errorf("these %T should not be equal:\nmsg = %+v\noriginal = %+v", a, a, b) + require.NoError(t, err) + } +} diff --git a/conformance/internal/conformance/test_messages_proto2_vtproto.pb.go b/conformance/internal/conformance/test_messages_proto2_vtproto.pb.go index 3d44df0..bcd9171 100644 --- a/conformance/internal/conformance/test_messages_proto2_vtproto.pb.go +++ b/conformance/internal/conformance/test_messages_proto2_vtproto.pb.go @@ -1043,7 +1043,7 @@ func (this *TestAllTypesProto2) EqualVT(that *TestAllTypesProto2) bool { if p, q := this.OptionalString, that.OptionalString; (p == nil && q != nil) || (p != nil && (q == nil || *p != *q)) { return false } - if string(this.OptionalBytes) != string(that.OptionalBytes) { + if p, q := this.OptionalBytes, that.OptionalBytes; (p == nil && q != nil) || (p != nil && q == nil) || string(p) != string(q) { return false } if !this.OptionalNestedMessage.EqualVT(that.OptionalNestedMessage) { @@ -1781,7 +1781,7 @@ func (this *TestAllTypesProto2) EqualVT(that *TestAllTypesProto2) bool { if p, q := this.DefaultString, that.DefaultString; (p == nil && q != nil) || (p != nil && (q == nil || *p != *q)) { return false } - if string(this.DefaultBytes) != string(that.DefaultBytes) { + if p, q := this.DefaultBytes, that.DefaultBytes; (p == nil && q != nil) || (p != nil && q == nil) || string(p) != string(q) { return false } if p, q := this.Fieldname1, that.Fieldname1; (p == nil && q != nil) || (p != nil && (q == nil || *p != *q)) { diff --git a/features/equal/equal.go b/features/equal/equal.go index 31569ba..c10b690 100644 --- a/features/equal/equal.go +++ b/features/equal/equal.go @@ -134,7 +134,7 @@ func (p *equal) oneof(field *protogen.Field) { case isScalar(kind): p.compareScalar(lhs, rhs, false) case kind == protoreflect.BytesKind: - p.compareBytes(lhs, rhs) + p.compareBytes(lhs, rhs, false) case kind == protoreflect.MessageKind || kind == protoreflect.GroupKind: goTyp, _ := p.FieldGoType(field) p.compareCall(lhs, rhs, goTyp, field.Message) @@ -178,7 +178,7 @@ func (p *equal) field(field *protogen.Field, nullable bool) { p.compareScalar(lhs, rhs, nullable) case kind == protoreflect.BytesKind: - p.compareBytes(lhs, rhs) + p.compareBytes(lhs, rhs, nullable) case kind == protoreflect.MessageKind || kind == protoreflect.GroupKind: goTyp := fmt.Sprintf("*%s", p.QualifiedGoIdent(field.Message.GoIdent)) @@ -204,9 +204,13 @@ func (p *equal) compareScalar(lhs, rhs string, nullable bool) { p.P(`}`) } -func (p *equal) compareBytes(lhs, rhs string) { - // Inlined call to bytes.Equal() - p.P(`if string(`, lhs, `) != string(`, rhs, `) {`) +func (p *equal) compareBytes(lhs, rhs string, nullable bool) { + if nullable { + p.P(`if p, q := `, lhs, `, `, rhs, `; (p == nil && q != nil) || (p != nil && q == nil) || string(p) != string(q) {`) + } else { + // Inlined call to bytes.Equal() + p.P(`if string(`, lhs, `) != string(`, rhs, `) {`) + } p.P(` return false`) p.P(`}`) } diff --git a/testproto/proto2/scalars_vtproto.pb.go b/testproto/proto2/scalars_vtproto.pb.go index 0cb64ed..5708bfc 100644 --- a/testproto/proto2/scalars_vtproto.pb.go +++ b/testproto/proto2/scalars_vtproto.pb.go @@ -956,10 +956,10 @@ func (this *BytesMessage) EqualVT(that *BytesMessage) bool { } else if that == nil { return this.String() == "" } - if string(this.RequiredField) != string(that.RequiredField) { + if p, q := this.RequiredField, that.RequiredField; (p == nil && q != nil) || (p != nil && q == nil) || string(p) != string(q) { return false } - if string(this.OptionalField) != string(that.OptionalField) { + if p, q := this.OptionalField, that.OptionalField; (p == nil && q != nil) || (p != nil && q == nil) || string(p) != string(q) { return false } if len(this.RepeatedField) != len(that.RepeatedField) { diff --git a/testproto/proto3opt/opt_vtproto.pb.go b/testproto/proto3opt/opt_vtproto.pb.go index 50cf1d0..b3deb01 100644 --- a/testproto/proto3opt/opt_vtproto.pb.go +++ b/testproto/proto3opt/opt_vtproto.pb.go @@ -146,7 +146,7 @@ func (this *OptionalFieldInProto3) EqualVT(that *OptionalFieldInProto3) bool { if p, q := this.OptionalString, that.OptionalString; (p == nil && q != nil) || (p != nil && (q == nil || *p != *q)) { return false } - if string(this.OptionalBytes) != string(that.OptionalBytes) { + if p, q := this.OptionalBytes, that.OptionalBytes; (p == nil && q != nil) || (p != nil && q == nil) || string(p) != string(q) { return false } if p, q := this.OptionalEnum, that.OptionalEnum; (p == nil && q != nil) || (p != nil && (q == nil || *p != *q)) {