From 7bcb77c9fff12e1679693b330260b69c96b2b4d2 Mon Sep 17 00:00:00 2001 From: Yuanjia Zhang Date: Thu, 9 May 2019 11:11:02 +0800 Subject: [PATCH 1/3] expression: check if period is valid in `period_add` (#10380) --- expression/builtin_time.go | 14 ++++++++++---- expression/builtin_time_test.go | 4 ++-- expression/integration_test.go | 14 ++++++++++---- 3 files changed, 22 insertions(+), 10 deletions(-) diff --git a/expression/builtin_time.go b/expression/builtin_time.go index ec055d6f4eb74..5585414531017 100644 --- a/expression/builtin_time.go +++ b/expression/builtin_time.go @@ -4887,6 +4887,11 @@ func (c *periodAddFunctionClass) getFunction(ctx sessionctx.Context, args []Expr return sig, nil } +// validPeriod checks if this period is valid, it comes from MySQL 8.0+. +func validPeriod(p int64) bool { + return !(p < 0 || p%100 == 0 || p%100 > 12) +} + // period2Month converts a period to months, in which period is represented in the format of YYMM or YYYYMM. // Note that the period argument is not a date value. func period2Month(period uint64) uint64 { @@ -4938,15 +4943,16 @@ func (b *builtinPeriodAddSig) evalInt(row chunk.Row) (int64, bool, error) { return 0, true, errors.Trace(err) } - if p == 0 { - return 0, false, nil - } - n, isNull, err := b.args[1].EvalInt(b.ctx, row) if isNull || err != nil { return 0, true, errors.Trace(err) } + // in MySQL, if p is invalid but n is NULL, the result is NULL, so we have to check if n is NULL first. + if !validPeriod(p) { + return 0, false, errIncorrectArgs.GenWithStackByArgs("period_add") + } + sumMonth := int64(period2Month(uint64(p))) + n return int64(month2Period(uint64(sumMonth))), false, nil } diff --git a/expression/builtin_time_test.go b/expression/builtin_time_test.go index e52ddf09c44cb..3f59ef6e8acac 100644 --- a/expression/builtin_time_test.go +++ b/expression/builtin_time_test.go @@ -2146,8 +2146,8 @@ func (s *testEvaluatorSuite) TestPeriodAdd(c *C) { {201611, -13, true, 201510}, {1611, 3, true, 201702}, {7011, 3, true, 197102}, - {12323, 10, true, 12509}, - {0, 3, true, 0}, + {12323, 10, false, 0}, + {0, 3, false, 0}, } fc := funcs[ast.PeriodAdd] diff --git a/expression/integration_test.go b/expression/integration_test.go index ab96348ff247c..3feabea5e8391 100644 --- a/expression/integration_test.go +++ b/expression/integration_test.go @@ -1438,10 +1438,16 @@ func (s *testIntegrationSuite) TestTimeBuiltin(c *C) { result.Check(testkit.Rows("123456 10 ")) // for period_add - result = tk.MustQuery(`SELECT period_add(191, 2), period_add(191, -2), period_add(0, 20), period_add(0, 0);`) - result.Check(testkit.Rows("200809 200805 0 0")) - result = tk.MustQuery(`SELECT period_add(NULL, 2), period_add(-191, NULL), period_add(NULL, NULL), period_add(12.09, -2), period_add("21aa", "11aa"), period_add("", "");`) - result.Check(testkit.Rows(" 200010 200208 0")) + result = tk.MustQuery(`SELECT period_add(200807, 2), period_add(200807, -2);`) + result.Check(testkit.Rows("200809 200805")) + result = tk.MustQuery(`SELECT period_add(NULL, 2), period_add(-191, NULL), period_add(NULL, NULL), period_add(12.09, -2), period_add("200207aa", "1aa");`) + result.Check(testkit.Rows(" 200010 200208")) + for _, errPeriod := range []string{ + "period_add(0, 20)", "period_add(0, 0)", "period_add(-1, 1)", "period_add(200013, 1)", "period_add(-200012, 1)", "period_add('', '')", + } { + err := tk.QueryToErr(fmt.Sprintf("SELECT %v;", errPeriod)) + c.Assert(err.Error(), Equals, "[expression:1210]Incorrect arguments to period_add") + } // for period_diff result = tk.MustQuery(`SELECT period_diff(191, 2), period_diff(191, -2), period_diff(0, 0), period_diff(191, 191);`) From b88fd54d543b76ede175c84aa979c47fe4ba8f52 Mon Sep 17 00:00:00 2001 From: qw4990 Date: Mon, 13 May 2019 16:51:13 +0800 Subject: [PATCH 2/3] fix CI --- expression/integration_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/expression/integration_test.go b/expression/integration_test.go index 3feabea5e8391..49967b2b8510e 100644 --- a/expression/integration_test.go +++ b/expression/integration_test.go @@ -1445,7 +1445,7 @@ func (s *testIntegrationSuite) TestTimeBuiltin(c *C) { for _, errPeriod := range []string{ "period_add(0, 20)", "period_add(0, 0)", "period_add(-1, 1)", "period_add(200013, 1)", "period_add(-200012, 1)", "period_add('', '')", } { - err := tk.QueryToErr(fmt.Sprintf("SELECT %v;", errPeriod)) + _, err := tk.Exec(fmt.Sprintf("SELECT %v;", errPeriod)) c.Assert(err.Error(), Equals, "[expression:1210]Incorrect arguments to period_add") } From cb8247e826b418c1cf3d9f100e312431c503e65e Mon Sep 17 00:00:00 2001 From: qw4990 Date: Wed, 15 May 2019 16:48:27 +0800 Subject: [PATCH 3/3] fix CI --- expression/integration_test.go | 2 +- util/testkit/testkit.go | 11 +++++++++++ 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/expression/integration_test.go b/expression/integration_test.go index 8a22f8a8945cc..52330581451ca 100644 --- a/expression/integration_test.go +++ b/expression/integration_test.go @@ -1457,7 +1457,7 @@ func (s *testIntegrationSuite) TestTimeBuiltin(c *C) { for _, errPeriod := range []string{ "period_add(0, 20)", "period_add(0, 0)", "period_add(-1, 1)", "period_add(200013, 1)", "period_add(-200012, 1)", "period_add('', '')", } { - _, err := tk.Exec(fmt.Sprintf("SELECT %v;", errPeriod)) + err := tk.QueryToErr(fmt.Sprintf("SELECT %v;", errPeriod)) c.Assert(err.Error(), Equals, "[expression:1210]Incorrect arguments to period_add") } diff --git a/util/testkit/testkit.go b/util/testkit/testkit.go index ca12e8a83cd3f..551a5da76d69d 100644 --- a/util/testkit/testkit.go +++ b/util/testkit/testkit.go @@ -180,6 +180,17 @@ func (tk *TestKit) MustQuery(sql string, args ...interface{}) *Result { return tk.ResultSetToResult(rs, comment) } +// QueryToErr executes a sql statement and discard results. +func (tk *TestKit) QueryToErr(sql string, args ...interface{}) error { + comment := check.Commentf("sql:%s, args:%v", sql, args) + res, err := tk.Exec(sql, args...) + tk.c.Assert(errors.ErrorStack(err), check.Equals, "", comment) + tk.c.Assert(res, check.NotNil, comment) + _, resErr := session.GetRows4Test(context.Background(), tk.Se, res) + tk.c.Assert(res.Close(), check.IsNil) + return resErr +} + // ResultSetToResult converts sqlexec.RecordSet to testkit.Result. // It is used to check results of execute statement in binary mode. func (tk *TestKit) ResultSetToResult(rs sqlexec.RecordSet, comment check.CommentInterface) *Result {