Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

rlp: optimize big.Int decoding for size <= 32 bytes #22927

Merged
merged 3 commits into from
May 25, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 47 additions & 13 deletions rlp/decode.go
Original file line number Diff line number Diff line change
Expand Up @@ -220,20 +220,51 @@ func decodeBigIntNoPtr(s *Stream, val reflect.Value) error {
}

func decodeBigInt(s *Stream, val reflect.Value) error {
b, err := s.Bytes()
if err != nil {
var buffer []byte
kind, size, err := s.Kind()
switch {
case err != nil:
return wrapStreamError(err, val.Type())
case kind == List:
return wrapStreamError(ErrExpectedString, val.Type())
case kind == Byte:
buffer = s.uintbuf[:1]
buffer[0] = s.byteval
s.kind = -1 // re-arm Kind
case size == 0:
// Avoid zero-length read.
s.kind = -1
case size <= uint64(len(s.uintbuf)):
// For integers smaller than s.uintbuf, allocating a buffer
// can be avoided.
buffer = s.uintbuf[:size]
if err := s.readFull(buffer); err != nil {
return wrapStreamError(err, val.Type())
}
// Reject inputs where single byte encoding should have been used.
if size == 1 && buffer[0] < 128 {
return wrapStreamError(ErrCanonSize, val.Type())
}
default:
// For large integers, a temporary buffer is needed.
buffer = make([]byte, size)
if err := s.readFull(buffer); err != nil {
return wrapStreamError(err, val.Type())
}
}

// Reject leading zero bytes.
if len(buffer) > 0 && buffer[0] == 0 {
return wrapStreamError(ErrCanonInt, val.Type())
}

// Set the integer bytes.
i := val.Interface().(*big.Int)
if i == nil {
i = new(big.Int)
val.Set(reflect.ValueOf(i))
}
// Reject leading zero bytes.
if len(b) > 0 && b[0] == 0 {
return wrapStreamError(ErrCanonInt, val.Type())
}
i.SetBytes(b)
i.SetBytes(buffer)
return nil
}

Expand Down Expand Up @@ -563,7 +594,7 @@ type Stream struct {
size uint64 // size of value ahead
kinderr error // error from last readKind
stack []uint64 // list sizes
uintbuf [8]byte // auxiliary buffer for integer decoding
uintbuf [32]byte // auxiliary buffer for integer decoding
kind Kind // kind of value ahead
byteval byte // value of single byte in type tag
limited bool // true if input limit is in effect
Expand Down Expand Up @@ -817,7 +848,7 @@ func (s *Stream) Reset(r io.Reader, inputLimit uint64) {
s.kind = -1
s.kinderr = nil
s.byteval = 0
s.uintbuf = [8]byte{}
s.uintbuf = [32]byte{}
}

// Kind returns the kind and size of the next value in the
Expand Down Expand Up @@ -927,17 +958,20 @@ func (s *Stream) readUint(size byte) (uint64, error) {
b, err := s.readByte()
return uint64(b), err
default:
buffer := s.uintbuf[:8]
for i := range buffer {
buffer[i] = 0
}
start := int(8 - size)
s.uintbuf = [8]byte{}
if err := s.readFull(s.uintbuf[start:]); err != nil {
if err := s.readFull(buffer[start:]); err != nil {
return 0, err
}
if s.uintbuf[start] == 0 {
if buffer[start] == 0 {
// Note: readUint is also used to decode integer values.
// The error needs to be adjusted to become ErrCanonInt in this case.
return 0, ErrCanonSize
}
return binary.BigEndian.Uint64(s.uintbuf[:]), nil
return binary.BigEndian.Uint64(buffer[:]), nil
}
}

Expand Down
22 changes: 19 additions & 3 deletions rlp/decode_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,11 @@ type recstruct struct {
Child *recstruct `rlp:"nil"`
}

type bigIntStruct struct {
I *big.Int
B string
}

type invalidNilTag struct {
X []byte `rlp:"nil"`
}
Expand Down Expand Up @@ -405,10 +410,11 @@ type ignoredField struct {
}

var (
veryBigInt = big.NewInt(0).Add(
veryBigInt = new(big.Int).Add(
big.NewInt(0).Lsh(big.NewInt(0xFFFFFFFFFFFFFF), 16),
big.NewInt(0xFFFF),
)
veryVeryBigInt = new(big.Int).Exp(veryBigInt, big.NewInt(8), nil)
)

var decodeTests = []decodeTest{
Expand Down Expand Up @@ -479,12 +485,15 @@ var decodeTests = []decodeTest{
{input: "C0", ptr: new(string), error: "rlp: expected input string or byte for string"},

// big ints
{input: "80", ptr: new(*big.Int), value: big.NewInt(0)},
{input: "01", ptr: new(*big.Int), value: big.NewInt(1)},
{input: "89FFFFFFFFFFFFFFFFFF", ptr: new(*big.Int), value: veryBigInt},
{input: "B848FFFFFFFFFFFFFFFFF800000000000000001BFFFFFFFFFFFFFFFFC8000000000000000045FFFFFFFFFFFFFFFFC800000000000000001BFFFFFFFFFFFFFFFFF8000000000000000001", ptr: new(*big.Int), value: veryVeryBigInt},
{input: "10", ptr: new(big.Int), value: *big.NewInt(16)}, // non-pointer also works
{input: "C0", ptr: new(*big.Int), error: "rlp: expected input string or byte for *big.Int"},
{input: "820001", ptr: new(big.Int), error: "rlp: non-canonical integer (leading zero bytes) for *big.Int"},
{input: "8105", ptr: new(big.Int), error: "rlp: non-canonical size information for *big.Int"},
{input: "00", ptr: new(*big.Int), error: "rlp: non-canonical integer (leading zero bytes) for *big.Int"},
{input: "820001", ptr: new(*big.Int), error: "rlp: non-canonical integer (leading zero bytes) for *big.Int"},
{input: "8105", ptr: new(*big.Int), error: "rlp: non-canonical size information for *big.Int"},

// structs
{
Expand All @@ -497,6 +506,13 @@ var decodeTests = []decodeTest{
ptr: new(recstruct),
value: recstruct{1, &recstruct{2, &recstruct{3, nil}}},
},
{
// This checks that empty big.Int works correctly in struct context. It's easy to
// miss the update of s.kind for this case, so it needs its own test.
input: "C58083343434",
ptr: new(bigIntStruct),
value: bigIntStruct{new(big.Int), "444"},
},

// struct errors
{
Expand Down
8 changes: 8 additions & 0 deletions rlp/encode_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,14 @@ var encTests = []encTest{
val: big.NewInt(0).SetBytes(unhex("010000000000000000000000000000000000000000000000000000000000000000")),
output: "A1010000000000000000000000000000000000000000000000000000000000000000",
},
{
val: veryBigInt,
output: "89FFFFFFFFFFFFFFFFFF",
},
{
val: veryVeryBigInt,
output: "B848FFFFFFFFFFFFFFFFF800000000000000001BFFFFFFFFFFFFFFFFC8000000000000000045FFFFFFFFFFFFFFFFC800000000000000001BFFFFFFFFFFFFFFFFF8000000000000000001",
},

// non-pointer big.Int
{val: *big.NewInt(0), output: "80"},
Expand Down