Skip to content

Commit

Permalink
Properly support the range of uint64 and allow big int to unwrap into…
Browse files Browse the repository at this point in the history
… smaller integer types
  • Loading branch information
nolag committed Sep 27, 2024
1 parent 0784a13 commit 38bf583
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 1 deletion.
17 changes: 17 additions & 0 deletions pkg/values/big_int.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package values
import (
"errors"
"fmt"
"math"
"math/big"
"reflect"

Expand Down Expand Up @@ -41,6 +42,14 @@ func (b *BigInt) UnwrapTo(to any) error {
return errors.New("cannot unwrap to nil pointer")
}
*tb = *b.Underlying
case *uint64:
if tb == nil {
return errors.New("cannot unwrap to nil pointer")
}
if b.Underlying.Cmp(new(big.Int).SetUint64(math.MaxUint64)) > 0 {
return errors.New("big.Int value is larger than uint64")
}
*tb = b.Underlying.Uint64()
case *any:
if tb == nil {
return errors.New("cannot unwrap to nil pointer")
Expand All @@ -52,6 +61,14 @@ func (b *BigInt) UnwrapTo(to any) error {
rto := reflect.ValueOf(to)
if rto.CanConvert(reflect.TypeOf(new(big.Int))) {
return b.UnwrapTo(rto.Convert(reflect.TypeOf(new(big.Int))).Interface())
} else if rto.CanConvert(reflect.TypeOf(new(uint64))) {
return b.UnwrapTo(rto.Convert(reflect.TypeOf(new(uint64))).Interface())
} else if rto.CanInt() || rto.CanUint() {
if b.Underlying.Cmp(big.NewInt(math.MaxInt64)) > 0 {
return fmt.Errorf("big.Int value is larger than int64")
}

return NewInt64(b.Underlying.Int64()).UnwrapTo(to)
}
return fmt.Errorf("cannot unwrap to value of type: %T", to)
}
Expand Down
8 changes: 7 additions & 1 deletion pkg/values/value.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package values
import (
"errors"
"fmt"
"math"
"math/big"
"reflect"

Expand Down Expand Up @@ -57,6 +58,9 @@ func Wrap(v any) (Value, error) {
case int:
return NewInt64(int64(tv)), nil
case uint64:
if tv > math.MaxInt64 {
return NewBigInt(new(big.Int).SetUint64(tv)), nil
}
return NewInt64(int64(tv)), nil
case uint32:
return NewInt64(int64(tv)), nil
Expand Down Expand Up @@ -141,7 +145,9 @@ func Wrap(v any) (Value, error) {

case reflect.Bool:
return Wrap(val.Convert(reflect.TypeOf(true)).Interface())
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
case reflect.Uint64:
return Wrap(val.Convert(reflect.TypeOf(uint64(0))).Interface())
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32:
return Wrap(val.Convert(reflect.TypeOf(int64(0))).Interface())
case reflect.Float32, reflect.Float64:
return Wrap(val.Convert(reflect.TypeOf(float64(0))).Interface())
Expand Down
37 changes: 37 additions & 0 deletions pkg/values/value_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,34 @@ func Test_IntTypes(t *testing.T) {
{name: "uint16", test: func(tt *testing.T) { wrappableTest[int64, uint16](tt, anyValue) }},
{name: "uint8", test: func(tt *testing.T) { wrappableTest[int64, uint8](tt, anyValue) }},
{name: "uint", test: func(tt *testing.T) { wrappableTest[int64, uint](tt, anyValue) }},
{name: "uint64 small enough for int64", test: func(tt *testing.T) {
u64, err := Wrap(uint64(math.MaxInt64))
require.NoError(tt, err)

expected, err := Wrap(int64(math.MaxInt64))
require.NoError(tt, err)

assert.Equal(tt, expected, u64)

unwrapped := uint64(0)
err = u64.UnwrapTo(&unwrapped)
require.NoError(tt, err)
assert.Equal(tt, uint64(math.MaxInt64), unwrapped)
}},
{name: "uint64 too large for int64", test: func(tt *testing.T) {
u64, err := Wrap(uint64(math.MaxInt64 + 1))
require.NoError(tt, err)

expected, err := Wrap(new(big.Int).SetUint64(math.MaxInt64 + 1))
require.NoError(tt, err)

assert.Equal(tt, expected, u64)

unwrapped := uint64(0)
err = u64.UnwrapTo(&unwrapped)
require.NoError(tt, err)
assert.Equal(tt, uint64(math.MaxInt64+1), unwrapped)
}},
}

for _, tc := range testCases {
Expand Down Expand Up @@ -451,6 +479,7 @@ type aliasByte uint8
type decimalAlias decimal.Decimal
type bigIntAlias big.Int
type bigIntPtrAlias *big.Int
type aliasUint64 uint64

func Test_Aliases(t *testing.T) {
testCases := []struct {
Expand Down Expand Up @@ -481,6 +510,14 @@ func Test_Aliases(t *testing.T) {
name: "integer",
test: func(tt *testing.T) { wrappableTest[int, aliasInt](tt, 1) },
},
{
name: "uint64 fits in int64",
test: func(tt *testing.T) { wrappableTest[uint64, aliasUint64](tt, uint64(math.MaxInt64)) },
},
{
name: "uint64 too large for int64",
test: func(tt *testing.T) { wrappableTest[uint64, aliasUint64](tt, uint64(math.MaxInt64+1)) },
},
{
name: "map",
test: func(tt *testing.T) { wrappableTest[map[string]any, aliasMap](tt, map[string]any{"hello": "world"}) },
Expand Down

0 comments on commit 38bf583

Please sign in to comment.