From f5296960f2b52d146aae135b43569f51a0124893 Mon Sep 17 00:00:00 2001 From: ti-srebot <66930949+ti-srebot@users.noreply.github.com> Date: Fri, 10 Sep 2021 17:48:39 +0800 Subject: [PATCH] expression: fix extract bug when argument is a negative duration (#27318) (#27369) --- expression/integration_test.go | 13 ++++++ types/time.go | 34 ++++++++------- types/time_test.go | 80 ++++++++++++++++++++++++---------- 3 files changed, 88 insertions(+), 39 deletions(-) diff --git a/expression/integration_test.go b/expression/integration_test.go index cf5a712a33b8d..5905fc8356f3e 100755 --- a/expression/integration_test.go +++ b/expression/integration_test.go @@ -8284,6 +8284,19 @@ func (s *testIntegrationSuite) TestJiraSetInnoDBDefaultRowFormat(c *C) { tk.MustQuery("SHOW VARIABLES LIKE 'innodb_large_prefix'").Check(testkit.Rows("innodb_large_prefix ON")) } +func (s *testIntegrationSuite) TestIssue27236(c *C) { + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("use test;") + row := tk.MustQuery(`select extract(hour_second from "-838:59:59.00");`) + row.Check(testkit.Rows("-8385959")) + + tk.MustExec(`drop table if exists t`) + tk.MustExec(`create table t(c1 varchar(100));`) + tk.MustExec(`insert into t values('-838:59:59.00'), ('700:59:59.00');`) + row = tk.MustQuery(`select extract(hour_second from c1) from t order by c1;`) + row.Check(testkit.Rows("-8385959", "7005959")) +} + func (s *testIntegrationSuite) TestConstPropNullFunctions(c *C) { tk := testkit.NewTestKit(c, s.store) tk.MustExec("use test") diff --git a/types/time.go b/types/time.go index b740a06c02569..3fab85307e25e 100644 --- a/types/time.go +++ b/types/time.go @@ -2182,39 +2182,43 @@ func ExtractDatetimeNum(t *Time, unit string) (int64, error) { } // ExtractDurationNum extracts duration value number from duration unit and format. -func ExtractDurationNum(d *Duration, unit string) (int64, error) { +func ExtractDurationNum(d *Duration, unit string) (res int64, err error) { switch strings.ToUpper(unit) { case "MICROSECOND": - return int64(d.MicroSecond()), nil + res = int64(d.MicroSecond()) case "SECOND": - return int64(d.Second()), nil + res = int64(d.Second()) case "MINUTE": - return int64(d.Minute()), nil + res = int64(d.Minute()) case "HOUR": - return int64(d.Hour()), nil + res = int64(d.Hour()) case "SECOND_MICROSECOND": - return int64(d.Second())*1000000 + int64(d.MicroSecond()), nil + res = int64(d.Second())*1000000 + int64(d.MicroSecond()) case "MINUTE_MICROSECOND": - return int64(d.Minute())*100000000 + int64(d.Second())*1000000 + int64(d.MicroSecond()), nil + res = int64(d.Minute())*100000000 + int64(d.Second())*1000000 + int64(d.MicroSecond()) case "MINUTE_SECOND": - return int64(d.Minute()*100 + d.Second()), nil + res = int64(d.Minute()*100 + d.Second()) case "HOUR_MICROSECOND": - return int64(d.Hour())*10000000000 + int64(d.Minute())*100000000 + int64(d.Second())*1000000 + int64(d.MicroSecond()), nil + res = int64(d.Hour())*10000000000 + int64(d.Minute())*100000000 + int64(d.Second())*1000000 + int64(d.MicroSecond()) case "HOUR_SECOND": - return int64(d.Hour())*10000 + int64(d.Minute())*100 + int64(d.Second()), nil + res = int64(d.Hour())*10000 + int64(d.Minute())*100 + int64(d.Second()) case "HOUR_MINUTE": - return int64(d.Hour())*100 + int64(d.Minute()), nil + res = int64(d.Hour())*100 + int64(d.Minute()) case "DAY_MICROSECOND": - return int64(d.Hour()*10000+d.Minute()*100+d.Second())*1000000 + int64(d.MicroSecond()), nil + res = int64(d.Hour()*10000+d.Minute()*100+d.Second())*1000000 + int64(d.MicroSecond()) case "DAY_SECOND": - return int64(d.Hour())*10000 + int64(d.Minute())*100 + int64(d.Second()), nil + res = int64(d.Hour())*10000 + int64(d.Minute())*100 + int64(d.Second()) case "DAY_MINUTE": - return int64(d.Hour())*100 + int64(d.Minute()), nil + res = int64(d.Hour())*100 + int64(d.Minute()) case "DAY_HOUR": - return int64(d.Hour()), nil + res = int64(d.Hour()) default: return 0, errors.Errorf("invalid unit %s", unit) } + if d.Duration < 0 { + res = -res + } + return res, nil } // parseSingleTimeValue parse the format according the given unit. If we set strictCheck true, we'll check whether diff --git a/types/time_test.go b/types/time_test.go index 9dc3c4851e486..3ff65fb96b08d 100644 --- a/types/time_test.go +++ b/types/time_test.go @@ -1556,35 +1556,67 @@ func (s *testTimeSuite) TestExtractDatetimeNum(c *C) { } func (s *testTimeSuite) TestExtractDurationNum(c *C) { - in := types.Duration{Duration: time.Duration(3600 * 24 * 365), Fsp: types.DefaultFsp} - tbl := []struct { + type resultTbl struct { unit string expect int64 - }{ - {"MICROSECOND", 31536}, - {"SECOND", 0}, - {"MINUTE", 0}, - {"HOUR", 0}, - {"SECOND_MICROSECOND", 31536}, - {"MINUTE_MICROSECOND", 31536}, - {"MINUTE_SECOND", 0}, - {"HOUR_MICROSECOND", 31536}, - {"HOUR_SECOND", 0}, - {"HOUR_MINUTE", 0}, - {"DAY_MICROSECOND", 31536}, - {"DAY_SECOND", 0}, - {"DAY_MINUTE", 0}, - {"DAY_HOUR", 0}, + } + type testCase struct { + in types.Duration + resTbls []resultTbl + } + cases := []testCase{ + { + in: types.Duration{Duration: time.Duration(3600 * 24 * 365), Fsp: types.DefaultFsp}, + resTbls: []resultTbl{ + {"MICROSECOND", 31536}, + {"SECOND", 0}, + {"MINUTE", 0}, + {"HOUR", 0}, + {"SECOND_MICROSECOND", 31536}, + {"MINUTE_MICROSECOND", 31536}, + {"MINUTE_SECOND", 0}, + {"HOUR_MICROSECOND", 31536}, + {"HOUR_SECOND", 0}, + {"HOUR_MINUTE", 0}, + {"DAY_MICROSECOND", 31536}, + {"DAY_SECOND", 0}, + {"DAY_MINUTE", 0}, + {"DAY_HOUR", 0}, + }, + }, + { + // "-10:59:1" = -10^9 * (10 * 3600 + 59 * 60 + 1) + in: types.Duration{Duration: time.Duration(-39541000000000), Fsp: types.DefaultFsp}, + resTbls: []resultTbl{ + {"MICROSECOND", 0}, + {"SECOND", -1}, + {"MINUTE", -59}, + {"HOUR", -10}, + {"SECOND_MICROSECOND", -1000000}, + {"MINUTE_MICROSECOND", -5901000000}, + {"MINUTE_SECOND", -5901}, + {"HOUR_MICROSECOND", -105901000000}, + {"HOUR_SECOND", -105901}, + {"HOUR_MINUTE", -1059}, + {"DAY_MICROSECOND", -105901000000}, + {"DAY_SECOND", -105901}, + {"DAY_MINUTE", -1059}, + {"DAY_HOUR", -10}, + }, + }, } - for _, col := range tbl { - res, err := types.ExtractDurationNum(&in, col.unit) - c.Assert(err, IsNil) - c.Assert(res, Equals, col.expect) + for _, testcase := range cases { + in := testcase.in + for _, col := range testcase.resTbls { + res, err := types.ExtractDurationNum(&in, col.unit) + c.Assert(err, IsNil) + c.Assert(res, Equals, col.expect) + } + res, err := types.ExtractDurationNum(&in, "TEST_ERROR") + c.Assert(res, Equals, int64(0)) + c.Assert(err, ErrorMatches, "invalid unit.*") } - res, err := types.ExtractDurationNum(&in, "TEST_ERROR") - c.Assert(res, Equals, int64(0)) - c.Assert(err, ErrorMatches, "invalid unit.*") } func (s *testTimeSuite) TestParseDurationValue(c *C) {