Skip to content

Commit

Permalink
make unknown proto not use gogo code gen types, it now uses descripto…
Browse files Browse the repository at this point in the history
…rpb types
  • Loading branch information
unknown unknown committed Nov 22, 2023
1 parent 8644e6b commit 8ada39e
Showing 1 changed file with 115 additions and 49 deletions.
164 changes: 115 additions & 49 deletions codec/unknownproto/unknown_fields.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,12 @@ import (
"strings"
"sync"

"github.com/cosmos/cosmos-sdk/codec/types"
"github.com/cosmos/gogoproto/jsonpb"
"github.com/cosmos/gogoproto/proto"
"github.com/cosmos/gogoproto/protoc-gen-gogo/descriptor"
"google.golang.org/protobuf/encoding/protowire"

"github.com/cosmos/cosmos-sdk/codec/types"
protov2 "google.golang.org/protobuf/proto"
"google.golang.org/protobuf/types/descriptorpb"
)

const bit11NonCritical = 1 << 10
Expand Down Expand Up @@ -68,10 +68,9 @@ func RejectUnknownFields(bz []byte, msg proto.Message, allowUnknownNonCriticals
Type: reflect.ValueOf(msg).Type().String(),
TagNum: tagNum,
GotWireType: wireType,
WantWireType: protowire.Type(fieldDescProto.WireType()),
WantWireType: findWireTypeFromFieldDescriptorProtoType(fieldDescProto.GetType()),
}
}

default:
isCriticalField := tagNum&bit11NonCritical == 0

