Skip to content

Commit

Permalink
fix FileDescriptorSet and merging tests
Browse files Browse the repository at this point in the history
  • Loading branch information
awalterschulze committed Jan 20, 2025
1 parent b0feb48 commit 066ae42
Show file tree
Hide file tree
Showing 7 changed files with 73 additions and 35 deletions.
2 changes: 0 additions & 2 deletions debug/example.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,13 +48,11 @@ var Output = Nodes{
Field(`D`, `3`),
Nested(`E`,
Nested(`0`,
Field(`A`, `0`),
Nested(`B`,
Field(`0`, `b4`),
),
),
Nested(`1`,
Field(`A`, `0`),
Nested(`B`,
Field(`0`, `b5`),
),
Expand Down
54 changes: 35 additions & 19 deletions proto/nomerge_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,13 +43,19 @@ var (
r = rand.New(rand.NewSource(time.Now().UnixNano()))
)

var debugFileDescriptoSet = protoparser.NewFileDescriptorSet(debug.File_debug_proto)

var msgFileDescriptoSet = protoparser.NewFileDescriptorSet(prototests.File_msg_proto)

var extensionsFileDesciptorSet = protoparser.NewFileDescriptorSet(prototests.File_extensions_proto)

func TestNoMergeNoMerge(t *testing.T) {
m := debug.Input
data, err := proto.Marshal(m)
if err != nil {
t.Fatal(err)
}
err = noMerge(data, m.Description(), "debug", "Debug")
err = noMerge(data, debugFileDescriptoSet, "debug", "Debug")
if err != nil {
t.Fatal(err)
}
Expand All @@ -64,7 +70,7 @@ func TestNoMergeMerge(t *testing.T) {
}
key := byte(uint32(7)<<3 | uint32(1))
data = append(data, key, byte(r.Intn(256)), byte(r.Intn(256)), byte(r.Intn(256)), byte(r.Intn(256)), byte(r.Intn(256)), byte(r.Intn(256)), byte(r.Intn(256)), byte(r.Intn(256)))
err = noMerge(data, m.Description(), "debug", "Debug")
err = noMerge(data, debugFileDescriptoSet, "debug", "Debug")
if err == nil || !strings.Contains(err.Error(), "G requires merging") {
t.Fatalf("G should require merging")
}
Expand All @@ -80,7 +86,7 @@ func TestNoMergeLatent(t *testing.T) {
}
key := byte(uint32(6)<<3 | uint32(5))
data = append(data, key, byte(r.Intn(256)), byte(r.Intn(256)), byte(r.Intn(256)), byte(r.Intn(256)))
err = noMerge(data, m.Description(), "debug", "Debug")
err = noMerge(data, debugFileDescriptoSet, "debug", "Debug")
if err == nil || !strings.Contains(err.Error(), "F") {
t.Fatalf("F should have latent appending")
}
Expand All @@ -92,7 +98,7 @@ func TestNoMergeNestedNoMerge(t *testing.T) {
if err != nil {
t.Fatal(err)
}
err = noMerge(data, bigm.Description(), "prototests", "BigMsg")
err = noMerge(data, msgFileDescriptoSet, "prototests", "BigMsg")
if err != nil {
t.Fatal(err)
}
Expand All @@ -105,32 +111,42 @@ func TestNoMergeMessageMerge(t *testing.T) {
if err != nil {
t.Fatal(err)
}
key := byte(uint32(3)<<3 | uint32(2))
fieldkey := byte(uint32(12)<<3 | uint32(5))
data = append(data, key, 5, fieldkey, byte(r.Intn(256)), byte(r.Intn(256)), byte(r.Intn(256)), byte(r.Intn(256)))
err = noMerge(data, bigm.Description(), "prototests", "BigMsg")
if err == nil || !strings.Contains(err.Error(), "requires merging") {
t.Fatal(err)
smallMsgfieldKey := byte(uint32(3)<<3 | uint32(2)) // 3 field number, 2 wire type
flightParachuteFieldKey := byte(uint32(12)<<3 | uint32(5)) // 12 field number, 5 wire type
data = append(data, smallMsgfieldKey, 5, flightParachuteFieldKey, byte(r.Intn(256)), byte(r.Intn(256)), byte(r.Intn(256)), byte(r.Intn(256)))
err = noMerge(data, msgFileDescriptoSet, "prototests", "BigMsg")
if err == nil || !strings.Contains(err.Error(), "Msg requires merging") {
t.Fatalf("Msg should require merging, but got Error: <%v>", err)
}
}

func TestNoMergeNestedMerge(t *testing.T) {
bigm := prototests.NewPopulatedBigMsg(r)

m := prototests.NewPopulatedSmallMsg(r)
if len(m.FlightParachute) == 0 {
m.FlightParachute = []uint32{1}
}
m.MapShark = proto.String("a")
key := byte(uint32(12)<<3 | uint32(5))
m.XXX_unrecognized = append(m.XXX_unrecognized, key, byte(r.Intn(256)), byte(r.Intn(256)), byte(r.Intn(256)), byte(r.Intn(256)))
bigm.Msg = m
data, err := proto.Marshal(bigm)
mdata, err := proto.Marshal(m)
if err != nil {
t.Fatal(err)
}
flightParachuteFieldKey := byte(uint32(12)<<3 | uint32(5)) // 12 field number, 5 wire type
mdata = append(mdata, flightParachuteFieldKey, byte(r.Intn(256)), byte(r.Intn(256)), byte(r.Intn(256)), byte(r.Intn(256)))

bigm := &prototests.BigMsg{
Field: proto.Int64(int64(r.Intn(256))),
}
bigdata, err := proto.Marshal(bigm)
if err != nil {
t.Fatal(err)
}
err = noMerge(data, bigm.Description(), "prototests", "BigMsg")
smallMsgfieldKey := byte(uint32(3)<<3 | uint32(2)) // 3 field number, 2 wire type
bigdata = append(bigdata, smallMsgfieldKey, byte(len(mdata)))
bigdata = append(bigdata, mdata...)
err = noMerge(bigdata, msgFileDescriptoSet, "prototests", "BigMsg")
if err == nil || !strings.Contains(err.Error(), "FlightParachute requires merging") {
t.Fatalf("FlightParachute should require merging %#v", bigm)
t.Fatalf("FlightParachute should require merging, but got Error: <%v>", err)
}
}

Expand All @@ -140,7 +156,7 @@ func TestNoMergeExtensionNoMerge(t *testing.T) {
if err != nil {
t.Fatal(err)
}
err = noMerge(data, bigm.Description(), "prototests", "Container")
err = noMerge(data, extensionsFileDesciptorSet, "prototests", "Container")
if err != nil {
t.Fatal(err)
}
Expand All @@ -165,7 +181,7 @@ func TestNoMergeExtensionMerge(t *testing.T) {
n = binary.PutUvarint(datalen, uint64(len(mdata)))
datalen = datalen[:n]
data = append(data, append(datakey, append(datalen, mdata...)...)...)
err = noMerge(data, bigm.Description(), "prototests", "Container")
err = noMerge(data, extensionsFileDesciptorSet, "prototests", "Container")
if err == nil || !strings.Contains(err.Error(), "FieldB requires merging") {
t.Fatalf("FieldB should require merging, but error is %v", err)
}
Expand Down
10 changes: 6 additions & 4 deletions proto/packed_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,10 @@ var packedOutput1 = debug.Nodes{
),
}

var msgFileDescriptorSet = NewFileDescriptorSet(prototests.File_msg_proto)

func TestPacked1(t *testing.T) {
p, err := NewProtoParser("prototests", "Packed", packedInput1.Description())
p, err := NewProtoParser("prototests", "Packed", msgFileDescriptorSet)
if err != nil {
t.Fatal(err)
}
Expand All @@ -56,7 +58,7 @@ func TestPacked1(t *testing.T) {
}

func TestRandomPacked1(t *testing.T) {
p, err := NewProtoParser("prototests", "Packed", packedInput1.Description())
p, err := NewProtoParser("prototests", "Packed", msgFileDescriptorSet)
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -95,7 +97,7 @@ var packedOutput2 = debug.Nodes{
}

func TestPacked2(t *testing.T) {
p, err := NewProtoParser("prototests", "Packed", packedInput2.Description())
p, err := NewProtoParser("prototests", "Packed", msgFileDescriptorSet)
if err != nil {
t.Fatal(err)
}
Expand All @@ -114,7 +116,7 @@ func TestPacked2(t *testing.T) {
}

func TestRandomPacked2(t *testing.T) {
p, err := NewProtoParser("prototests", "Packed", packedInput2.Description())
p, err := NewProtoParser("prototests", "Packed", msgFileDescriptorSet)
if err != nil {
t.Fatal(err)
}
Expand Down
6 changes: 4 additions & 2 deletions proto/proto3_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,10 @@ var proto3Output1 = debug.Nodes{
),
}

var proto3FileDescriptorSet = NewFileDescriptorSet(prototests.File_proto3_proto)

func TestProto31(t *testing.T) {
p, err := NewProtoParser("prototests", "Proto3", proto3Input1.Description())
p, err := NewProtoParser("prototests", "Proto3", proto3FileDescriptorSet)
if err != nil {
t.Fatal(err)
}
Expand All @@ -68,7 +70,7 @@ func TestProto31(t *testing.T) {
}

func TestRandomProto31(t *testing.T) {
p, err := NewProtoParser("prototests", "Proto3", proto3Input1.Description())
p, err := NewProtoParser("prototests", "Proto3", proto3FileDescriptorSet)
if err != nil {
t.Fatal(err)
}
Expand Down
18 changes: 11 additions & 7 deletions proto/proto_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,12 @@ import (
"google.golang.org/protobuf/proto"
)

var debugFileDescriptoSet = NewFileDescriptorSet(debug.File_debug_proto)

var extensionsFileDescriptorSet = NewFileDescriptorSet(prototests.File_extensions_proto)

func TestDebug(t *testing.T) {
p, err := NewProtoParser("debug", "Debug", debug.DebugDescription())
p, err := NewProtoParser("debug", "Debug", debugFileDescriptoSet)
if err != nil {
t.Fatal(err)
}
Expand All @@ -44,7 +48,7 @@ func TestDebug(t *testing.T) {
}

func TestRandomDebug(t *testing.T) {
p, err := NewProtoParser("debug", "Debug", debug.DebugDescription())
p, err := NewProtoParser("debug", "Debug", debugFileDescriptoSet)
if err != nil {
t.Fatal(err)
}
Expand All @@ -71,7 +75,7 @@ func next(t *testing.T, parser parser.Interface) {
}

func TestSkipRepeated1(t *testing.T) {
p, err := NewProtoParser("debug", "Debug", debug.DebugDescription())
p, err := NewProtoParser("debug", "Debug", debugFileDescriptoSet)
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -99,7 +103,7 @@ func TestSkipRepeated1(t *testing.T) {
}

func TestSkipRepeated2(t *testing.T) {
p, err := NewProtoParser("debug", "Debug", debug.DebugDescription())
p, err := NewProtoParser("debug", "Debug", debugFileDescriptoSet)
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -129,7 +133,7 @@ func TestSkipRepeated2(t *testing.T) {
}

func TestIndexIsNotAString(t *testing.T) {
p, err := NewProtoParser("debug", "Debug", debug.DebugDescription())
p, err := NewProtoParser("debug", "Debug", debugFileDescriptoSet)
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -157,7 +161,7 @@ func TestIndexIsNotAString(t *testing.T) {
}

func TestExtensionsSmallContainer(t *testing.T) {
p, err := NewProtoParser("prototests", "Container", prototests.AContainer.Description())
p, err := NewProtoParser("prototests", "Container", extensionsFileDescriptorSet)
if err != nil {
t.Fatal(err)
}
Expand All @@ -175,7 +179,7 @@ func TestExtensionsSmallContainer(t *testing.T) {
}

func TestExtensionsBigContainer(t *testing.T) {
p, err := NewProtoParser("prototests", "BigContainer", prototests.ABigContainer.Description())
p, err := NewProtoParser("prototests", "BigContainer", extensionsFileDescriptorSet)
if err != nil {
t.Fatal(err)
}
Expand Down
16 changes: 16 additions & 0 deletions proto/protoreflect.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
package proto

import (
"google.golang.org/protobuf/reflect/protodesc"
"google.golang.org/protobuf/reflect/protoreflect"
descriptor "google.golang.org/protobuf/types/descriptorpb"
)

// NewFileDescriptorSet is a helper function that converts multiple protoreflect.FileDescriptor into a FileDescriptorSet.
func NewFileDescriptorSet(reflectFileDescriptors ...protoreflect.FileDescriptor) *descriptor.FileDescriptorSet {
fileDescriptors := make([]*descriptor.FileDescriptorProto, len(reflectFileDescriptors))
for i, rfd := range reflectFileDescriptors {
fileDescriptors[i] = protodesc.ToFileDescriptorProto(rfd)
}
return &descriptor.FileDescriptorSet{File: fileDescriptors}
}
2 changes: 1 addition & 1 deletion proto/prototests/value.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ var AContainer = &Container{

func init() {
f := float64(0.123)
proto.SetExtension(AContainer, E_FieldA, &f)
proto.SetExtension(AContainer, E_FieldA, f)
proto.SetExtension(AContainer, E_FieldB, &Small{SmallField: proto.Int64(456)})
proto.SetExtension(AContainer, E_FieldC, &Big{BigField: proto.Int64(789)})
}
Expand Down

0 comments on commit 066ae42

Please sign in to comment.