From 76e56952b611a270e356e60996a7b90a9a542ecc Mon Sep 17 00:00:00 2001 From: Dmitriy Matrenichev Date: Mon, 1 Aug 2022 23:45:07 +0300 Subject: [PATCH] chore: various fixes and small refactorings - Add benchmark for slice of ints. - Rename several variables. Signed-off-by: Dmitriy Matrenichev --- benchmarks_test.go | 64 ++++++++++++++++++++ unmarshal.go | 147 +++++++++++++++++++++------------------------ 2 files changed, 132 insertions(+), 79 deletions(-) diff --git a/benchmarks_test.go b/benchmarks_test.go index e5311b1..3b81d69 100644 --- a/benchmarks_test.go +++ b/benchmarks_test.go @@ -75,3 +75,67 @@ func BenchmarkCustom(b *testing.B) { } } } + +func BenchmarkSlice(b *testing.B) { + type structWithSlice struct { + Field []int `protobuf:"1"` + } + + type structType struct { + Field structWithSlice `protobuf:"1"` + } + + o := structType{ + Field: structWithSlice{ + Field: []int{1, 2, 3, 4, 5, 6, 7, 8, 9, 1500, 1600}, + }, + } + + encoded, err := protoenc.Marshal(&o) + require.NoError(b, err) + + b.ResetTimer() + b.ReportAllocs() + + target := &structType{} + for i := 0; i < b.N; i++ { + *target = structType{} + + err := protoenc.Unmarshal(encoded, target) + if err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkString(b *testing.B) { + type structWithString struct { + Field string `protobuf:"1"` + } + + type structType struct { + Field structWithString `protobuf:"1"` + } + + o := structType{ + Field: structWithString{ + Field: "stuff to benchmark", + }, + } + + encoded, err := protoenc.Marshal(&o) + require.NoError(b, err) + + b.ResetTimer() + b.ReportAllocs() + + target := &structType{} + for i := 0; i < b.N; i++ { + *target = structType{} + + err := protoenc.Unmarshal(encoded, target) + if err != nil { + b.Fatal(err) + } + } +} diff --git a/unmarshal.go b/unmarshal.go index fb5683d..d86c1e1 100644 --- a/unmarshal.go +++ b/unmarshal.go @@ -94,10 +94,6 @@ func (u *unmarshaller) unmarshalStruct(buf []byte, structVal reflect.Value) erro field = initStructField(structVal, structFields[fieldIndex]) } - // For more debugging output, uncomment the following three lines. - // if fieldi < len(fields){ - // fmt.Printf("Decoding FieldName %+v\n", fields[fieldi].Field) - // } // Decode the field's value rem, err := u.decodeValue(wiretype, buf, field) if err != nil { @@ -171,14 +167,14 @@ func findField(fields []FieldData, fieldnum protowire.Number) int { func (u *unmarshaller) decodeValue(wiretype protowire.Type, buf []byte, dst reflect.Value) ([]byte, error) { var ( // Break out the value from the buffer based on the wire type - v uint64 - n int - vb []byte + decodedValue uint64 + n int + decodedBytes []byte ) switch wiretype { //nolint:exhaustive case protowire.VarintType: - v, n = protowire.ConsumeVarint(buf) + decodedValue, n = protowire.ConsumeVarint(buf) if n <= 0 { return nil, errors.New("bad protobuf varint value") } @@ -193,7 +189,7 @@ func (u *unmarshaller) decodeValue(wiretype protowire.Type, buf []byte, dst refl return nil, errors.New("bad protobuf 32-bit value") } - v = uint64(res) + decodedValue = uint64(res) buf = buf[n:] case protowire.Fixed64Type: @@ -204,23 +200,23 @@ func (u *unmarshaller) decodeValue(wiretype protowire.Type, buf []byte, dst refl return nil, errors.New("bad protobuf 64-bit value") } - v = res + decodedValue = res buf = buf[n:] case protowire.BytesType: - vb, n = protowire.ConsumeBytes(buf) + decodedBytes, n = protowire.ConsumeBytes(buf) if n <= 0 { return nil, errors.New("bad protobuf length-delimited value") } - vb = vb[:len(vb):len(vb)] + decodedBytes = decodedBytes[:len(decodedBytes):len(decodedBytes)] buf = buf[n:] default: return nil, errors.New("unknown protobuf wire-type") } - if err := u.putInto(dst, wiretype, v, vb); err != nil { + if err := u.putInto(dst, wiretype, decodedValue, decodedBytes); err != nil { return nil, err } @@ -228,8 +224,8 @@ func (u *unmarshaller) decodeValue(wiretype protowire.Type, buf []byte, dst refl } //nolint:gocognit,gocyclo,cyclop -func (u *unmarshaller) putInto(dst reflect.Value, wiretype protowire.Type, v uint64, vb []byte) error { - // Value is not settable (invalid reflect.Value, private +func (u *unmarshaller) putInto(dst reflect.Value, wiretype protowire.Type, decodedValue uint64, decodedBytes []byte) error { + // Value is not settable (invalid reflect.Value, private) if !dst.CanSet() { return nil } @@ -241,14 +237,14 @@ func (u *unmarshaller) putInto(dst reflect.Value, wiretype protowire.Type, v uin return fmt.Errorf("bad wiretype for time.Time: %v", wiretype) } - var t timestamppb.Timestamp + var result timestamppb.Timestamp - err := proto.Unmarshal(vb, &t) + err := proto.Unmarshal(decodedBytes, &result) if err != nil { return err } - dst.Set(reflect.ValueOf(t.AsTime())) + dst.Set(reflect.ValueOf(result.AsTime())) return nil case typeDuration: @@ -256,14 +252,14 @@ func (u *unmarshaller) putInto(dst reflect.Value, wiretype protowire.Type, v uin return fmt.Errorf("bad wiretype for time.Duration: %v", wiretype) } - var d durationpb.Duration + var result durationpb.Duration - err := proto.Unmarshal(vb, &d) + err := proto.Unmarshal(decodedBytes, &result) if err != nil { return err } - dst.Set(reflect.ValueOf(d.AsDuration())) + dst.Set(reflect.ValueOf(result.AsDuration())) return nil } @@ -274,11 +270,11 @@ func (u *unmarshaller) putInto(dst reflect.Value, wiretype protowire.Type, v uin return fmt.Errorf("bad wiretype for bool: %v", wiretype) } - if v > 1 { + if decodedValue > 1 { return errors.New("invalid bool value") } - dst.SetBool(protowire.DecodeBool(v)) + dst.SetBool(protowire.DecodeBool(decodedValue)) case reflect.Int, reflect.Int32, reflect.Int64: // Signed integers may be encoded either zigzag-varint or fixed @@ -287,9 +283,9 @@ func (u *unmarshaller) putInto(dst reflect.Value, wiretype protowire.Type, v uin return errors.New("detected a 32bit machine, please use either int64 or int32") } - sv, err := decodeSignedInt(wiretype, v) + sv, err := decodeSignedInt(wiretype, decodedValue) if err != nil { - fmt.Println("Error Reflect.Int for v=", v, "wiretype=", wiretype, "for Value=", dst.Type().Name()) + fmt.Println("Error Reflect.Int for decodedValue=", decodedValue, "wiretype=", wiretype, "for Value=", dst.Type().Name()) return err } @@ -304,11 +300,11 @@ func (u *unmarshaller) putInto(dst reflect.Value, wiretype protowire.Type, v uin switch wiretype { //nolint:exhaustive case protowire.VarintType: - dst.SetUint(v) + dst.SetUint(decodedValue) case protowire.Fixed32Type: - dst.SetUint(uint64(uint32(v))) + dst.SetUint(uint64(uint32(decodedValue))) case protowire.Fixed64Type: - dst.SetUint(v) + dst.SetUint(decodedValue) default: return errors.New("bad wiretype for uint") } @@ -318,21 +314,21 @@ func (u *unmarshaller) putInto(dst reflect.Value, wiretype protowire.Type, v uin return errors.New("bad wiretype for float32") } - dst.SetFloat(float64(math.Float32frombits(uint32(v)))) + dst.SetFloat(float64(math.Float32frombits(uint32(decodedValue)))) case reflect.Float64: if wiretype != protowire.Fixed64Type { return errors.New("bad wiretype for float64") } - dst.SetFloat(math.Float64frombits(v)) + dst.SetFloat(math.Float64frombits(decodedValue)) case reflect.String: if wiretype != protowire.BytesType { return errors.New("bad wiretype for string") } - dst.SetString(string(vb)) + dst.SetString(string(decodedBytes)) case reflect.Ptr: if dst.IsNil() { @@ -342,18 +338,18 @@ func (u *unmarshaller) putInto(dst reflect.Value, wiretype protowire.Type, v uin } } - return u.putInto(dst.Elem(), wiretype, v, vb) + return u.putInto(dst.Elem(), wiretype, decodedValue, decodedBytes) case reflect.Struct: if enc, ok := dst.Addr().Interface().(encoding.BinaryUnmarshaler); ok { - return enc.UnmarshalBinary(vb) + return enc.UnmarshalBinary(decodedBytes) } if wiretype != protowire.BytesType { return errors.New("bad wiretype for embedded message") } - return u.unmarshalStruct(vb, dst) + return u.unmarshalStruct(decodedBytes, dst) case reflect.Slice, reflect.Array: // Repeated field or byte-slice @@ -361,7 +357,7 @@ func (u *unmarshaller) putInto(dst reflect.Value, wiretype protowire.Type, v uin return errors.New("bad wiretype for repeated field") } - return u.slice(dst, vb) + return u.slice(dst, decodedBytes) case reflect.Map: if wiretype != protowire.BytesType { return errors.New("bad wiretype for repeated field") @@ -371,10 +367,8 @@ func (u *unmarshaller) putInto(dst reflect.Value, wiretype protowire.Type, v uin dst.Set(reflect.MakeMap(dst.Type())) } - return u.mapEntry(dst, vb) + return u.mapEntry(dst, decodedBytes) case reflect.Interface: - data := vb - // TODO: find a way to handle nil interfaces if dst.IsNil() { return errors.New("nil interface fields are not supported") @@ -386,13 +380,11 @@ func (u *unmarshaller) putInto(dst reflect.Value, wiretype protowire.Type, v uin return errors.New("bad wiretype for bytes") } - return enc.UnmarshalBinary(data) + return enc.UnmarshalBinary(decodedBytes) } // Decode into the object the interface points to. - // XXX perhaps better ONLY to support self-decoding - // for interface fields? - return Unmarshal(vb, dst.Interface()) + return Unmarshal(decodedBytes, dst.Interface()) default: panic("unsupported value kind " + dst.Kind().String()) @@ -441,12 +433,11 @@ func instantiate(dst reflect.Value) error { return nil } -func (u *unmarshaller) slice(slice reflect.Value, vb []byte) error { +func (u *unmarshaller) slice(dst reflect.Value, decodedBytes []byte) error { // Find the element type, and create a temporary instance of it. - eltype := slice.Type().Elem() - val := reflect.New(eltype).Elem() + elemType := dst.Type().Elem() - ok, err := tryDecodeUnpackedByteSlice(slice, eltype, vb) + ok, err := tryDecodeUnpackedByteSlice(dst, elemType, decodedBytes) if err != nil { return err } @@ -455,50 +446,48 @@ func (u *unmarshaller) slice(slice reflect.Value, vb []byte) error { return nil } - wiretype, err := getWiretype(eltype) + wiretype, err := getWiretypeFor(elemType) if err != nil { return err } + elem := reflect.New(elemType).Elem() + if wiretype < 0 { // Other unpacked repeated types - // Just unpack and append one value from vb. - if err := u.putInto(val, protowire.BytesType, 0, vb); err != nil { + // Just unpack and append one value from decodedBytes. + if err := u.putInto(elem, protowire.BytesType, 0, decodedBytes); err != nil { return err } - if slice.Kind() != reflect.Slice { - return errors.New("append to non-slice") - } - - slice.Set(reflect.Append(slice, val)) + dst.Set(reflect.Append(dst, elem)) return nil } - // Decode packed values from the buffer and append them to the slice. - for len(vb) > 0 { - rem, err := u.decodeValue(wiretype, vb, val) + // Decode packed values from the buffer and append them to the dst. + for len(decodedBytes) > 0 { + rem, err := u.decodeValue(wiretype, decodedBytes, elem) if err != nil { return err } - slice.Set(reflect.Append(slice, val)) + dst.Set(reflect.Append(dst, elem)) - vb = rem + decodedBytes = rem } return nil } -func getWiretype(eltype reflect.Type) (protowire.Type, error) { - switch eltype.Kind() { //nolint:exhaustive +func getWiretypeFor(elemType reflect.Type) (protowire.Type, error) { + switch elemType.Kind() { //nolint:exhaustive case reflect.Bool, reflect.Int32, reflect.Int64, reflect.Int, reflect.Uint32, reflect.Uint64, reflect.Uint: - if (eltype.Kind() == reflect.Int || eltype.Kind() == reflect.Uint) && eltype.Size() < 8 { + if (elemType.Kind() == reflect.Int || elemType.Kind() == reflect.Uint) && elemType.Size() < 8 { return 0, errors.New("detected a 32bit machine, please either use (u)int64 or (u)int32") } - switch eltype { + switch elemType { case typeFixedS32: return protowire.Fixed32Type, nil case typeFixedS64: @@ -523,40 +512,40 @@ func getWiretype(eltype reflect.Type) (protowire.Type, error) { } } -func tryDecodeUnpackedByteSlice(slice reflect.Value, eltype reflect.Type, vb []byte) (bool, error) { - if eltype.Kind() != reflect.Uint8 { +func tryDecodeUnpackedByteSlice(dst reflect.Value, elemType reflect.Type, decodedBytes []byte) (bool, error) { + if elemType.Kind() != reflect.Uint8 { return false, nil } - if slice.Kind() == reflect.Array { - if slice.Len() != len(vb) { + if dst.Kind() == reflect.Array { + if dst.Len() != len(decodedBytes) { return false, errors.New("array length and buffer length differ") } - for i := 0; i < slice.Len(); i++ { + for i := 0; i < dst.Len(); i++ { // no SetByte method in reflect so has to pass down by uint64 - slice.Index(i).SetUint(uint64(vb[i])) + dst.Index(i).SetUint(uint64(decodedBytes[i])) } } else { - slice.SetBytes(vb) + dst.SetBytes(decodedBytes) } return true, nil } -func (u *unmarshaller) mapEntry(slval reflect.Value, vb []byte) error { - mKey := reflect.New(slval.Type().Key()).Elem() - mVal := reflect.New(slval.Type().Elem()).Elem() +func (u *unmarshaller) mapEntry(dstEntry reflect.Value, decodedBytes []byte) error { + entryKey := reflect.New(dstEntry.Type().Key()).Elem() + entryVal := reflect.New(dstEntry.Type().Elem()).Elem() - _, wiretype, n := protowire.ConsumeTag(vb) + _, wiretype, n := protowire.ConsumeTag(decodedBytes) if n <= 0 { return errors.New("bad protobuf field key") } - buf := vb[n:] + buf := decodedBytes[n:] var err error - buf, err = u.decodeValue(wiretype, buf, mKey) + buf, err = u.decodeValue(wiretype, buf, entryKey) if err != nil { return err @@ -569,20 +558,20 @@ func (u *unmarshaller) mapEntry(slval reflect.Value, vb []byte) error { } buf = buf[n:] - buf, err = u.decodeValue(wiretype, buf, mVal) + buf, err = u.decodeValue(wiretype, buf, entryVal) if err != nil { return err } } - if !mKey.IsValid() || !mVal.IsValid() { + if !entryKey.IsValid() || !entryVal.IsValid() { // We did not decode the key or the value in the map entry. // Either way, it's an invalid map entry. return errors.New("proto: bad map data: missing key/val") } - slval.SetMapIndex(mKey, mVal) + dstEntry.SetMapIndex(entryKey, entryVal) return nil }