Skip to content

Commit

Permalink
expression, types: fix decimal minus/round/multiple result (#7001)
Browse files Browse the repository at this point in the history
  • Loading branch information
lysu authored and zz-jason committed Jul 11, 2018
1 parent cc72254 commit 5f7fc80
Show file tree
Hide file tree
Showing 8 changed files with 74 additions and 45 deletions.
3 changes: 2 additions & 1 deletion expression/aggregation/avg.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
package aggregation

import (
"github.com/cznic/mathutil"
"github.com/juju/errors"
"github.com/pingcap/tidb/mysql"
"github.com/pingcap/tidb/sessionctx/stmtctx"
Expand Down Expand Up @@ -84,7 +85,7 @@ func (af *avgFunction) GetResult(evalCtx *AggEvaluateContext) (d types.Datum) {
if frac == -1 {
frac = mysql.MaxDecimalScale
}
err = to.Round(to, frac, types.ModeHalfEven)
err = to.Round(to, mathutil.Min(frac, mysql.MaxDecimalScale), types.ModeHalfEven)
terror.Log(errors.Trace(err))
d.SetMysqlDecimal(to)
}
Expand Down
3 changes: 2 additions & 1 deletion expression/builtin_arithmetic.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
"github.com/juju/errors"
"github.com/pingcap/tidb/mysql"
"github.com/pingcap/tidb/sessionctx"
"github.com/pingcap/tidb/terror"
"github.com/pingcap/tidb/types"
"github.com/pingcap/tipb/go-tipb"
)
Expand Down Expand Up @@ -523,7 +524,7 @@ func (s *builtinArithmeticMultiplyDecimalSig) evalDecimal(row types.Row) (*types
}
c := &types.MyDecimal{}
err = types.DecimalMul(a, b, c)
if err != nil {
if err != nil && !terror.ErrorEqual(err, types.ErrTruncated) {
return nil, true, errors.Trace(err)
}
return c, false, nil
Expand Down
50 changes: 26 additions & 24 deletions expression/builtin_math.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import (
"strings"
"time"

"github.com/cznic/mathutil"
"github.com/juju/errors"
"github.com/pingcap/tidb/mysql"
"github.com/pingcap/tidb/sessionctx"
Expand Down Expand Up @@ -263,8 +264,10 @@ func (c *roundFunctionClass) getFunction(ctx sessionctx.Context, args []Expressi
if mysql.HasUnsignedFlag(argFieldTp.Flag) {
bf.tp.Flag |= mysql.UnsignedFlag
}

bf.tp.Flen = argFieldTp.Flen
bf.tp.Decimal = 0
bf.tp.Decimal = calculateDecimal4RoundAndTruncate(ctx, args, argTp)

var sig builtinFunc
if len(args) > 1 {
switch argTp {
Expand Down Expand Up @@ -292,6 +295,25 @@ func (c *roundFunctionClass) getFunction(ctx sessionctx.Context, args []Expressi
return sig, nil
}

// calculateDecimal4RoundAndTruncate calculates tp.decimals of round/truncate func.
func calculateDecimal4RoundAndTruncate(ctx sessionctx.Context, args []Expression, retType types.EvalType) int {
if retType == types.ETInt || len(args) <= 1 {
return 0
}
secondConst, secondIsConst := args[1].(*Constant)
if !secondIsConst {
return args[0].GetType().Decimal
}
argDec, isNull, err := secondConst.EvalInt(ctx, nil)
if err != nil || isNull || argDec < 0 {
return 0
}
if argDec > mysql.MaxDecimalScale {
return mysql.MaxDecimalScale
}
return int(argDec)
}

type builtinRoundRealSig struct {
baseBuiltinFunc
}
Expand Down Expand Up @@ -422,7 +444,7 @@ func (b *builtinRoundWithFracDecSig) evalDecimal(row types.Row) (*types.MyDecima
return nil, isNull, errors.Trace(err)
}
to := new(types.MyDecimal)
if err = val.Round(to, int(frac), types.ModeHalfEven); err != nil {
if err = val.Round(to, mathutil.Min(int(frac), b.tp.Decimal), types.ModeHalfEven); err != nil {
return nil, true, errors.Trace(err)
}
return to, false, nil
Expand Down Expand Up @@ -1695,22 +1717,6 @@ type truncateFunctionClass struct {
baseFunctionClass
}

// getDecimal returns the `Decimal` value of return type for function `TRUNCATE`.
func (c *truncateFunctionClass) getDecimal(ctx sessionctx.Context, arg Expression) int {
if constant, ok := arg.(*Constant); ok {
decimal, isNull, err := constant.EvalInt(ctx, nil)
if isNull || err != nil {
return 0
} else if decimal > 30 {
return 30
} else if decimal < 0 {
return 0
}
return int(decimal)
}
return 3
}

func (c *truncateFunctionClass) getFunction(ctx sessionctx.Context, args []Expression) (builtinFunc, error) {
if err := c.verifyArgs(args); err != nil {
return nil, errors.Trace(err)
Expand All @@ -1723,11 +1729,7 @@ func (c *truncateFunctionClass) getFunction(ctx sessionctx.Context, args []Expre

bf := newBaseBuiltinFuncWithTp(ctx, args, argTp, argTp, types.ETInt)

if argTp == types.ETInt {
bf.tp.Decimal = 0
} else {
bf.tp.Decimal = c.getDecimal(bf.ctx, args[1])
}
bf.tp.Decimal = calculateDecimal4RoundAndTruncate(ctx, args, argTp)
bf.tp.Flen = args[0].GetType().Flen - args[0].GetType().Decimal + bf.tp.Decimal
bf.tp.Flag |= args[0].GetType().Flag

Expand Down Expand Up @@ -1768,7 +1770,7 @@ func (b *builtinTruncateDecimalSig) evalDecimal(row types.Row) (*types.MyDecimal
}

result := new(types.MyDecimal)
if err := x.Round(result, int(d), types.ModeTruncate); err != nil {
if err := x.Round(result, mathutil.Min(int(d), b.getRetTp().Decimal), types.ModeTruncate); err != nil {
return nil, true, errors.Trace(err)
}
return result, false, nil
Expand Down
6 changes: 1 addition & 5 deletions expression/builtin_op.go
Original file line number Diff line number Diff line change
Expand Up @@ -749,15 +749,11 @@ func (b *builtinUnaryMinusDecimalSig) Clone() builtinFunc {
}

func (b *builtinUnaryMinusDecimalSig) evalDecimal(row types.Row) (*types.MyDecimal, bool, error) {
var dec *types.MyDecimal
dec, isNull, err := b.args[0].EvalDecimal(b.ctx, row)
if err != nil || isNull {
return dec, isNull, errors.Trace(err)
}

to := new(types.MyDecimal)
err = types.DecimalSub(new(types.MyDecimal), dec, to)
return to, err != nil, errors.Trace(err)
return types.DecimalNeg(dec), false, nil
}

type builtinUnaryMinusRealSig struct {
Expand Down
4 changes: 3 additions & 1 deletion expression/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3372,7 +3372,7 @@ func newStoreWithBootstrap() (kv.Storage, *domain.Domain, error) {
return store, dom, errors.Trace(err)
}

func (s *testIntegrationSuite) TestTwoDecimalAssignTruncate(c *C) {
func (s *testIntegrationSuite) TestTwoDecimalTruncate(c *C) {
tk := testkit.NewTestKit(c, s.store)
defer s.cleanEnv(c)
tk.MustExec("use test")
Expand All @@ -3383,4 +3383,6 @@ func (s *testIntegrationSuite) TestTwoDecimalAssignTruncate(c *C) {
tk.MustExec("update t1 set b = a")
res := tk.MustQuery("select a, b from t1")
res.Check(testkit.Rows("123.12345 123.1"))
res = tk.MustQuery("select 2.00000000000000000000000000000001 * 1.000000000000000000000000000000000000000000002")
res.Check(testkit.Rows("2.000000000000000000000000000000"))
}
6 changes: 3 additions & 3 deletions expression/typeinfer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -591,9 +591,9 @@ func (s *testInferTypeSuite) createTestCase4MathFuncs() []typeInferTestCase {

{"round(c_int_d )", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, 11, 0},
{"round(c_bigint_d )", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, 20, 0},
{"round(c_float_d )", mysql.TypeDouble, charset.CharsetBin, mysql.BinaryFlag, 12, 0}, // Should be 17.
{"round(c_double_d )", mysql.TypeDouble, charset.CharsetBin, mysql.BinaryFlag, 22, 0}, // Should be 17.
{"round(c_decimal )", mysql.TypeNewDecimal, charset.CharsetBin, mysql.BinaryFlag, 6, 0}, // Should be 5.
{"round(c_float_d )", mysql.TypeDouble, charset.CharsetBin, mysql.BinaryFlag, 12, 0}, // flen Should be 17.
{"round(c_double_d )", mysql.TypeDouble, charset.CharsetBin, mysql.BinaryFlag, 22, 0}, // flen Should be 17.
{"round(c_decimal )", mysql.TypeNewDecimal, charset.CharsetBin, mysql.BinaryFlag, 6, 0}, // flen Should be 5.
{"round(c_datetime )", mysql.TypeDouble, charset.CharsetBin, mysql.BinaryFlag, mysql.MaxRealWidth, 0},
{"round(c_time_d )", mysql.TypeDouble, charset.CharsetBin, mysql.BinaryFlag, mysql.MaxRealWidth, 0},
{"round(c_timestamp_d)", mysql.TypeDouble, charset.CharsetBin, mysql.BinaryFlag, mysql.MaxRealWidth, 0},
Expand Down
23 changes: 15 additions & 8 deletions types/mydecimal.go
Original file line number Diff line number Diff line change
Expand Up @@ -713,9 +713,6 @@ func (d *MyDecimal) doMiniRightShift(shift, beg, end int) {
// RETURN VALUE
// eDecOK/eDecTruncated
func (d *MyDecimal) Round(to *MyDecimal, frac int, roundMode RoundMode) (err error) {
if frac > mysql.MaxDecimalScale {
frac = mysql.MaxDecimalScale
}
// wordsFracTo is the number of fraction words in buffer.
wordsFracTo := (frac + 1) / digitsPerWord
if frac > 0 {
Expand Down Expand Up @@ -1383,6 +1380,16 @@ func (d *MyDecimal) Compare(to *MyDecimal) int {
return 1
}

// DecimalNeg reverses decimal's sign.
func DecimalNeg(from *MyDecimal) *MyDecimal {
to := *from
if from.IsZero() {
return &to
}
to.negative = !from.negative
return &to
}

// DecimalAdd adds two decimals, sets the result to 'to'.
func DecimalAdd(from1, from2, to *MyDecimal) error {
to.resultFrac = myMaxInt8(from1.resultFrac, from2.resultFrac)
Expand Down Expand Up @@ -1753,7 +1760,7 @@ func DecimalMul(from1, from2, to *MyDecimal) error {
to.digitsFrac = int8(wordsFracTo * digitsPerWord)
}
if to.digitsInt > int8(wordsIntTo*digitsPerWord) {
to.digitsInt = int8(wordsFracTo * digitsPerWord)
to.digitsInt = int8(wordsIntTo * digitsPerWord)
}
if tmp1 > wordsIntTo {
tmp1 -= wordsIntTo
Expand All @@ -1762,7 +1769,7 @@ func DecimalMul(from1, from2, to *MyDecimal) error {
wordsFrac1 = 0
wordsFrac2 = 0
} else {
tmp2 -= wordsIntTo
tmp2 -= wordsFracTo
tmp1 = tmp2 >> 1
if wordsFrac1 <= wordsFrac2 {
wordsFrac1 -= tmp1
Expand All @@ -1774,9 +1781,9 @@ func DecimalMul(from1, from2, to *MyDecimal) error {
}
}
startTo := wordsIntTo + wordsFracTo - 1
start2 := wordsInt2 + wordsFrac2 - 1
stop1 := 0
stop2 := 0
start2 := idx2 + wordsFrac2 - 1
stop1 := idx1 - wordsInt1
stop2 := idx2 - wordsInt2
to.wordBuf = zeroMyDecimal.wordBuf

for idx1 += wordsFrac1 - 1; idx1 >= stop1; idx1-- {
Expand Down
24 changes: 22 additions & 2 deletions types/mydecimal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -543,6 +543,25 @@ func (s *testMyDecimalSuite) TestMaxDecimal(c *C) {
}
}

func (s *testMyDecimalSuite) TestNeg(c *C) {
type testCase struct {
a string
result string
err error
}
tests := []testCase{
{"-0.0000000000000000000000000000000000000000000000000017382578996420603", "0.0000000000000000000000000000000000000000000000000017382578996420603", nil},
{"-13890436710184412000000000000000000000000000000000000000000000000000000000000", "13890436710184412000000000000000000000000000000000000000000000000000000000000", nil},
{"0", "0", nil},
}
for _, tt := range tests {
a := NewDecFromStringForTest(tt.a)
negResult := DecimalNeg(a)
result := negResult.ToString()
c.Assert(string(result), Equals, tt.result)
}
}

func (s *testMyDecimalSuite) TestAdd(c *C) {
type testCase struct {
a string
Expand Down Expand Up @@ -627,6 +646,7 @@ func (s *testMyDecimalSuite) TestMul(c *C) {
{"123456", "9876543210", "1219318518533760", nil},
{"123", "0.01", "1.23", nil},
{"123", "0", "0", nil},
{"-0.0000000000000000000000000000000000000000000000000017382578996420603", "-13890436710184412000000000000000000000000000000000000000000000000000000000000", "0.000000000000000000000000000000", ErrTruncated},
{"1" + strings.Repeat("0", 60), "1" + strings.Repeat("0", 60), "0", ErrOverflow},
}
for _, tt := range tests {
Expand All @@ -635,8 +655,8 @@ func (s *testMyDecimalSuite) TestMul(c *C) {
b.FromString([]byte(tt.b))
err := DecimalMul(&a, &b, &product)
c.Check(err, Equals, tt.err)
result := product.ToString()
c.Assert(string(result), Equals, tt.result)
result := product.String()
c.Assert(result, Equals, tt.result)
}
}

Expand Down

0 comments on commit 5f7fc80

Please sign in to comment.