From 83ee368af9cd65042fecedb0520e6fa63c23a6e6 Mon Sep 17 00:00:00 2001 From: Joshua Humphries <2035234+jhump@users.noreply.github.com> Date: Thu, 23 Feb 2023 11:45:13 -0500 Subject: [PATCH] Preserve unknown fields when converting between `FileDescriptorProto` and `ImageFile` (#1855) This way, if a field is ever added to the descriptors and then serialized by a newer version of `buf`, an older version of `buf` won't mangle it and inadvertently drop those newer fields. --- .../bufimagetesting/bufimagetesting_test.go | 71 +++--- private/bufpkg/bufimage/util.go | 54 ++++- private/bufpkg/bufimage/util_test.go | 202 ++++++++++++++++++ .../pkg/protodescriptor/protodescriptor.go | 2 + 4 files changed, 295 insertions(+), 34 deletions(-) create mode 100644 private/bufpkg/bufimage/util_test.go diff --git a/private/bufpkg/bufimage/bufimagetesting/bufimagetesting_test.go b/private/bufpkg/bufimage/bufimagetesting/bufimagetesting_test.go index f9517dc230..af2a449394 100644 --- a/private/bufpkg/bufimage/bufimagetesting/bufimagetesting_test.go +++ b/private/bufpkg/bufimage/bufimagetesting/bufimagetesting_test.go @@ -455,22 +455,19 @@ func TestBasic(t *testing.T) { ) diff := cmp.Diff(protoImage, bufimage.ImageToProtoImage(newImage), protocmp.Transform()) require.Equal(t, "", diff) - // TODO: all of the below proto equality checks should use cmp.Diff - require.Equal( - t, - &descriptorpb.FileDescriptorSet{ - File: []*descriptorpb.FileDescriptorProto{ - testProtoImageFileToFileDescriptorProto(protoImageFileAA), - testProtoImageFileToFileDescriptorProto(protoImageFileImport), - testProtoImageFileToFileDescriptorProto(protoImageFileWellKnownTypeImport), - testProtoImageFileToFileDescriptorProto(protoImageFileAB), - testProtoImageFileToFileDescriptorProto(protoImageFileBA), - testProtoImageFileToFileDescriptorProto(protoImageFileBB), - testProtoImageFileToFileDescriptorProto(protoImageFileOutlandishDirectoryName), - }, + fileDescriptorSet := &descriptorpb.FileDescriptorSet{ + File: []*descriptorpb.FileDescriptorProto{ + testProtoImageFileToFileDescriptorProto(protoImageFileAA), + testProtoImageFileToFileDescriptorProto(protoImageFileImport), + testProtoImageFileToFileDescriptorProto(protoImageFileWellKnownTypeImport), + testProtoImageFileToFileDescriptorProto(protoImageFileAB), + testProtoImageFileToFileDescriptorProto(protoImageFileBA), + testProtoImageFileToFileDescriptorProto(protoImageFileBB), + testProtoImageFileToFileDescriptorProto(protoImageFileOutlandishDirectoryName), }, - bufimage.ImageToFileDescriptorSet(image), - ) + } + diff = cmp.Diff(fileDescriptorSet, bufimage.ImageToFileDescriptorSet(image), protocmp.Transform()) + require.Equal(t, "", diff) codeGeneratorRequest := &pluginpb.CodeGeneratorRequest{ ProtoFile: []*descriptorpb.FileDescriptorProto{ testProtoImageFileToFileDescriptorProto(protoImageFileAA), @@ -490,17 +487,21 @@ func TestBasic(t *testing.T) { "d/d.proto/d.proto", }, } - require.Equal( - t, + diff = cmp.Diff( codeGeneratorRequest, bufimage.ImageToCodeGeneratorRequest(image, "foo", nil, false, false), + protocmp.Transform(), ) + require.Equal(t, "", diff) + // verify that includeWellKnownTypes is a no-op if includeImports is false - require.Equal( - t, + diff = cmp.Diff( codeGeneratorRequest, bufimage.ImageToCodeGeneratorRequest(image, "foo", nil, false, true), + protocmp.Transform(), ) + require.Equal(t, "", diff) + codeGeneratorRequestIncludeImports := &pluginpb.CodeGeneratorRequest{ ProtoFile: []*descriptorpb.FileDescriptorProto{ testProtoImageFileToFileDescriptorProto(protoImageFileAA), @@ -522,11 +523,12 @@ func TestBasic(t *testing.T) { "d/d.proto/d.proto", }, } - require.Equal( - t, + diff = cmp.Diff( codeGeneratorRequestIncludeImports, bufimage.ImageToCodeGeneratorRequest(image, "foo", nil, true, false), + protocmp.Transform(), ) + require.Equal(t, "", diff) newImage, err = bufimage.NewImageForCodeGeneratorRequest(codeGeneratorRequest) require.NoError(t, err) AssertImageFilesEqual( @@ -563,11 +565,12 @@ func TestBasic(t *testing.T) { "d/d.proto/d.proto", }, } - require.Equal( - t, + diff = cmp.Diff( codeGeneratorRequestIncludeImportsAndWellKnownTypes, bufimage.ImageToCodeGeneratorRequest(image, "foo", nil, true, true), + protocmp.Transform(), ) + require.Equal(t, "", diff) // imagesByDir and multiple Image tests imagesByDir, err := bufimage.ImageByDir(image) require.NoError(t, err) @@ -642,11 +645,12 @@ func TestBasic(t *testing.T) { }, }, } - require.Equal( - t, - codeGeneratorRequests, - bufimage.ImagesToCodeGeneratorRequests(imagesByDir, "foo", nil, false, false), - ) + requestsFromImages := bufimage.ImagesToCodeGeneratorRequests(imagesByDir, "foo", nil, false, false) + require.Equal(t, len(codeGeneratorRequests), len(requestsFromImages)) + for i := range codeGeneratorRequests { + diff = cmp.Diff(codeGeneratorRequests[i], requestsFromImages[i], protocmp.Transform()) + require.Equal(t, "", diff) + } codeGeneratorRequestsIncludeImports := []*pluginpb.CodeGeneratorRequest{ { ProtoFile: []*descriptorpb.FileDescriptorProto{ @@ -688,11 +692,12 @@ func TestBasic(t *testing.T) { }, }, } - require.Equal( - t, - codeGeneratorRequestsIncludeImports, - bufimage.ImagesToCodeGeneratorRequests(imagesByDir, "foo", nil, true, false), - ) + requestsFromImages = bufimage.ImagesToCodeGeneratorRequests(imagesByDir, "foo", nil, true, false) + require.Equal(t, len(codeGeneratorRequestsIncludeImports), len(requestsFromImages)) + for i := range codeGeneratorRequestsIncludeImports { + diff = cmp.Diff(codeGeneratorRequestsIncludeImports[i], requestsFromImages[i], protocmp.Transform()) + require.Equal(t, "", diff) + } } func testProtoImageFileToFileDescriptorProto(imageFile *imagev1.ImageFile) *descriptorpb.FileDescriptorProto { diff --git a/private/bufpkg/bufimage/util.go b/private/bufpkg/bufimage/util.go index 0ec8f1f056..d46ca11a4b 100644 --- a/private/bufpkg/bufimage/util.go +++ b/private/bufpkg/bufimage/util.go @@ -24,11 +24,16 @@ import ( "github.com/bufbuild/buf/private/pkg/normalpath" "github.com/bufbuild/buf/private/pkg/protodescriptor" "github.com/bufbuild/buf/private/pkg/stringutil" + "google.golang.org/protobuf/encoding/protowire" "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/reflect/protoreflect" "google.golang.org/protobuf/types/descriptorpb" "google.golang.org/protobuf/types/pluginpb" ) +// Must match the tag number for ImageFile.buf_extensions defined in proto/buf/alpha/image/v1/image.proto. +const bufExtensionFieldNumber = 8042 + // paths can be either files (ending in .proto) or directories // paths must be normalized and validated, and not duplicated // if a directory, all .proto files underneath will be included @@ -345,7 +350,7 @@ func fileDescriptorProtoToProtoImageFile( if len(unusedDependencyIndexes) == 0 { unusedDependencyIndexes = nil } - return &imagev1.ImageFile{ + resultFile := &imagev1.ImageFile{ Name: fileDescriptorProto.Name, Package: fileDescriptorProto.Package, Syntax: fileDescriptorProto.Syntax, @@ -368,6 +373,53 @@ func fileDescriptorProtoToProtoImageFile( ModuleInfo: protoModuleInfo, }, } + resultFile.ProtoReflect().SetUnknown(stripBufExtensionField(fileDescriptorProto.ProtoReflect().GetUnknown())) + return resultFile +} + +func stripBufExtensionField(unknownFields protoreflect.RawFields) protoreflect.RawFields { + // We accumulate the new bytes in result. However, for efficiency, we don't do any + // allocation/copying until we have to (i.e. until we actually see the field we're + // trying to strip). So result will be left nil and initialized lazily if-and-only-if + // we actually need to strip data from unknownFields. + var result protoreflect.RawFields + bytesRemaining := unknownFields + for len(bytesRemaining) > 0 { + num, wireType, n := protowire.ConsumeTag(bytesRemaining) + if n < 0 { + // shouldn't be possible unless explicitly set to invalid bytes via reflection + return unknownFields + } + var skip bool + if num == bufExtensionFieldNumber { + // We need to strip this field. + skip = true + if result == nil { + // Lazily initialize result to the preface that we've already examined. + result = append( + make(protoreflect.RawFields, 0, len(unknownFields)), + unknownFields[:len(unknownFields)-len(bytesRemaining)]..., + ) + } + } else if result != nil { + // accumulate data in result as we go + result = append(result, bytesRemaining[:n]...) + } + bytesRemaining = bytesRemaining[n:] + n = protowire.ConsumeFieldValue(num, wireType, bytesRemaining) + if n < 0 { + return unknownFields + } + if !skip && result != nil { + result = append(result, bytesRemaining[:n]...) + } + bytesRemaining = bytesRemaining[n:] + } + if result == nil { + // we did not have to remove anything + return unknownFields + } + return result } func imageToCodeGeneratorRequest( diff --git a/private/bufpkg/bufimage/util_test.go b/private/bufpkg/bufimage/util_test.go new file mode 100644 index 0000000000..6aeec05ab8 --- /dev/null +++ b/private/bufpkg/bufimage/util_test.go @@ -0,0 +1,202 @@ +// Copyright 2020-2023 Buf Technologies, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package bufimage + +import ( + "bytes" + "testing" + + "github.com/bufbuild/buf/private/bufpkg/bufmodule/bufmoduleref" + imagev1 "github.com/bufbuild/buf/private/gen/proto/go/buf/alpha/image/v1" + "github.com/google/go-cmp/cmp" + "github.com/stretchr/testify/require" + "google.golang.org/protobuf/encoding/protowire" + "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/testing/protocmp" + "google.golang.org/protobuf/types/descriptorpb" +) + +func TestStripBufExtensionField(t *testing.T) { + t.Parallel() + file := &imagev1.ImageFile{ + BufExtension: &imagev1.ImageFileExtension{ + IsImport: proto.Bool(true), + UnusedDependency: []int32{1, 3, 5}, + ModuleInfo: &imagev1.ModuleInfo{ + Name: &imagev1.ModuleName{ + Remote: proto.String("buf.build"), + Owner: proto.String("foo"), + Repository: proto.String("bar"), + }, + Commit: proto.String("1234981234123412341234"), + }, + }, + } + dataToBeStripped, err := proto.Marshal(file) + require.NoError(t, err) + + otherData := protowire.AppendTag(nil, 122, protowire.BytesType) + otherData = protowire.AppendBytes(otherData, []byte{1, 18, 28, 123, 5, 3, 1}) + otherData = protowire.AppendTag(otherData, 123, protowire.VarintType) + otherData = protowire.AppendVarint(otherData, 23456) + otherData = protowire.AppendTag(otherData, 124, protowire.Fixed32Type) + otherData = protowire.AppendFixed32(otherData, 23456) + otherData = protowire.AppendTag(otherData, 125, protowire.Fixed64Type) + otherData = protowire.AppendFixed64(otherData, 23456) + otherData = protowire.AppendTag(otherData, 126, protowire.StartGroupType) + { + otherData = protowire.AppendTag(otherData, 1, protowire.VarintType) + otherData = protowire.AppendVarint(otherData, 123) + otherData = protowire.AppendTag(otherData, 2, protowire.BytesType) + otherData = protowire.AppendBytes(otherData, []byte("foo-bar-baz")) + } + otherData = protowire.AppendTag(otherData, 126, protowire.EndGroupType) + + testCases := []struct { + name string + input []byte + expectedOutput []byte + }{ + { + name: "nothing to strip", + input: otherData, + expectedOutput: otherData, + }, + { + name: "nothing left after strip", + input: dataToBeStripped, + expectedOutput: []byte{}, + }, + { + name: "stripped field at start", + input: bytes.Join([][]byte{dataToBeStripped, otherData}, nil), + expectedOutput: otherData, + }, + { + name: "stripped field at end", + input: bytes.Join([][]byte{otherData, dataToBeStripped}, nil), + expectedOutput: otherData, + }, + { + name: "stripped field in the middle", + input: bytes.Join([][]byte{otherData, dataToBeStripped, otherData}, nil), + expectedOutput: bytes.Repeat(otherData, 2), + }, + } + for i := range testCases { + testCase := testCases[i] + t.Run(testCase.name, func(t *testing.T) { + t.Parallel() + stripped := stripBufExtensionField(testCase.input) + require.Equal(t, testCase.expectedOutput, []byte(stripped)) + }) + } +} + +func TestImageToProtoPreservesUnrecognizedFields(t *testing.T) { + t.Parallel() + fileDescriptor := &descriptorpb.FileDescriptorProto{ + Name: proto.String("foo/bar/baz.proto"), + Package: proto.String("foo.bar.baz"), + Syntax: proto.String("proto3"), + MessageType: []*descriptorpb.DescriptorProto{ + { + Name: proto.String("Foo"), + Field: []*descriptorpb.FieldDescriptorProto{ + { + Name: proto.String("id"), + Number: proto.Int32(1), + Label: descriptorpb.FieldDescriptorProto_LABEL_OPTIONAL.Enum(), + Type: descriptorpb.FieldDescriptorProto_TYPE_INT64.Enum(), + JsonName: proto.String("id"), + }, + { + Name: proto.String("name"), + Number: proto.Int32(2), + Label: descriptorpb.FieldDescriptorProto_LABEL_OPTIONAL.Enum(), + Type: descriptorpb.FieldDescriptorProto_TYPE_STRING.Enum(), + JsonName: proto.String("name"), + }, + }, + }, + }, + } + otherData := protowire.AppendTag(nil, 122, protowire.BytesType) + otherData = protowire.AppendBytes(otherData, []byte{1, 18, 28, 123, 5, 3, 1}) + otherData = protowire.AppendTag(otherData, 123, protowire.VarintType) + otherData = protowire.AppendVarint(otherData, 23456) + otherData = protowire.AppendTag(otherData, 124, protowire.Fixed32Type) + otherData = protowire.AppendFixed32(otherData, 23456) + fileDescriptor.ProtoReflect().SetUnknown(otherData) + + module, err := bufmoduleref.ModuleIdentityForString("buf.build/foo/bar") + require.NoError(t, err) + imageFile, err := NewImageFile( + fileDescriptor, + module, + "1234123451235", + "foo/bar/baz.proto", + false, + false, + nil, + ) + require.NoError(t, err) + + protoImageFile := imageFileToProtoImageFile(imageFile) + // make sure unrecognized bytes survived + require.Equal(t, otherData, []byte(protoImageFile.ProtoReflect().GetUnknown())) + + // now round-trip it back through + imageFileBytes, err := proto.Marshal(protoImageFile) + require.NoError(t, err) + + roundTrippedFileDescriptor := &descriptorpb.FileDescriptorProto{} + err = proto.Unmarshal(imageFileBytes, roundTrippedFileDescriptor) + require.NoError(t, err) + // unrecognized now includes image file's buf extension + require.Greater(t, len(roundTrippedFileDescriptor.ProtoReflect().GetUnknown()), len(otherData)) + + // if we go back through an image file, we should strip out the + // buf extension unknown bytes but preserve the rest + module, err = bufmoduleref.ModuleIdentityForString("buf.build/abc/def") + require.NoError(t, err) + // NB: intentionally different metadata + imageFile, err = NewImageFile( + fileDescriptor, + module, + "987654321", + "abc/def/xyz.proto", + false, + true, + []int32{1, 2, 3}, + ) + require.NoError(t, err) + + protoImageFile = imageFileToProtoImageFile(imageFile) + // make sure unrecognized bytes survived and extraneous buf extension is not present + require.Equal(t, otherData, []byte(protoImageFile.ProtoReflect().GetUnknown())) + + // double-check via round-trip, to make sure resulting image file equals the input + // (to verify that the original unknown bytes byf extension didn't interfere) + imageFileBytes, err = proto.Marshal(protoImageFile) + require.NoError(t, err) + + roundTrippedImageFile := &imagev1.ImageFile{} + err = proto.Unmarshal(imageFileBytes, roundTrippedImageFile) + require.NoError(t, err) + + diff := cmp.Diff(protoImageFile, roundTrippedImageFile, protocmp.Transform()) + require.Empty(t, diff) +} diff --git a/private/pkg/protodescriptor/protodescriptor.go b/private/pkg/protodescriptor/protodescriptor.go index 5c64958741..491c5c2677 100644 --- a/private/pkg/protodescriptor/protodescriptor.go +++ b/private/pkg/protodescriptor/protodescriptor.go @@ -29,6 +29,7 @@ import ( // // Note that a FileDescriptor is not necessarily validated, unlike other interfaces in buf. type FileDescriptor interface { + proto.Message GetName() string GetPackage() string GetDependency() []string @@ -95,6 +96,7 @@ func FileDescriptorProtoForFileDescriptor(fileDescriptor FileDescriptor) *descri if edition := fileDescriptor.GetEdition(); edition != "" { fileDescriptorProto.Edition = proto.String(edition) } + fileDescriptorProto.ProtoReflect().SetUnknown(fileDescriptor.ProtoReflect().GetUnknown()) return fileDescriptorProto }