diff --git a/extra/bunbig/int.go b/extra/bunbig/int.go index 88f109ede..a0d22644a 100644 --- a/extra/bunbig/int.go +++ b/extra/bunbig/int.go @@ -18,15 +18,31 @@ func newBigint(x *big.Int) *Int { return (*Int)(x) } -// same as NewBigint() +// FromMathBig is same as NewBigint() func FromMathBig(x *big.Int) *Int { return (*Int)(x) } +func (i *Int) ToMathBig() *big.Int { + return (*big.Int)(i) +} + func FromInt64(x int64) *Int { return FromMathBig(big.NewInt(x)) } +func (i *Int) ToUInt64() uint64 { + return i.ToMathBig().Uint64() +} + +func FromUInt64(x uint64) *Int { + return FromMathBig(new(big.Int).SetUint64(x)) +} + +func (i *Int) ToInt64() int64 { + return i.ToMathBig().Int64() +} + func (i *Int) FromString(x string) (*Int, error) { if x == "" { return FromInt64(0), nil @@ -41,69 +57,73 @@ func (i *Int) FromString(x string) (*Int, error) { return newBigint(b), nil } -func (b *Int) Value() (driver.Value, error) { - return (*big.Int)(b).String(), nil +func (i *Int) String() string { + return i.ToMathBig().String() } -func (b *Int) Scan(value interface{}) error { +func (i *Int) Value() (driver.Value, error) { + return (*big.Int)(i).String(), nil +} - var i sql.NullString +func (i *Int) Scan(value interface{}) error { + var x sql.NullString - if err := i.Scan(value); err != nil { + if err := x.Scan(value); err != nil { return err } - if _, ok := (*big.Int)(b).SetString(i.String, 10); ok { + if _, ok := (*big.Int)(i).SetString(x.String, 10); ok { return nil } - return fmt.Errorf("Error converting type %T into Bigint", value) + return fmt.Errorf("error converting type %T into Int", value) } -func (b *Int) ToMathBig() *big.Int { - return (*big.Int)(b) +func (i *Int) MarshalJSON() ([]byte, error) { + return []byte(i.String()), nil } -func (b *Int) Sub(x *Int) *Int { - return (*Int)(big.NewInt(0).Sub(b.ToMathBig(), x.ToMathBig())) -} - -func (b *Int) Add(x *Int) *Int { - return (*Int)(big.NewInt(0).Add(b.ToMathBig(), x.ToMathBig())) -} - -func (b *Int) Mul(x *Int) *Int { - return (*Int)(big.NewInt(0).Mul(b.ToMathBig(), x.ToMathBig())) +func (i *Int) UnmarshalJSON(p []byte) error { + if string(p) == "null" { + return nil + } + var z big.Int + _, ok := z.SetString(string(p), 10) + if !ok { + return fmt.Errorf("not a valid big integer: %s", p) + } + *i = (Int)(z) + return nil } -func (b *Int) Div(x *Int) *Int { - return (*Int)(big.NewInt(0).Div(b.ToMathBig(), x.ToMathBig())) +func (i *Int) Sub(x *Int) *Int { + return (*Int)(big.NewInt(0).Sub(i.ToMathBig(), x.ToMathBig())) } -func (b *Int) Neg() *Int { - return (*Int)(big.NewInt(0).Neg(b.ToMathBig())) +func (i *Int) Add(x *Int) *Int { + return (*Int)(big.NewInt(0).Add(i.ToMathBig(), x.ToMathBig())) } -func (b *Int) ToUInt64() uint64 { - return b.ToMathBig().Uint64() +func (i *Int) Mul(x *Int) *Int { + return (*Int)(big.NewInt(0).Mul(i.ToMathBig(), x.ToMathBig())) } -func (b *Int) ToInt64() int64 { - return b.ToMathBig().Int64() +func (i *Int) Div(x *Int) *Int { + return (*Int)(big.NewInt(0).Div(i.ToMathBig(), x.ToMathBig())) } -func (b *Int) String() string { - return b.ToMathBig().String() +func (i *Int) Neg() *Int { + return (*Int)(big.NewInt(0).Neg(i.ToMathBig())) } -func (b *Int) Abs() *Int { - return (*Int)(new(big.Int).Abs(b.ToMathBig())) +func (i *Int) Abs() *Int { + return (*Int)(new(big.Int).Abs(i.ToMathBig())) } var _ yaml.Unmarshaler = (*Int)(nil) // @todo , this part needs to be fixed -func (b *Int) UnmarshalYAML(value *yaml.Node) error { +func (i *Int) UnmarshalYAML(value *yaml.Node) error { var str string if err := value.Decode(&str); err != nil { return err @@ -115,8 +135,8 @@ func (b *Int) UnmarshalYAML(value *yaml.Node) error { return nil } -func (b *Int) Cmp(target *Int) Cmp { - return &cmpInt{r: b.ToMathBig().Cmp(target.ToMathBig())} +func (i *Int) Cmp(target *Int) Cmp { + return &cmpInt{r: i.ToMathBig().Cmp(target.ToMathBig())} } func (c *cmpInt) Eq() bool { diff --git a/extra/bunbig/int_test.go b/extra/bunbig/int_test.go index e3aa4ecdf..1d391c6fc 100644 --- a/extra/bunbig/int_test.go +++ b/extra/bunbig/int_test.go @@ -6,8 +6,9 @@ import ( "testing" "github.com/stretchr/testify/assert" - "github.com/uptrace/bun/extra/bunbig" "gopkg.in/yaml.v3" + + "github.com/uptrace/bun/extra/bunbig" ) func TestInt(t *testing.T) { @@ -20,6 +21,7 @@ func TestInt(t *testing.T) { // 100 * 200 = 20000 assert.Equal(t, bunbig.FromMathBig(big.NewInt(20000)), x.Mul(y)) }) + t.Run("add", func(t *testing.T) { x := bunbig.FromMathBig(a) y := bunbig.FromMathBig(b) @@ -55,25 +57,33 @@ func TestInt(t *testing.T) { x := bunbig.FromMathBig(a) assert.Equal(t, uint64(100), x.ToUInt64()) }) + t.Run("toString", func(t *testing.T) { x := bunbig.FromMathBig(a) assert.Equal(t, "100", x.String()) }) + t.Run("fromString", func(t *testing.T) { x, err := bunbig.NewInt().FromString("100") assert.Nil(t, err) assert.Equal(t, "100", x.String()) }) + t.Run("fromInt64", func(t *testing.T) { x := bunbig.FromInt64(100000000) assert.Equal(t, int64(100000000), x.ToInt64()) }) + t.Run("fromUInt64", func(t *testing.T) { + x := bunbig.FromUInt64(100000000) + assert.Equal(t, int64(100000000), x.ToInt64()) + }) + t.Run("Abs", func(t *testing.T) { x := bunbig.FromMathBig(a) - assert.Equal(t, x.Neg().Abs(), x) }) + t.Run("compare: ", func(t *testing.T) { x := bunbig.FromMathBig(a) // 100 y := bunbig.FromMathBig(b) // 200 @@ -98,12 +108,23 @@ func TestInt(t *testing.T) { }) t.Run("empty string ", func(t *testing.T) { - x, err := bunbig.NewInt().FromString("") assert.Nil(t, err) assert.Equal(t, x.ToInt64(), int64(0)) }) + t.Run("json", func(t *testing.T) { + i := bunbig.FromInt64(1337) + + r, err := i.MarshalJSON() + assert.Nil(t, err) + assert.Equal(t, "1337", string(r)) + + got := new(bunbig.Int) + err = got.UnmarshalJSON(r) + assert.Nil(t, err) + assert.Equal(t, uint64(1337), got.ToUInt64()) + }) } func TestFloat(t *testing.T) {