From e85c5956fdd94d4c07741f52abd93f1db68518fa Mon Sep 17 00:00:00 2001 From: aBear Date: Wed, 16 Oct 2024 23:35:12 +0200 Subject: [PATCH] hardened GweiFromWei to avoid integer truncation --- mod/primitives/pkg/math/u64.go | 13 ++++- mod/primitives/pkg/math/u64_test.go | 87 +++++++++++++++++++++-------- 2 files changed, 76 insertions(+), 24 deletions(-) diff --git a/mod/primitives/pkg/math/u64.go b/mod/primitives/pkg/math/u64.go index cad6b74ea8..08f0ce17e0 100644 --- a/mod/primitives/pkg/math/u64.go +++ b/mod/primitives/pkg/math/u64.go @@ -24,6 +24,7 @@ import ( "math/big" "strconv" + "github.com/berachain/beacon-kit/mod/errors" "github.com/berachain/beacon-kit/mod/primitives/pkg/encoding/hex" "github.com/berachain/beacon-kit/mod/primitives/pkg/math/log" "github.com/berachain/beacon-kit/mod/primitives/pkg/math/pow" @@ -136,11 +137,19 @@ func (u U64) ILog2Floor() uint8 { // ---------------------------- Gwei Methods ---------------------------- +var ErrGweiOverflow = errors.New("gwei from big.Int overflows") + // GweiFromWei returns the value of Wei in Gwei. -func GweiFromWei(i *big.Int) Gwei { +func GweiFromWei(i *big.Int) (Gwei, error) { intToGwei := big.NewInt(0).SetUint64(GweiPerWei) i.Div(i, intToGwei) - return Gwei(i.Uint64()) + if !i.IsUint64() { + // a Gwei amount >= (2**64) * (10**9) or negative would not + // be representable as uint64. This should not happen but + // we still guard against a serialization bug or other mishap. + return 0, ErrGweiOverflow + } + return Gwei(i.Uint64()), nil } // ToWei converts a value from Gwei to Wei. diff --git a/mod/primitives/pkg/math/u64_test.go b/mod/primitives/pkg/math/u64_test.go index c9510377b7..06090ba2e8 100644 --- a/mod/primitives/pkg/math/u64_test.go +++ b/mod/primitives/pkg/math/u64_test.go @@ -328,38 +328,81 @@ func TestU64_PrevPowerOfTwo(t *testing.T) { func TestGweiFromWei(t *testing.T) { tests := []struct { - name string - input *big.Int - expected math.Gwei + name string + input func(t *testing.T) *big.Int + expectedErr error + expectedRes math.Gwei }{ { - name: "zero wei", - input: big.NewInt(0), - expected: math.Gwei(0), - }, - { - name: "one gwei", - input: big.NewInt(math.GweiPerWei), - expected: math.Gwei(1), - }, - { - name: "arbitrary wei", - input: big.NewInt(math.GweiPerWei * 123456789), - expected: math.Gwei(123456789), + name: "invalid negative gwei", + input: func(t *testing.T) *big.Int { + t.Helper() + b, _ := new(big.Int).SetString("-1", 10) + return b + }, + expectedErr: math.ErrGweiOverflow, + expectedRes: math.Gwei(0), + }, + { + name: "invalid huge gwei", + input: func(t *testing.T) *big.Int { + t.Helper() + b, _ := new(big.Int).SetString("18446744073709551616000000000", 10) + return b + }, + expectedErr: math.ErrGweiOverflow, + expectedRes: math.Gwei(0), + }, + { + name: "zero wei", + input: func(t *testing.T) *big.Int { + t.Helper() + return big.NewInt(0) + }, + expectedErr: nil, + expectedRes: math.Gwei(0), + }, + { + name: "one gwei", + input: func(t *testing.T) *big.Int { + t.Helper() + return big.NewInt(math.GweiPerWei) + }, + expectedErr: nil, + expectedRes: math.Gwei(1), + }, + { + name: "arbitrary wei", + input: func(t *testing.T) *big.Int { + t.Helper() + return big.NewInt(math.GweiPerWei * 123456789) + }, + expectedErr: nil, + expectedRes: math.Gwei(123456789), }, { name: "max uint64 wei", - input: new( - big.Int, - ).Mul(big.NewInt(math.GweiPerWei), new(big.Int).SetUint64(^uint64(0))), - expected: math.Gwei(1<<64 - 1), + input: func(t *testing.T) *big.Int { + t.Helper() + return new(big.Int).Mul( + big.NewInt(math.GweiPerWei), + new(big.Int).SetUint64(^uint64(0)), + ) + }, + expectedErr: nil, + expectedRes: math.Gwei(1<<64 - 1), }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - result := math.GweiFromWei(tt.input) - require.Equal(t, tt.expected, result, "Test case: %s", tt.name) + result, err := math.GweiFromWei(tt.input(t)) + if tt.expectedErr != nil { + require.ErrorIs(t, err, tt.expectedErr) + } else { + require.NoError(t, err) + require.Equal(t, tt.expectedRes, result, "Test case: %s", tt.name) + } }) } }