Skip to content

Commit

Permalink
expression: fix the behavior when adding date with big interval
Browse files Browse the repository at this point in the history
  • Loading branch information
lcwangchao committed Dec 7, 2023
1 parent 373608f commit 8a6bb2c
Show file tree
Hide file tree
Showing 9 changed files with 687 additions and 78 deletions.
1 change: 1 addition & 0 deletions pkg/expression/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ go_library(
visibility = ["//visibility:public"],
deps = [
"//pkg/config",
"//pkg/errctx",
"//pkg/errno",
"//pkg/extension",
"//pkg/kv",
Expand Down
111 changes: 76 additions & 35 deletions pkg/expression/builtin_time.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import (
"github.com/pingcap/errors"
"github.com/pingcap/failpoint"
"github.com/pingcap/tidb/pkg/config"
"github.com/pingcap/tidb/pkg/errctx"
"github.com/pingcap/tidb/pkg/parser/ast"
"github.com/pingcap/tidb/pkg/parser/mysql"
"github.com/pingcap/tidb/pkg/parser/terror"
Expand Down Expand Up @@ -2754,7 +2755,7 @@ type baseDateArithmetical struct {

func newDateArithmeticalUtil() baseDateArithmetical {
return baseDateArithmetical{
intervalRegexp: regexp.MustCompile(`-?[\d]+`),
intervalRegexp: regexp.MustCompile(`^[+-]?[\d]+`),
}
}

Expand Down Expand Up @@ -2864,17 +2865,55 @@ func (du *baseDateArithmetical) getIntervalFromString(ctx sessionctx.Context, ar
if isNull || err != nil {
return "", true, err
}
// unit "DAY" and "HOUR" has to be specially handled.
if toLower := strings.ToLower(unit); toLower == "day" || toLower == "hour" {
if strings.ToLower(interval) == "true" {
interval = "1"
} else if strings.ToLower(interval) == "false" {

interval, err = du.intervalReformatString(ctx.GetSessionVars().StmtCtx.ErrCtx(), interval, unit)
return interval, false, err
}

func (du *baseDateArithmetical) intervalReformatString(ec errctx.Context, str string, unit string) (interval string, err error) {
switch strings.ToUpper(unit) {
case "MICROSECOND", "MINUTE", "HOUR", "DAY", "WEEK", "MONTH", "QUARTER", "YEAR":
str = strings.TrimSpace(str)
// a single unit value has to be specially handled.
interval = du.intervalRegexp.FindString(str)
if interval == "" {
interval = "0"
} else {
interval = du.intervalRegexp.FindString(interval)
}

if interval != str {
err = ec.HandleError(types.ErrTruncatedWrongVal.GenWithStackByArgs("INTEGER", str))
}
case "SECOND":
// The unit SECOND is specially handled, for example:
// date + INTERVAL "1e2" SECOND = date + INTERVAL 100 second
// date + INTERVAL "1.6" SECOND = date + INTERVAL 1.6 second
// But:
// date + INTERVAL "1e2" MINUTE = date + INTERVAL 1 MINUTE
// date + INTERVAL "1.6" MINUTE = date + INTERVAL 1 MINUTE
var dec types.MyDecimal
err = ec.HandleError(dec.FromString([]byte(str)))
interval = string(dec.ToString())
default:
interval = str
}
return interval, false, nil
return interval, err
}

func (du *baseDateArithmetical) intervalDecimalToString(tc types.Context, dec *types.MyDecimal) (string, error) {
var rounded types.MyDecimal
err := dec.Round(&rounded, 0, types.ModeHalfUp)
if err != nil {
return "", err
}

intVal, err := rounded.ToInt()
if err != nil {
if err = tc.HandleTruncate(types.ErrTruncatedWrongVal.GenWithStackByArgs("DECIMAL", dec.String())); err != nil {
return "", err
}
}

return strconv.FormatInt(intVal, 10), nil
}

func (du *baseDateArithmetical) getIntervalFromDecimal(ctx sessionctx.Context, args []Expression, row chunk.Row, unit string) (string, bool, error) {
Expand Down Expand Up @@ -2921,9 +2960,8 @@ func (du *baseDateArithmetical) getIntervalFromDecimal(ctx sessionctx.Context, a
// interval is already like the %f format.
default:
// YEAR, QUARTER, MONTH, WEEK, DAY, HOUR, MINUTE, MICROSECOND
castExpr := WrapWithCastAsString(ctx, WrapWithCastAsInt(ctx, args[1]))
interval, isNull, err = castExpr.EvalString(ctx, row)
if isNull || err != nil {
interval, err = du.intervalDecimalToString(ctx.GetSessionVars().StmtCtx.TypeCtx(), decimal)
if err != nil {
return "", true, err
}
}
Expand All @@ -2936,6 +2974,11 @@ func (du *baseDateArithmetical) getIntervalFromInt(ctx sessionctx.Context, args
if isNull || err != nil {
return "", true, err
}

if mysql.HasUnsignedFlag(args[1].GetType().GetFlag()) {
return strconv.FormatUint(uint64(interval), 10), false, nil
}

return strconv.FormatInt(interval, 10), false, nil
}

Expand All @@ -2962,7 +3005,10 @@ func (du *baseDateArithmetical) addDate(ctx sessionctx.Context, date types.Time,
}

goTime = goTime.Add(time.Duration(nano))
goTime = types.AddDate(year, month, day, goTime)
goTime, err = types.AddDate(year, month, day, goTime)
if err != nil {
return types.ZeroTime, true, handleInvalidTimeError(ctx, types.ErrDatetimeFunctionOverflow.GenWithStackByArgs("datetime"))
}

// Adjust fsp as required by outer - always respect type inference.
date.SetFsp(resultFsp)
Expand All @@ -2974,10 +3020,6 @@ func (du *baseDateArithmetical) addDate(ctx sessionctx.Context, date types.Time,
return date, false, nil
}

if goTime.Year() < 0 || goTime.Year() > 9999 {
return types.ZeroTime, true, handleInvalidTimeError(ctx, types.ErrDatetimeFunctionOverflow.GenWithStackByArgs("datetime"))
}

date.SetCoreTime(types.FromGoTime(goTime))
overflow, err := types.DateTimeIsOverflow(ctx.GetSessionVars().StmtCtx.TypeCtx(), date)
if err := handleInvalidTimeError(ctx, err); err != nil {
Expand Down Expand Up @@ -3236,28 +3278,19 @@ func (du *baseDateArithmetical) vecGetIntervalFromString(b *baseBuiltinFunc, ctx
return err
}

amendInterval := func(val string) string {
return val
}
if unitLower := strings.ToLower(unit); unitLower == "day" || unitLower == "hour" {
amendInterval = func(val string) string {
if intervalLower := strings.ToLower(val); intervalLower == "true" {
return "1"
} else if intervalLower == "false" {
return "0"
}
return du.intervalRegexp.FindString(val)
}
}

ec := ctx.GetSessionVars().StmtCtx.ErrCtx()
result.ReserveString(n)
for i := 0; i < n; i++ {
if buf.IsNull(i) {
result.AppendNull()
continue
}

result.AppendString(amendInterval(buf.GetString(i)))
interval, err := du.intervalReformatString(ec, buf.GetString(i), unit)
if err != nil {
return err
}
result.AppendString(interval)
}
return nil
}
Expand Down Expand Up @@ -3325,10 +3358,18 @@ func (du *baseDateArithmetical) vecGetIntervalFromDecimal(b *baseBuiltinFunc, ct
/* keep interval as original decimal */
default:
// YEAR, QUARTER, MONTH, WEEK, DAY, HOUR, MINUTE, MICROSECOND
castExpr := WrapWithCastAsString(ctx, WrapWithCastAsInt(ctx, b.args[1]))
amendInterval = func(_ string, row *chunk.Row) (string, bool, error) {
interval, isNull, err := castExpr.EvalString(ctx, *row)
return interval, isNull || err != nil, err
dec, isNull, err := b.args[1].EvalDecimal(ctx, *row)
if isNull || err != nil {
return "", true, err
}

str, err := du.intervalDecimalToString(ctx.GetSessionVars().StmtCtx.TypeCtx(), dec)
if err != nil {
return "", true, err
}

return str, false, nil
}
}

Expand Down
8 changes: 4 additions & 4 deletions pkg/expression/integration_test/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2247,11 +2247,11 @@ func TestTimeBuiltin(t *testing.T) {
{"\"2011-11-11 10:10:10\"", "\"20\"", "DAY", "2011-12-01 10:10:10", "2011-10-22 10:10:10"},
{"\"2011-11-11 10:10:10\"", "19.88", "DAY", "2011-12-01 10:10:10", "2011-10-22 10:10:10"},
{"\"2011-11-11 10:10:10\"", "\"19.88\"", "DAY", "2011-11-30 10:10:10", "2011-10-23 10:10:10"},
{"\"2011-11-11 10:10:10\"", "\"prefix19suffix\"", "DAY", "2011-11-30 10:10:10", "2011-10-23 10:10:10"},
{"\"2011-11-11 10:10:10\"", "\"prefix19suffix\"", "DAY", "2011-11-11 10:10:10", "2011-11-11 10:10:10"},
{"\"2011-11-11 10:10:10\"", "\"20-11\"", "DAY", "2011-12-01 10:10:10", "2011-10-22 10:10:10"},
{"\"2011-11-11 10:10:10\"", "\"20,11\"", "daY", "2011-12-01 10:10:10", "2011-10-22 10:10:10"},
{"\"2011-11-11 10:10:10\"", "\"1000\"", "dAy", "2014-08-07 10:10:10", "2009-02-14 10:10:10"},
{"\"2011-11-11 10:10:10\"", "\"true\"", "Day", "2011-11-12 10:10:10", "2011-11-10 10:10:10"},
{"\"2011-11-11 10:10:10\"", "\"true\"", "Day", "2011-11-11 10:10:10", "2011-11-11 10:10:10"},
{"\"2011-11-11 10:10:10\"", "true", "Day", "2011-11-12 10:10:10", "2011-11-10 10:10:10"},
{"\"2011-11-11\"", "1", "DAY", "2011-11-12", "2011-11-10"},
{"\"2011-11-11\"", "10", "HOUR", "2011-11-11 10:00:00", "2011-11-10 14:00:00"},
Expand Down Expand Up @@ -2329,8 +2329,8 @@ func TestTimeBuiltin(t *testing.T) {
{"\"2009-01-01\"", "6/0", "HOUR_MINUTE", "<nil>", "<nil>"},
{"\"1970-01-01 12:00:00\"", "CAST(6/4 AS DECIMAL(3,1))", "HOUR_MINUTE", "1970-01-01 13:05:00", "1970-01-01 10:55:00"},
// for issue #8077
{"\"2012-01-02\"", "\"prefix8\"", "HOUR", "2012-01-02 08:00:00", "2012-01-01 16:00:00"},
{"\"2012-01-02\"", "\"prefix8prefix\"", "HOUR", "2012-01-02 08:00:00", "2012-01-01 16:00:00"},
{"\"2012-01-02\"", "\"prefix8\"", "HOUR", "2012-01-02 00:00:00", "2012-01-02 00:00:00"},
{"\"2012-01-02\"", "\"prefix8prefix\"", "HOUR", "2012-01-02 00:00:00", "2012-01-02 00:00:00"},
{"\"2012-01-02\"", "\"8:00\"", "HOUR", "2012-01-02 08:00:00", "2012-01-01 16:00:00"},
{"\"2012-01-02\"", "\"8:00:00\"", "HOUR", "2012-01-02 08:00:00", "2012-01-01 16:00:00"},
}
Expand Down
20 changes: 18 additions & 2 deletions pkg/types/core_time.go
Original file line number Diff line number Diff line change
Expand Up @@ -280,14 +280,30 @@ func compareTime(a, b CoreTime) int {
// Dig it and we found it's caused by golang api time.Date(year int, month Month, day, hour, min, sec, nsec int, loc *Location) Time ,
// it says October 32 converts to November 1 ,it conflicts with mysql.
// See https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_date-add
func AddDate(year, month, day int64, ot gotime.Time) (nt gotime.Time) {
func AddDate(year, month, day int64, ot gotime.Time) (nt gotime.Time, _ error) {
// We must limit the range of year, month and day to avoid overflow.
// The datetime range is from '1000-01-01 00:00:00.000000' to '9999-12-31 23:59:59.499999',
// so it is safe to limit the added value from -10000*365 to 10000*365.
const maxAdd = 10000 * 365
const minAdd = -maxAdd
if year > maxAdd || year < minAdd ||
month > maxAdd || month < minAdd ||
day > maxAdd || day < minAdd {
return nt, ErrDatetimeFunctionOverflow.GenWithStackByArgs("datetime")
}

df := getFixDays(int(year), int(month), int(day), ot)
if df != 0 {
nt = ot.AddDate(int(year), int(month), df)
} else {
nt = ot.AddDate(int(year), int(month), int(day))
}
return nt

if nt.Year() < 0 || nt.Year() > 9999 {
return nt, ErrDatetimeFunctionOverflow.GenWithStackByArgs("datetime")
}

return nt, nil
}

func calcTimeFromSec(to *CoreTime, seconds, microseconds int) {
Expand Down
29 changes: 23 additions & 6 deletions pkg/types/core_time_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -263,16 +263,33 @@ func TestAddDate(t *testing.T) {
month int
day int
ot time.Time
err bool
}{
{01, 1, 0, time.Date(2000, 1, 01, 0, 0, 0, 0, time.UTC)},
{02, 1, 12, time.Date(2000, 1, 01, 0, 0, 0, 0, time.UTC)},
{03, 1, 12, time.Date(2000, 1, 01, 0, 0, 0, 0, time.UTC)},
{04, 2, 24, time.Date(2000, 2, 10, 0, 0, 0, 0, time.UTC)},
{01, 04, 05, time.Date(2019, 04, 01, 1, 2, 3, 4, time.UTC)},
{01, 1, 0, time.Date(2000, 1, 01, 0, 0, 0, 0, time.UTC), false},
{02, 1, 12, time.Date(2000, 1, 01, 0, 0, 0, 0, time.UTC), false},
{03, 1, 12, time.Date(2000, 1, 01, 0, 0, 0, 0, time.UTC), false},
{04, 2, 24, time.Date(2000, 2, 10, 0, 0, 0, 0, time.UTC), false},
{01, 04, 05, time.Date(2019, 04, 01, 1, 2, 3, 4, time.UTC), false},
{7999, 1, 1, time.Date(2000, 1, 01, 0, 0, 0, 0, time.UTC), false},
{-2000, 1, 1, time.Date(2000, 1, 01, 0, 0, 0, 0, time.UTC), false},
{8000, 1, 1, time.Date(2000, 1, 01, 0, 0, 0, 0, time.UTC), true},
{10001 * 365, 1, 1, time.Date(2000, 1, 01, 0, 0, 0, 0, time.UTC), true},
{01, 10001 * 36, 1, time.Date(2000, 1, 01, 0, 0, 0, 0, time.UTC), true},
{01, 1, 10001 * 365, time.Date(2000, 1, 01, 0, 0, 0, 0, time.UTC), true},
{-2001, 1, 1, time.Date(2000, 1, 01, 0, 0, 0, 0, time.UTC), true},
{-10001 * 365, 1, 1, time.Date(2000, 1, 01, 0, 0, 0, 0, time.UTC), true},
{01, -10001 * 36, 1, time.Date(2000, 1, 01, 0, 0, 0, 0, time.UTC), true},
{01, 1, -10001 * 365, time.Date(2000, 1, 01, 0, 0, 0, 0, time.UTC), true},
}

for _, tt := range tests {
res := AddDate(int64(tt.year), int64(tt.month), int64(tt.day), tt.ot)
res, err := AddDate(int64(tt.year), int64(tt.month), int64(tt.day), tt.ot)
if tt.err {
require.EqualError(t, err, ErrDatetimeFunctionOverflow.GenWithStackByArgs("datetime").Error())
require.True(t, ErrDatetimeFunctionOverflow.Equal(err))
continue
}
require.NoError(t, err)
require.Equal(t, tt.year+tt.ot.Year(), res.Year())
}
}
Expand Down
61 changes: 35 additions & 26 deletions pkg/types/time.go
Original file line number Diff line number Diff line change
Expand Up @@ -2301,9 +2301,12 @@ func parseSingleTimeValue(unit string, format string, strictCheck bool) (year in
if len(format) > 0 && format[0] == '-' {
sign = int64(-1)
}

// We should also continue even if an error occurs here
// because the called may ignore the error and use the return value.
iv, err := strconv.ParseInt(format[0:decimalPointPos], 10, 64)
if err != nil {
return 0, 0, 0, 0, 0, ErrWrongValue.GenWithStackByArgs(DateTimeStr, format)
err = ErrWrongValue.GenWithStackByArgs(DateTimeStr, format)
}
riv := iv // Rounded integer value

Expand All @@ -2312,22 +2315,23 @@ func parseSingleTimeValue(unit string, format string, strictCheck bool) (year in
lf := len(format) - 1
// Has fraction part
if decimalPointPos < lf {
var tmpErr error
dvPre := oneToSixDigitRegex.FindString(format[decimalPointPos+1:]) // the numberical prefix of the fraction part
decimalLen = len(dvPre)
if decimalLen >= 6 {
// MySQL rounds down to 1e-6.
if dv, err = strconv.ParseInt(dvPre[0:6], 10, 64); err != nil {
return 0, 0, 0, 0, 0, ErrWrongValue.GenWithStackByArgs(DateTimeStr, format)
if dv, tmpErr = strconv.ParseInt(dvPre[0:6], 10, 64); tmpErr != nil && err == nil {
err = ErrWrongValue.GenWithStackByArgs(DateTimeStr, format)
}
} else {
if dv, err = strconv.ParseInt(dvPre+"000000"[:6-decimalLen], 10, 64); err != nil {
return 0, 0, 0, 0, 0, ErrWrongValue.GenWithStackByArgs(DateTimeStr, format)
if dv, tmpErr = strconv.ParseInt(dvPre+"000000"[:6-decimalLen], 10, 64); tmpErr != nil && err == nil {
err = ErrWrongValue.GenWithStackByArgs(DateTimeStr, format)
}
}
if dv >= 500000 { // Round up, and we should keep 6 digits for microsecond, so dv should in [000000, 999999].
riv += sign
}
if unit != "SECOND" {
if unit != "SECOND" && err == nil {
err = ErrTruncatedWrongVal.GenWithStackByArgs(format)
}
dv *= sign
Expand Down Expand Up @@ -2421,39 +2425,44 @@ func parseTimeValue(format string, index, cnt int) (years int64, months int64, d
index--
}

// ParseInt may return an error when overflowed, but we should continue to parse the rest of the string because
// the caller may ignore the error and use the return value.
// In this case, we should return a big value to make sure the result date after adding this interval
// is also overflowed and NULL is returned to the user.
years, err = strconv.ParseInt(fields[YearIndex], 10, 64)
if err != nil {
return 0, 0, 0, 0, 0, ErrWrongValue.GenWithStackByArgs(DateTimeStr, originalFmt)
err = ErrWrongValue.GenWithStackByArgs(DateTimeStr, originalFmt)
}
months, err = strconv.ParseInt(fields[MonthIndex], 10, 64)
if err != nil {
return 0, 0, 0, 0, 0, ErrWrongValue.GenWithStackByArgs(DateTimeStr, originalFmt)
var tmpErr error
months, tmpErr = strconv.ParseInt(fields[MonthIndex], 10, 64)
if err == nil && tmpErr != nil {
err = ErrWrongValue.GenWithStackByArgs(DateTimeStr, originalFmt)
}
days, err = strconv.ParseInt(fields[DayIndex], 10, 64)
if err != nil {
return 0, 0, 0, 0, 0, ErrWrongValue.GenWithStackByArgs(DateTimeStr, originalFmt)
days, tmpErr = strconv.ParseInt(fields[DayIndex], 10, 64)
if err == nil && tmpErr != nil {
err = ErrWrongValue.GenWithStackByArgs(DateTimeStr, originalFmt)
}

hours, err := strconv.ParseInt(fields[HourIndex], 10, 64)
if err != nil {
return 0, 0, 0, 0, 0, ErrWrongValue.GenWithStackByArgs(DateTimeStr, originalFmt)
hours, tmpErr := strconv.ParseInt(fields[HourIndex], 10, 64)
if tmpErr != nil && err == nil {
err = ErrWrongValue.GenWithStackByArgs(DateTimeStr, originalFmt)
}
minutes, err := strconv.ParseInt(fields[MinuteIndex], 10, 64)
if err != nil {
return 0, 0, 0, 0, 0, ErrWrongValue.GenWithStackByArgs(DateTimeStr, originalFmt)
minutes, tmpErr := strconv.ParseInt(fields[MinuteIndex], 10, 64)
if tmpErr != nil && err == nil {
err = ErrWrongValue.GenWithStackByArgs(DateTimeStr, originalFmt)
}
seconds, err := strconv.ParseInt(fields[SecondIndex], 10, 64)
if err != nil {
return 0, 0, 0, 0, 0, ErrWrongValue.GenWithStackByArgs(DateTimeStr, originalFmt)
seconds, tmpErr := strconv.ParseInt(fields[SecondIndex], 10, 64)
if tmpErr != nil && err == nil {
err = ErrWrongValue.GenWithStackByArgs(DateTimeStr, originalFmt)
}
microseconds, err := strconv.ParseInt(alignFrac(fields[MicrosecondIndex], MaxFsp), 10, 64)
if err != nil {
return 0, 0, 0, 0, 0, ErrWrongValue.GenWithStackByArgs(DateTimeStr, originalFmt)
microseconds, tmpErr := strconv.ParseInt(alignFrac(fields[MicrosecondIndex], MaxFsp), 10, 64)
if tmpErr != nil && err == nil {
err = ErrWrongValue.GenWithStackByArgs(DateTimeStr, originalFmt)
}
seconds = hours*3600 + minutes*60 + seconds
days += seconds / (3600 * 24)
seconds %= 3600 * 24
return years, months, days, seconds*int64(gotime.Second) + microseconds*int64(gotime.Microsecond), fsp, nil
return years, months, days, seconds*int64(gotime.Second) + microseconds*int64(gotime.Microsecond), fsp, err
}

func parseAndValidateDurationValue(format string, index, cnt int) (int64, int, error) {
Expand Down
Loading

0 comments on commit 8a6bb2c

Please sign in to comment.