Expand Down Expand Up @@ -101,14 +100,14 @@ func RejectUnknownFields(bz []byte, msg proto.Message, allowUnknownNonCriticals
bz = bz[n:]

// An unknown but non-critical field or just a scalar type (aka *INT and BYTES like).
if fieldDescProto == nil || fieldDescProto.IsScalar() {
if fieldDescProto == nil || isScalar(fieldDescProto) {
continue
}

protoMessageName := fieldDescProto.GetTypeName()
if protoMessageName == "" {
switch typ := fieldDescProto.GetType(); typ {
case descriptor.FieldDescriptorProto_TYPE_STRING, descriptor.FieldDescriptorProto_TYPE_BYTES:
case descriptorpb.FieldDescriptorProto_TYPE_STRING, descriptorpb.FieldDescriptorProto_TYPE_BYTES:
// At this point only TYPE_STRING is expected to be unregistered, since FieldDescriptorProto.IsScalar() returns false for
// TYPE_BYTES and TYPE_STRING as per
// https://github.com/cosmos/gogoproto/blob/5628607bb4c51c3157aacc3a50f0ab707582b805/protoc-gen-gogo/descriptor/descriptor.go#L95-L118
Expand Down Expand Up @@ -199,68 +198,68 @@ func protoMessageForTypeName(protoMessageName string) (proto.Message, error) {
// checks is a mapping of protowire.Type to supported descriptor.FieldDescriptorProto_Type.
// it is implemented this way so as to have constant time lookups and avoid the overhead
// from O(n) walking of switch. The change to using this mapping boosts throughput by about 200%.
var checks = [...]map[descriptor.FieldDescriptorProto_Type]bool{
var checks = [...]map[descriptorpb.FieldDescriptorProto_Type]bool{
// "0 Varint: int32, int64, uint32, uint64, sint32, sint64, bool, enum"
0: {
descriptor.FieldDescriptorProto_TYPE_INT32: true,
descriptor.FieldDescriptorProto_TYPE_INT64: true,
descriptor.FieldDescriptorProto_TYPE_UINT32: true,
descriptor.FieldDescriptorProto_TYPE_UINT64: true,
descriptor.FieldDescriptorProto_TYPE_SINT32: true,
descriptor.FieldDescriptorProto_TYPE_SINT64: true,
descriptor.FieldDescriptorProto_TYPE_BOOL: true,
descriptor.FieldDescriptorProto_TYPE_ENUM: true,
descriptorpb.FieldDescriptorProto_TYPE_INT32: true,
descriptorpb.FieldDescriptorProto_TYPE_INT64: true,
descriptorpb.FieldDescriptorProto_TYPE_UINT32: true,
descriptorpb.FieldDescriptorProto_TYPE_UINT64: true,
descriptorpb.FieldDescriptorProto_TYPE_SINT32: true,
descriptorpb.FieldDescriptorProto_TYPE_SINT64: true,
descriptorpb.FieldDescriptorProto_TYPE_BOOL: true,
descriptorpb.FieldDescriptorProto_TYPE_ENUM: true,
},

// "1 64-bit: fixed64, sfixed64, double"
1: {
descriptor.FieldDescriptorProto_TYPE_FIXED64: true,
descriptor.FieldDescriptorProto_TYPE_SFIXED64: true,
descriptor.FieldDescriptorProto_TYPE_DOUBLE: true,
descriptorpb.FieldDescriptorProto_TYPE_FIXED64: true,
descriptorpb.FieldDescriptorProto_TYPE_SFIXED64: true,
descriptorpb.FieldDescriptorProto_TYPE_DOUBLE: true,
},

// "2 Length-delimited: string, bytes, embedded messages, packed repeated fields"
2: {
descriptor.FieldDescriptorProto_TYPE_STRING: true,
descriptor.FieldDescriptorProto_TYPE_BYTES: true,
descriptor.FieldDescriptorProto_TYPE_MESSAGE: true,
descriptorpb.FieldDescriptorProto_TYPE_STRING: true,
descriptorpb.FieldDescriptorProto_TYPE_BYTES: true,
descriptorpb.FieldDescriptorProto_TYPE_MESSAGE: true,
// The following types can be packed repeated.
// ref: "Only repeated fields of primitive numeric types (types which use the varint, 32-bit, or 64-bit wire types) can be declared "packed"."
// ref: https://developers.google.com/protocol-buffers/docs/encoding#packed
descriptor.FieldDescriptorProto_TYPE_INT32: true,
descriptor.FieldDescriptorProto_TYPE_INT64: true,
descriptor.FieldDescriptorProto_TYPE_UINT32: true,
descriptor.FieldDescriptorProto_TYPE_UINT64: true,
descriptor.FieldDescriptorProto_TYPE_SINT32: true,
descriptor.FieldDescriptorProto_TYPE_SINT64: true,
descriptor.FieldDescriptorProto_TYPE_BOOL: true,
descriptor.FieldDescriptorProto_TYPE_ENUM: true,
descriptor.FieldDescriptorProto_TYPE_FIXED64: true,
descriptor.FieldDescriptorProto_TYPE_SFIXED64: true,
descriptor.FieldDescriptorProto_TYPE_DOUBLE: true,
descriptorpb.FieldDescriptorProto_TYPE_INT32: true,
descriptorpb.FieldDescriptorProto_TYPE_INT64: true,
descriptorpb.FieldDescriptorProto_TYPE_UINT32: true,
descriptorpb.FieldDescriptorProto_TYPE_UINT64: true,
descriptorpb.FieldDescriptorProto_TYPE_SINT32: true,
descriptorpb.FieldDescriptorProto_TYPE_SINT64: true,
descriptorpb.FieldDescriptorProto_TYPE_BOOL: true,
descriptorpb.FieldDescriptorProto_TYPE_ENUM: true,
descriptorpb.FieldDescriptorProto_TYPE_FIXED64: true,
descriptorpb.FieldDescriptorProto_TYPE_SFIXED64: true,
descriptorpb.FieldDescriptorProto_TYPE_DOUBLE: true,
},

// "3 Start group: groups (deprecated)"
3: {
descriptor.FieldDescriptorProto_TYPE_GROUP: true,
descriptorpb.FieldDescriptorProto_TYPE_GROUP: true,
},

// "4 End group: groups (deprecated)"
4: {
descriptor.FieldDescriptorProto_TYPE_GROUP: true,
descriptorpb.FieldDescriptorProto_TYPE_GROUP: true,
},

// "5 32-bit: fixed32, sfixed32, float"
5: {
descriptor.FieldDescriptorProto_TYPE_FIXED32: true,
descriptor.FieldDescriptorProto_TYPE_SFIXED32: true,
descriptor.FieldDescriptorProto_TYPE_FLOAT: true,
descriptorpb.FieldDescriptorProto_TYPE_FIXED32: true,
descriptorpb.FieldDescriptorProto_TYPE_SFIXED32: true,
descriptorpb.FieldDescriptorProto_TYPE_FLOAT: true,
},
}

// canEncodeType returns true if the wireType is suitable for encoding the descriptor type.
// See https://developers.google.com/protocol-buffers/docs/encoding#structure.
func canEncodeType(wireType protowire.Type, descType descriptor.FieldDescriptorProto_Type) bool {
func canEncodeType(wireType protowire.Type, descType descriptorpb.FieldDescriptorProto_Type) bool {
if iwt := int(wireType); iwt < 0 || iwt >= len(checks) {
return false
}
Expand Down Expand Up @@ -330,21 +329,21 @@ func (twt *errUnknownField) Error() string {
var _ error = (*errUnknownField)(nil)

var (
protoFileToDesc = make(map[string]*descriptor.FileDescriptorProto)
protoFileToDesc = make(map[string]*descriptorpb.FileDescriptorProto)
protoFileToDescMu sync.RWMutex
)

func unnestDesc(mdescs []*descriptor.DescriptorProto, indices []int) *descriptor.DescriptorProto {
func unnestDesc(mdescs []*descriptorpb.DescriptorProto, indices []int) *descriptorpb.DescriptorProto {
mdesc := mdescs[indices[0]]
for _, index := range indices[1:] {
mdesc = mdesc.NestedType[index]
}
return mdesc
}

// Invoking descriptor.ForMessage(proto.Message.(Descriptor).Descriptor()) is incredibly slow
// Invoking descriptorpb.ForMessage(proto.Message.(Descriptor).Descriptor()) is incredibly slow
// for every single message, thus the need for a hand-rolled custom version that's performant and cacheable.
func extractFileDescMessageDesc(desc descriptorIface) (*descriptor.FileDescriptorProto, *descriptor.DescriptorProto, error) {
func extractFileDescMessageDesc(desc descriptorIface) (*descriptorpb.FileDescriptorProto, *descriptorpb.DescriptorProto, error) {
gzippedPb, indices := desc.Descriptor()

protoFileToDescMu.RLock()
Expand All @@ -365,8 +364,8 @@ func extractFileDescMessageDesc(desc descriptorIface) (*descriptor.FileDescripto
return nil, nil, err
}

fdesc := new(descriptor.FileDescriptorProto)
if err := proto.Unmarshal(protoBlob, fdesc); err != nil {
fdesc := new(descriptorpb.FileDescriptorProto)
if err := protov2.Unmarshal(protoBlob, fdesc); err != nil {
return nil, nil, err
}

Expand All @@ -380,8 +379,8 @@ func extractFileDescMessageDesc(desc descriptorIface) (*descriptor.FileDescripto
}

type descriptorMatch struct {
cache map[int32]*descriptor.FieldDescriptorProto
desc *descriptor.DescriptorProto
cache map[int32]*descriptorpb.FieldDescriptorProto
desc *descriptorpb.DescriptorProto
}

var (
Expand All @@ -390,7 +389,7 @@ var (
)

// getDescriptorInfo retrieves the mapping of field numbers to their respective field descriptors.
func getDescriptorInfo(desc descriptorIface, msg proto.Message) (map[int32]*descriptor.FieldDescriptorProto, *descriptor.DescriptorProto, error) {
func getDescriptorInfo(desc descriptorIface, msg proto.Message) (map[int32]*descriptorpb.FieldDescriptorProto, *descriptorpb.DescriptorProto, error) {
key := reflect.ValueOf(msg).Type()

descprotoCacheMu.RLock()
Expand All @@ -407,7 +406,7 @@ func getDescriptorInfo(desc descriptorIface, msg proto.Message) (map[int32]*desc
return nil, nil, err
}

tagNumToTypeIndex := make(map[int32]*descriptor.FieldDescriptorProto)
tagNumToTypeIndex := make(map[int32]*descriptorpb.FieldDescriptorProto)
for _, field := range md.Field {
tagNumToTypeIndex[field.GetNumber()] = field
}
Expand Down Expand Up @@ -441,3 +440,70 @@ func (d DefaultAnyResolver) Resolve(typeURL string) (proto.Message, error) {
}
return reflect.New(mt.Elem()).Interface().(proto.Message), nil
}

func findWireTypeFromFieldDescriptorProtoType(fieldType descriptorpb.FieldDescriptorProto_Type) protowire.Type {
switch fieldType {
case descriptorpb.FieldDescriptorProto_TYPE_DOUBLE:
return 1
case descriptorpb.FieldDescriptorProto_TYPE_FLOAT:
return 5
case descriptorpb.FieldDescriptorProto_TYPE_INT64:
return 0
case descriptorpb.FieldDescriptorProto_TYPE_UINT64:
return 0
case descriptorpb.FieldDescriptorProto_TYPE_INT32:
return 0
case descriptorpb.FieldDescriptorProto_TYPE_UINT32:
return 0
case descriptorpb.FieldDescriptorProto_TYPE_FIXED64:
return 1
case descriptorpb.FieldDescriptorProto_TYPE_FIXED32:
return 5
case descriptorpb.FieldDescriptorProto_TYPE_BOOL:
return 0
case descriptorpb.FieldDescriptorProto_TYPE_STRING:
return 2
case descriptorpb.FieldDescriptorProto_TYPE_GROUP:
return 2
case descriptorpb.FieldDescriptorProto_TYPE_MESSAGE:
return 2
case descriptorpb.FieldDescriptorProto_TYPE_BYTES:
return 2
case descriptorpb.FieldDescriptorProto_TYPE_ENUM:
return 0
case descriptorpb.FieldDescriptorProto_TYPE_SFIXED32:
return 5
case descriptorpb.FieldDescriptorProto_TYPE_SFIXED64:
return 1
case descriptorpb.FieldDescriptorProto_TYPE_SINT32:
return 0
case descriptorpb.FieldDescriptorProto_TYPE_SINT64:
return 0
}
panic("unreachable")
}

func isScalar(field *descriptorpb.FieldDescriptorProto) bool {
if field.Type == nil {
return false
}
switch *field.Type {
case descriptorpb.FieldDescriptorProto_TYPE_DOUBLE,
descriptorpb.FieldDescriptorProto_TYPE_FLOAT,
descriptorpb.FieldDescriptorProto_TYPE_INT64,
descriptorpb.FieldDescriptorProto_TYPE_UINT64,
descriptorpb.FieldDescriptorProto_TYPE_INT32,
descriptorpb.FieldDescriptorProto_TYPE_FIXED64,
descriptorpb.FieldDescriptorProto_TYPE_FIXED32,
descriptorpb.FieldDescriptorProto_TYPE_BOOL,
descriptorpb.FieldDescriptorProto_TYPE_UINT32,
descriptorpb.FieldDescriptorProto_TYPE_ENUM,
descriptorpb.FieldDescriptorProto_TYPE_SFIXED32,
descriptorpb.FieldDescriptorProto_TYPE_SFIXED64,
descriptorpb.FieldDescriptorProto_TYPE_SINT32,
descriptorpb.FieldDescriptorProto_TYPE_SINT64:
return true
default:
return false
}
}

0 comments on commit 8ada39e

Please sign in to comment.