From 5f7fc8038941ad12a10a700420c7ad3f02cfed0e Mon Sep 17 00:00:00 2001 From: lysu Date: Thu, 12 Jul 2018 00:02:19 +0800 Subject: [PATCH] expression, types: fix decimal minus/round/multiple result (#7001) --- expression/aggregation/avg.go | 3 +- expression/builtin_arithmetic.go | 3 +- expression/builtin_math.go | 50 +++++++++++++++++--------------- expression/builtin_op.go | 6 +--- expression/integration_test.go | 4 ++- expression/typeinfer_test.go | 6 ++-- types/mydecimal.go | 23 ++++++++++----- types/mydecimal_test.go | 24 +++++++++++++-- 8 files changed, 74 insertions(+), 45 deletions(-) diff --git a/expression/aggregation/avg.go b/expression/aggregation/avg.go index 83481d060f634..7f60b54c94026 100644 --- a/expression/aggregation/avg.go +++ b/expression/aggregation/avg.go @@ -14,6 +14,7 @@ package aggregation import ( + "github.com/cznic/mathutil" "github.com/juju/errors" "github.com/pingcap/tidb/mysql" "github.com/pingcap/tidb/sessionctx/stmtctx" @@ -84,7 +85,7 @@ func (af *avgFunction) GetResult(evalCtx *AggEvaluateContext) (d types.Datum) { if frac == -1 { frac = mysql.MaxDecimalScale } - err = to.Round(to, frac, types.ModeHalfEven) + err = to.Round(to, mathutil.Min(frac, mysql.MaxDecimalScale), types.ModeHalfEven) terror.Log(errors.Trace(err)) d.SetMysqlDecimal(to) } diff --git a/expression/builtin_arithmetic.go b/expression/builtin_arithmetic.go index 75495371cd15c..4b75e5c6a3ad7 100644 --- a/expression/builtin_arithmetic.go +++ b/expression/builtin_arithmetic.go @@ -21,6 +21,7 @@ import ( "github.com/juju/errors" "github.com/pingcap/tidb/mysql" "github.com/pingcap/tidb/sessionctx" + "github.com/pingcap/tidb/terror" "github.com/pingcap/tidb/types" "github.com/pingcap/tipb/go-tipb" ) @@ -523,7 +524,7 @@ func (s *builtinArithmeticMultiplyDecimalSig) evalDecimal(row types.Row) (*types } c := &types.MyDecimal{} err = types.DecimalMul(a, b, c) - if err != nil { + if err != nil && !terror.ErrorEqual(err, types.ErrTruncated) { return nil, true, errors.Trace(err) } return c, false, nil diff --git a/expression/builtin_math.go b/expression/builtin_math.go index e5282aca05e93..d01320d2abf8d 100644 --- a/expression/builtin_math.go +++ b/expression/builtin_math.go @@ -26,6 +26,7 @@ import ( "strings" "time" + "github.com/cznic/mathutil" "github.com/juju/errors" "github.com/pingcap/tidb/mysql" "github.com/pingcap/tidb/sessionctx" @@ -263,8 +264,10 @@ func (c *roundFunctionClass) getFunction(ctx sessionctx.Context, args []Expressi if mysql.HasUnsignedFlag(argFieldTp.Flag) { bf.tp.Flag |= mysql.UnsignedFlag } + bf.tp.Flen = argFieldTp.Flen - bf.tp.Decimal = 0 + bf.tp.Decimal = calculateDecimal4RoundAndTruncate(ctx, args, argTp) + var sig builtinFunc if len(args) > 1 { switch argTp { @@ -292,6 +295,25 @@ func (c *roundFunctionClass) getFunction(ctx sessionctx.Context, args []Expressi return sig, nil } +// calculateDecimal4RoundAndTruncate calculates tp.decimals of round/truncate func. +func calculateDecimal4RoundAndTruncate(ctx sessionctx.Context, args []Expression, retType types.EvalType) int { + if retType == types.ETInt || len(args) <= 1 { + return 0 + } + secondConst, secondIsConst := args[1].(*Constant) + if !secondIsConst { + return args[0].GetType().Decimal + } + argDec, isNull, err := secondConst.EvalInt(ctx, nil) + if err != nil || isNull || argDec < 0 { + return 0 + } + if argDec > mysql.MaxDecimalScale { + return mysql.MaxDecimalScale + } + return int(argDec) +} + type builtinRoundRealSig struct { baseBuiltinFunc } @@ -422,7 +444,7 @@ func (b *builtinRoundWithFracDecSig) evalDecimal(row types.Row) (*types.MyDecima return nil, isNull, errors.Trace(err) } to := new(types.MyDecimal) - if err = val.Round(to, int(frac), types.ModeHalfEven); err != nil { + if err = val.Round(to, mathutil.Min(int(frac), b.tp.Decimal), types.ModeHalfEven); err != nil { return nil, true, errors.Trace(err) } return to, false, nil @@ -1695,22 +1717,6 @@ type truncateFunctionClass struct { baseFunctionClass } -// getDecimal returns the `Decimal` value of return type for function `TRUNCATE`. -func (c *truncateFunctionClass) getDecimal(ctx sessionctx.Context, arg Expression) int { - if constant, ok := arg.(*Constant); ok { - decimal, isNull, err := constant.EvalInt(ctx, nil) - if isNull || err != nil { - return 0 - } else if decimal > 30 { - return 30 - } else if decimal < 0 { - return 0 - } - return int(decimal) - } - return 3 -} - func (c *truncateFunctionClass) getFunction(ctx sessionctx.Context, args []Expression) (builtinFunc, error) { if err := c.verifyArgs(args); err != nil { return nil, errors.Trace(err) @@ -1723,11 +1729,7 @@ func (c *truncateFunctionClass) getFunction(ctx sessionctx.Context, args []Expre bf := newBaseBuiltinFuncWithTp(ctx, args, argTp, argTp, types.ETInt) - if argTp == types.ETInt { - bf.tp.Decimal = 0 - } else { - bf.tp.Decimal = c.getDecimal(bf.ctx, args[1]) - } + bf.tp.Decimal = calculateDecimal4RoundAndTruncate(ctx, args, argTp) bf.tp.Flen = args[0].GetType().Flen - args[0].GetType().Decimal + bf.tp.Decimal bf.tp.Flag |= args[0].GetType().Flag @@ -1768,7 +1770,7 @@ func (b *builtinTruncateDecimalSig) evalDecimal(row types.Row) (*types.MyDecimal } result := new(types.MyDecimal) - if err := x.Round(result, int(d), types.ModeTruncate); err != nil { + if err := x.Round(result, mathutil.Min(int(d), b.getRetTp().Decimal), types.ModeTruncate); err != nil { return nil, true, errors.Trace(err) } return result, false, nil diff --git a/expression/builtin_op.go b/expression/builtin_op.go index 3cdb698addd9b..875fb5cc1f692 100644 --- a/expression/builtin_op.go +++ b/expression/builtin_op.go @@ -749,15 +749,11 @@ func (b *builtinUnaryMinusDecimalSig) Clone() builtinFunc { } func (b *builtinUnaryMinusDecimalSig) evalDecimal(row types.Row) (*types.MyDecimal, bool, error) { - var dec *types.MyDecimal dec, isNull, err := b.args[0].EvalDecimal(b.ctx, row) if err != nil || isNull { return dec, isNull, errors.Trace(err) } - - to := new(types.MyDecimal) - err = types.DecimalSub(new(types.MyDecimal), dec, to) - return to, err != nil, errors.Trace(err) + return types.DecimalNeg(dec), false, nil } type builtinUnaryMinusRealSig struct { diff --git a/expression/integration_test.go b/expression/integration_test.go index 4f80d9933cf48..6d86f04e9f8d5 100644 --- a/expression/integration_test.go +++ b/expression/integration_test.go @@ -3372,7 +3372,7 @@ func newStoreWithBootstrap() (kv.Storage, *domain.Domain, error) { return store, dom, errors.Trace(err) } -func (s *testIntegrationSuite) TestTwoDecimalAssignTruncate(c *C) { +func (s *testIntegrationSuite) TestTwoDecimalTruncate(c *C) { tk := testkit.NewTestKit(c, s.store) defer s.cleanEnv(c) tk.MustExec("use test") @@ -3383,4 +3383,6 @@ func (s *testIntegrationSuite) TestTwoDecimalAssignTruncate(c *C) { tk.MustExec("update t1 set b = a") res := tk.MustQuery("select a, b from t1") res.Check(testkit.Rows("123.12345 123.1")) + res = tk.MustQuery("select 2.00000000000000000000000000000001 * 1.000000000000000000000000000000000000000000002") + res.Check(testkit.Rows("2.000000000000000000000000000000")) } diff --git a/expression/typeinfer_test.go b/expression/typeinfer_test.go index 10d8819eabba5..5d18e45dc829a 100644 --- a/expression/typeinfer_test.go +++ b/expression/typeinfer_test.go @@ -591,9 +591,9 @@ func (s *testInferTypeSuite) createTestCase4MathFuncs() []typeInferTestCase { {"round(c_int_d )", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, 11, 0}, {"round(c_bigint_d )", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, 20, 0}, - {"round(c_float_d )", mysql.TypeDouble, charset.CharsetBin, mysql.BinaryFlag, 12, 0}, // Should be 17. - {"round(c_double_d )", mysql.TypeDouble, charset.CharsetBin, mysql.BinaryFlag, 22, 0}, // Should be 17. - {"round(c_decimal )", mysql.TypeNewDecimal, charset.CharsetBin, mysql.BinaryFlag, 6, 0}, // Should be 5. + {"round(c_float_d )", mysql.TypeDouble, charset.CharsetBin, mysql.BinaryFlag, 12, 0}, // flen Should be 17. + {"round(c_double_d )", mysql.TypeDouble, charset.CharsetBin, mysql.BinaryFlag, 22, 0}, // flen Should be 17. + {"round(c_decimal )", mysql.TypeNewDecimal, charset.CharsetBin, mysql.BinaryFlag, 6, 0}, // flen Should be 5. {"round(c_datetime )", mysql.TypeDouble, charset.CharsetBin, mysql.BinaryFlag, mysql.MaxRealWidth, 0}, {"round(c_time_d )", mysql.TypeDouble, charset.CharsetBin, mysql.BinaryFlag, mysql.MaxRealWidth, 0}, {"round(c_timestamp_d)", mysql.TypeDouble, charset.CharsetBin, mysql.BinaryFlag, mysql.MaxRealWidth, 0}, diff --git a/types/mydecimal.go b/types/mydecimal.go index 5185f503c3b71..71391b415b2dd 100644 --- a/types/mydecimal.go +++ b/types/mydecimal.go @@ -713,9 +713,6 @@ func (d *MyDecimal) doMiniRightShift(shift, beg, end int) { // RETURN VALUE // eDecOK/eDecTruncated func (d *MyDecimal) Round(to *MyDecimal, frac int, roundMode RoundMode) (err error) { - if frac > mysql.MaxDecimalScale { - frac = mysql.MaxDecimalScale - } // wordsFracTo is the number of fraction words in buffer. wordsFracTo := (frac + 1) / digitsPerWord if frac > 0 { @@ -1383,6 +1380,16 @@ func (d *MyDecimal) Compare(to *MyDecimal) int { return 1 } +// DecimalNeg reverses decimal's sign. +func DecimalNeg(from *MyDecimal) *MyDecimal { + to := *from + if from.IsZero() { + return &to + } + to.negative = !from.negative + return &to +} + // DecimalAdd adds two decimals, sets the result to 'to'. func DecimalAdd(from1, from2, to *MyDecimal) error { to.resultFrac = myMaxInt8(from1.resultFrac, from2.resultFrac) @@ -1753,7 +1760,7 @@ func DecimalMul(from1, from2, to *MyDecimal) error { to.digitsFrac = int8(wordsFracTo * digitsPerWord) } if to.digitsInt > int8(wordsIntTo*digitsPerWord) { - to.digitsInt = int8(wordsFracTo * digitsPerWord) + to.digitsInt = int8(wordsIntTo * digitsPerWord) } if tmp1 > wordsIntTo { tmp1 -= wordsIntTo @@ -1762,7 +1769,7 @@ func DecimalMul(from1, from2, to *MyDecimal) error { wordsFrac1 = 0 wordsFrac2 = 0 } else { - tmp2 -= wordsIntTo + tmp2 -= wordsFracTo tmp1 = tmp2 >> 1 if wordsFrac1 <= wordsFrac2 { wordsFrac1 -= tmp1 @@ -1774,9 +1781,9 @@ func DecimalMul(from1, from2, to *MyDecimal) error { } } startTo := wordsIntTo + wordsFracTo - 1 - start2 := wordsInt2 + wordsFrac2 - 1 - stop1 := 0 - stop2 := 0 + start2 := idx2 + wordsFrac2 - 1 + stop1 := idx1 - wordsInt1 + stop2 := idx2 - wordsInt2 to.wordBuf = zeroMyDecimal.wordBuf for idx1 += wordsFrac1 - 1; idx1 >= stop1; idx1-- { diff --git a/types/mydecimal_test.go b/types/mydecimal_test.go index 0b2ffec4c7cb7..e72b9658327da 100644 --- a/types/mydecimal_test.go +++ b/types/mydecimal_test.go @@ -543,6 +543,25 @@ func (s *testMyDecimalSuite) TestMaxDecimal(c *C) { } } +func (s *testMyDecimalSuite) TestNeg(c *C) { + type testCase struct { + a string + result string + err error + } + tests := []testCase{ + {"-0.0000000000000000000000000000000000000000000000000017382578996420603", "0.0000000000000000000000000000000000000000000000000017382578996420603", nil}, + {"-13890436710184412000000000000000000000000000000000000000000000000000000000000", "13890436710184412000000000000000000000000000000000000000000000000000000000000", nil}, + {"0", "0", nil}, + } + for _, tt := range tests { + a := NewDecFromStringForTest(tt.a) + negResult := DecimalNeg(a) + result := negResult.ToString() + c.Assert(string(result), Equals, tt.result) + } +} + func (s *testMyDecimalSuite) TestAdd(c *C) { type testCase struct { a string @@ -627,6 +646,7 @@ func (s *testMyDecimalSuite) TestMul(c *C) { {"123456", "9876543210", "1219318518533760", nil}, {"123", "0.01", "1.23", nil}, {"123", "0", "0", nil}, + {"-0.0000000000000000000000000000000000000000000000000017382578996420603", "-13890436710184412000000000000000000000000000000000000000000000000000000000000", "0.000000000000000000000000000000", ErrTruncated}, {"1" + strings.Repeat("0", 60), "1" + strings.Repeat("0", 60), "0", ErrOverflow}, } for _, tt := range tests { @@ -635,8 +655,8 @@ func (s *testMyDecimalSuite) TestMul(c *C) { b.FromString([]byte(tt.b)) err := DecimalMul(&a, &b, &product) c.Check(err, Equals, tt.err) - result := product.ToString() - c.Assert(string(result), Equals, tt.result) + result := product.String() + c.Assert(result, Equals, tt.result) } }