diff --git a/br/pkg/lightning/backend/kv/session.go b/br/pkg/lightning/backend/kv/session.go index e1aca81bd581e..727e4de60ef0a 100644 --- a/br/pkg/lightning/backend/kv/session.go +++ b/br/pkg/lightning/backend/kv/session.go @@ -33,6 +33,7 @@ import ( "github.com/pingcap/tidb/pkg/parser/model" "github.com/pingcap/tidb/pkg/sessionctx" "github.com/pingcap/tidb/pkg/sessionctx/variable" + "github.com/pingcap/tidb/pkg/types" "github.com/pingcap/tidb/pkg/util/topsql/stmtstats" "go.uber.org/zap" ) @@ -313,6 +314,9 @@ func NewSession(options *encode.SessionOptions, logger log.Logger) *Session { } } vars.StmtCtx.SetTimeZone(vars.Location()) + vars.StmtCtx.SetTypeFlags(types.StrictFlags. + WithClipNegativeToZero(true), + ) if err := vars.SetSystemVar("timestamp", strconv.FormatInt(options.Timestamp, 10)); err != nil { logger.Warn("new session: failed to set timestamp", log.ShortError(err)) diff --git a/pkg/ddl/backfilling_scheduler.go b/pkg/ddl/backfilling_scheduler.go index 5a6a5d9008b78..d70e0432879bd 100644 --- a/pkg/ddl/backfilling_scheduler.go +++ b/pkg/ddl/backfilling_scheduler.go @@ -18,7 +18,6 @@ import ( "context" "fmt" "sync" - "time" "github.com/pingcap/errors" "github.com/pingcap/tidb/pkg/ddl/copr" @@ -33,6 +32,7 @@ import ( "github.com/pingcap/tidb/pkg/sessionctx" "github.com/pingcap/tidb/pkg/sessionctx/variable" "github.com/pingcap/tidb/pkg/table" + "github.com/pingcap/tidb/pkg/types" "github.com/pingcap/tidb/pkg/util" "github.com/pingcap/tidb/pkg/util/dbterror" "github.com/pingcap/tidb/pkg/util/intest" @@ -148,12 +148,6 @@ func initSessCtx( sqlMode mysql.SQLMode, tzLocation *model.TimeZoneLocation, ) error { - // Unify the TimeZone settings in newContext. - if sessCtx.GetSessionVars().StmtCtx.TimeZone() == nil { - tz := *time.UTC - sessCtx.GetSessionVars().StmtCtx.SetTimeZone(&tz) - } - sessCtx.GetSessionVars().StmtCtx.IsDDLJobInQueue = true // Set the row encode format version. rowFormat := variable.GetDDLReorgRowFormat() sessCtx.GetSessionVars().RowEncoder.Enable = rowFormat != variable.DefTiDBRowFormatV1 @@ -162,15 +156,17 @@ func initSessCtx( if err := setSessCtxLocation(sessCtx, tzLocation); err != nil { return errors.Trace(err) } + sessCtx.GetSessionVars().StmtCtx.SetTimeZone(sessCtx.GetSessionVars().Location()) sessCtx.GetSessionVars().StmtCtx.BadNullAsWarning = !sqlMode.HasStrictMode() sessCtx.GetSessionVars().StmtCtx.OverflowAsWarning = !sqlMode.HasStrictMode() sessCtx.GetSessionVars().StmtCtx.AllowInvalidDate = sqlMode.HasAllowInvalidDatesMode() sessCtx.GetSessionVars().StmtCtx.DividedByZeroAsWarning = !sqlMode.HasStrictMode() sessCtx.GetSessionVars().StmtCtx.IgnoreZeroInDate = !sqlMode.HasStrictMode() || sqlMode.HasAllowInvalidDatesMode() sessCtx.GetSessionVars().StmtCtx.NoZeroDate = sqlMode.HasStrictMode() - - typeFlags := sessCtx.GetSessionVars().StmtCtx.TypeFlags().WithTruncateAsWarning(!sqlMode.HasStrictMode()) - sessCtx.GetSessionVars().StmtCtx.SetTypeFlags(typeFlags) + sessCtx.GetSessionVars().StmtCtx.SetTypeFlags(types.StrictFlags. + WithTruncateAsWarning(!sqlMode.HasStrictMode()). + WithClipNegativeToZero(true), + ) // Prevent initializing the mock context in the workers concurrently. // For details, see https://github.com/pingcap/tidb/issues/40879. diff --git a/pkg/executor/executor.go b/pkg/executor/executor.go index 34c10220cb66b..54406e2b1495f 100644 --- a/pkg/executor/executor.go +++ b/pkg/executor/executor.go @@ -2162,7 +2162,12 @@ func ResetContextOfStmt(ctx sessionctx.Context, s ast.StmtNode) (err error) { sc.SetTypeFlags(sc.TypeFlags(). WithSkipUTF8Check(vars.SkipUTF8Check). WithSkipSACIICheck(vars.SkipASCIICheck). - WithSkipUTF8MB4Check(!globalConfig.Instance.CheckMb4ValueInUTF8.Load())) + WithSkipUTF8MB4Check(!globalConfig.Instance.CheckMb4ValueInUTF8.Load()). + // WithClipNegativeToZero indicates whether values less than 0 should be clipped to 0 for unsigned integer types. + // This is the case for `insert`, `update`, `alter table`, `create table` and `load data infile` statements, when not in strict SQL mode. + // see https://dev.mysql.com/doc/refman/5.7/en/out-of-range-and-overflow.html + WithClipNegativeToZero(sc.InInsertStmt || sc.InLoadDataStmt || sc.InUpdateStmt || sc.InCreateOrAlterStmt), + ) vars.PlanCacheParams.Reset() if priority := mysql.PriorityEnum(atomic.LoadInt32(&variable.ForcePriority)); priority != mysql.NoPriority { diff --git a/pkg/expression/builtin_cast.go b/pkg/expression/builtin_cast.go index b9f54d441a58d..d01710fa044e1 100644 --- a/pkg/expression/builtin_cast.go +++ b/pkg/expression/builtin_cast.go @@ -480,7 +480,9 @@ var fakeSctx = newFakeSctx() func newFakeSctx() *stmtctx.StatementContext { sc := stmtctx.NewStmtCtx() - sc.InInsertStmt = true + sc.SetTypeFlags(types.StrictFlags. + WithClipNegativeToZero(true), + ) return sc } @@ -980,7 +982,7 @@ func (b *builtinCastRealAsIntSig) evalInt(row chunk.Row) (res int64, isNull bool } else { var uintVal uint64 sc := b.ctx.GetSessionVars().StmtCtx - uintVal, err = types.ConvertFloatToUint(sc, val, types.IntergerUnsignedUpperBound(mysql.TypeLonglong), mysql.TypeLonglong) + uintVal, err = types.ConvertFloatToUint(sc.TypeFlags(), val, types.IntergerUnsignedUpperBound(mysql.TypeLonglong), mysql.TypeLonglong) res = int64(uintVal) } if types.ErrOverflow.Equal(err) { diff --git a/pkg/expression/builtin_cast_vec.go b/pkg/expression/builtin_cast_vec.go index 5b8f0f682bf41..d06e9d09985cd 100644 --- a/pkg/expression/builtin_cast_vec.go +++ b/pkg/expression/builtin_cast_vec.go @@ -766,7 +766,7 @@ func (b *builtinCastRealAsIntSig) vecEvalInt(input *chunk.Chunk, result *chunk.C } else { var uintVal uint64 sc := b.ctx.GetSessionVars().StmtCtx - uintVal, err = types.ConvertFloatToUint(sc, f64s[i], types.IntergerUnsignedUpperBound(mysql.TypeLonglong), mysql.TypeLonglong) + uintVal, err = types.ConvertFloatToUint(sc.TypeFlags(), f64s[i], types.IntergerUnsignedUpperBound(mysql.TypeLonglong), mysql.TypeLonglong) i64s[i] = int64(uintVal) } if types.ErrOverflow.Equal(err) { diff --git a/pkg/sessionctx/stmtctx/stmtctx.go b/pkg/sessionctx/stmtctx/stmtctx.go index 8aa3a0e933996..3c97bbdc5baad 100644 --- a/pkg/sessionctx/stmtctx/stmtctx.go +++ b/pkg/sessionctx/stmtctx/stmtctx.go @@ -473,12 +473,6 @@ func (sc *StatementContext) SetTypeFlags(flags typectx.Flags) { sc.typeCtx = sc.typeCtx.WithFlags(flags) } -// UpdateTypeFlags updates the flags of the type context -func (sc *StatementContext) UpdateTypeFlags(fn func(typectx.Flags) typectx.Flags) { - flags := fn(sc.typeCtx.Flags()) - sc.typeCtx = sc.typeCtx.WithFlags(flags) -} - // HandleTruncate ignores or returns the error based on the TypeContext inside. func (sc *StatementContext) HandleTruncate(err error) error { return sc.typeCtx.HandleTruncate(err) @@ -1133,13 +1127,6 @@ func (sc *StatementContext) GetExecDetails() execdetails.ExecDetails { return details } -// ShouldClipToZero indicates whether values less than 0 should be clipped to 0 for unsigned integer types. -// This is the case for `insert`, `update`, `alter table`, `create table` and `load data infile` statements, when not in strict SQL mode. -// see https://dev.mysql.com/doc/refman/5.7/en/out-of-range-and-overflow.html -func (sc *StatementContext) ShouldClipToZero() bool { - return sc.InInsertStmt || sc.InLoadDataStmt || sc.InUpdateStmt || sc.InCreateOrAlterStmt || sc.IsDDLJobInQueue -} - // ShouldIgnoreOverflowError indicates whether we should ignore the error when type conversion overflows, // so we can leave it for further processing like clipping values less than 0 to 0 for unsigned integer types. func (sc *StatementContext) ShouldIgnoreOverflowError() bool { @@ -1236,12 +1223,11 @@ func (sc *StatementContext) InitFromPBFlagAndTz(flags uint64, tz *time.Location) sc.IgnoreZeroInDate = (flags & model.FlagIgnoreZeroInDate) > 0 sc.DividedByZeroAsWarning = (flags & model.FlagDividedByZeroAsWarning) > 0 sc.SetTimeZone(tz) - - typeFlags := sc.TypeFlags() - typeFlags = typeFlags. + sc.SetTypeFlags(typectx.StrictFlags. WithIgnoreTruncateErr((flags & model.FlagIgnoreTruncate) > 0). - WithTruncateAsWarning((flags & model.FlagTruncateAsWarning) > 0) - sc.typeCtx = typectx.NewContext(typeFlags, tz, sc.AppendWarning) + WithTruncateAsWarning((flags & model.FlagTruncateAsWarning) > 0). + WithClipNegativeToZero(sc.InInsertStmt), + ) } // GetLockWaitStartTime returns the statement pessimistic lock wait start time diff --git a/pkg/sessionctx/stmtctx/stmtctx_test.go b/pkg/sessionctx/stmtctx/stmtctx_test.go index 09b4ed41c2263..c1e03af0166d6 100644 --- a/pkg/sessionctx/stmtctx/stmtctx_test.go +++ b/pkg/sessionctx/stmtctx/stmtctx_test.go @@ -356,12 +356,6 @@ func TestSetStmtCtxTypeFlags(t *testing.T) { sc.SetTypeFlags(typectx.FlagSkipASCIICheck | typectx.FlagSkipUTF8Check | typectx.FlagInvalidDateAsWarning) require.Equal(t, typectx.FlagSkipASCIICheck|typectx.FlagSkipUTF8Check|typectx.FlagInvalidDateAsWarning, sc.TypeFlags()) require.Equal(t, sc.TypeFlags(), sc.TypeFlags()) - - sc.UpdateTypeFlags(func(flags typectx.Flags) typectx.Flags { - return (flags | typectx.FlagSkipUTF8Check | typectx.FlagClipNegativeToZero) &^ typectx.FlagSkipASCIICheck - }) - require.Equal(t, typectx.FlagSkipUTF8Check|typectx.FlagClipNegativeToZero|typectx.FlagInvalidDateAsWarning, sc.TypeFlags()) - require.Equal(t, sc.TypeFlags(), sc.TypeFlags()) } func TestResetStmtCtx(t *testing.T) { diff --git a/pkg/store/mockstore/mockcopr/cop_handler_dag.go b/pkg/store/mockstore/mockcopr/cop_handler_dag.go index 1b176602ce3fd..c1fac6fa803d3 100644 --- a/pkg/store/mockstore/mockcopr/cop_handler_dag.go +++ b/pkg/store/mockstore/mockcopr/cop_handler_dag.go @@ -460,7 +460,7 @@ func (e *evalContext) decodeRelatedColumnVals(relatedColOffsets []int, value [][ // flagsAndTzToStatementContext creates a StatementContext from a `tipb.SelectRequest.Flags`. func flagsAndTzToStatementContext(flags uint64, tz *time.Location) *stmtctx.StatementContext { - sc := new(stmtctx.StatementContext) + sc := stmtctx.NewStmtCtx() sc.InitFromPBFlagAndTz(flags, tz) return sc } diff --git a/pkg/store/mockstore/unistore/cophandler/cop_handler.go b/pkg/store/mockstore/unistore/cophandler/cop_handler.go index 43daa0716978d..c4e4c3bb3c766 100644 --- a/pkg/store/mockstore/unistore/cophandler/cop_handler.go +++ b/pkg/store/mockstore/unistore/cophandler/cop_handler.go @@ -423,7 +423,7 @@ func newRowDecoder(columnInfos []*tipb.ColumnInfo, fieldTps []*types.FieldType, // flagsAndTzToStatementContext creates a StatementContext from a `tipb.SelectRequest.Flags`. func flagsAndTzToStatementContext(flags uint64, tz *time.Location) *stmtctx.StatementContext { - sc := new(stmtctx.StatementContext) + sc := stmtctx.NewStmtCtx() sc.InitFromPBFlagAndTz(flags, tz) return sc } diff --git a/pkg/table/column.go b/pkg/table/column.go index 15be521014ca8..c7e7a8469c055 100644 --- a/pkg/table/column.go +++ b/pkg/table/column.go @@ -713,7 +713,7 @@ func FillVirtualColumnValue(virtualRetTypes []*types.FieldType, virtualColumnInd } // Clip to zero if get negative value after cast to unsigned. - if mysql.HasUnsignedFlag(colInfos[idx].FieldType.GetFlag()) && !castDatum.IsNull() && !sctx.GetSessionVars().StmtCtx.ShouldClipToZero() { + if mysql.HasUnsignedFlag(colInfos[idx].FieldType.GetFlag()) && !castDatum.IsNull() && !sctx.GetSessionVars().StmtCtx.TypeFlags().ClipNegativeToZero() { switch datum.Kind() { case types.KindInt64: if datum.GetInt64() < 0 { diff --git a/pkg/types/context/context.go b/pkg/types/context/context.go index 41e62ad1482c3..59497437e52d8 100644 --- a/pkg/types/context/context.go +++ b/pkg/types/context/context.go @@ -65,6 +65,19 @@ const ( FlagSkipUTF8MB4Check ) +// ClipNegativeToZero indicates whether the flag `FlagClipNegativeToZero` is set +func (f Flags) ClipNegativeToZero() bool { + return f&FlagClipNegativeToZero != 0 +} + +// WithClipNegativeToZero returns a new flags with `FlagClipNegativeToZero` set/unset according to the clip parameter +func (f Flags) WithClipNegativeToZero(clip bool) Flags { + if clip { + return f | FlagClipNegativeToZero + } + return f &^ FlagClipNegativeToZero +} + // SkipASCIICheck indicates whether the flag `FlagSkipASCIICheck` is set func (f Flags) SkipASCIICheck() bool { return f&FlagSkipASCIICheck != 0 diff --git a/pkg/types/context/context_test.go b/pkg/types/context/context_test.go index 9a140a2ceb476..a905c728d94af 100644 --- a/pkg/types/context/context_test.go +++ b/pkg/types/context/context_test.go @@ -31,13 +31,23 @@ func TestWithNewFlags(t *testing.T) { require.Equal(t, time.UTC, ctx2.Location()) } -func TestStringFlags(t *testing.T) { +func TestSimpleOnOffFlags(t *testing.T) { cases := []struct { name string flag Flags - readFn func(f Flags) bool - writeFn func(f Flags, skip bool) Flags + readFn func(Flags) bool + writeFn func(Flags, bool) Flags }{ + { + name: "FlagClipNegativeToZero", + flag: FlagClipNegativeToZero, + readFn: func(f Flags) bool { + return f.ClipNegativeToZero() + }, + writeFn: func(f Flags, clip bool) Flags { + return f.WithClipNegativeToZero(clip) + }, + }, { name: "FlagSkipASCIICheck", flag: FlagSkipASCIICheck, diff --git a/pkg/types/convert.go b/pkg/types/convert.go index 4798103cf3654..97a859b487fe9 100644 --- a/pkg/types/convert.go +++ b/pkg/types/convert.go @@ -145,8 +145,8 @@ func ConvertUintToInt(val uint64, upperBound int64, tp byte) (int64, error) { } // ConvertIntToUint converts an int value to an uint value. -func ConvertIntToUint(sc *stmtctx.StatementContext, val int64, upperBound uint64, tp byte) (uint64, error) { - if sc.ShouldClipToZero() && val < 0 { +func ConvertIntToUint(flags Flags, val int64, upperBound uint64, tp byte) (uint64, error) { + if val < 0 && flags.ClipNegativeToZero() { return 0, overflow(val, tp) } @@ -167,10 +167,10 @@ func ConvertUintToUint(val uint64, upperBound uint64, tp byte) (uint64, error) { } // ConvertFloatToUint converts a float value to an uint value. -func ConvertFloatToUint(sc *stmtctx.StatementContext, fval float64, upperBound uint64, tp byte) (uint64, error) { +func ConvertFloatToUint(flags Flags, fval float64, upperBound uint64, tp byte) (uint64, error) { val := RoundFloat(fval) if val < 0 { - if sc.ShouldClipToZero() { + if flags.ClipNegativeToZero() { return 0, overflow(val, tp) } return uint64(int64(val)), overflow(val, tp) @@ -586,7 +586,7 @@ func ConvertJSONToInt(sc *stmtctx.StatementContext, j BinaryJSON, unsigned bool, i := j.GetInt64() if unsigned { uBound := IntergerUnsignedUpperBound(tp) - u, err := ConvertIntToUint(sc, i, uBound, tp) + u, err := ConvertIntToUint(sc.TypeFlags(), i, uBound, tp) return int64(u), sc.HandleOverflow(err, err) } @@ -614,7 +614,7 @@ func ConvertJSONToInt(sc *stmtctx.StatementContext, j BinaryJSON, unsigned bool, return u, sc.HandleOverflow(e, e) } bound := IntergerUnsignedUpperBound(tp) - u, err := ConvertFloatToUint(sc, f, bound, tp) + u, err := ConvertFloatToUint(sc.TypeFlags(), f, bound, tp) return int64(u), sc.HandleOverflow(err, err) case JSONTypeCodeString: str := string(hack.String(j.GetString())) diff --git a/pkg/types/datum.go b/pkg/types/datum.go index 3a2524a9b19ab..868c2f6fd39f3 100644 --- a/pkg/types/datum.go +++ b/pkg/types/datum.go @@ -1194,11 +1194,11 @@ func (d *Datum) convertToUint(sc *stmtctx.StatementContext, target *FieldType) ( ) switch d.k { case KindInt64: - val, err = ConvertIntToUint(sc, d.GetInt64(), upperBound, tp) + val, err = ConvertIntToUint(sc.TypeFlags(), d.GetInt64(), upperBound, tp) case KindUint64: val, err = ConvertUintToUint(d.GetUint64(), upperBound, tp) case KindFloat32, KindFloat64: - val, err = ConvertFloatToUint(sc, d.GetFloat64(), upperBound, tp) + val, err = ConvertFloatToUint(sc.TypeFlags(), d.GetFloat64(), upperBound, tp) case KindString, KindBytes: uval, err1 := StrToUint(sc.TypeCtxOrDefault(), d.GetString(), false) if err1 != nil && ErrOverflow.Equal(err1) && !sc.ShouldIgnoreOverflowError() { @@ -1215,7 +1215,7 @@ func (d *Datum) convertToUint(sc *stmtctx.StatementContext, target *FieldType) ( if err == nil { err = err1 } - val, err1 = ConvertIntToUint(sc, ival, upperBound, tp) + val, err1 = ConvertIntToUint(sc.TypeFlags(), ival, upperBound, tp) if err == nil { err = err1 } @@ -1230,9 +1230,9 @@ func (d *Datum) convertToUint(sc *stmtctx.StatementContext, target *FieldType) ( case KindMysqlDecimal: val, err = ConvertDecimalToUint(sc, d.GetMysqlDecimal(), upperBound, tp) case KindMysqlEnum: - val, err = ConvertFloatToUint(sc, d.GetMysqlEnum().ToNumber(), upperBound, tp) + val, err = ConvertFloatToUint(sc.TypeFlags(), d.GetMysqlEnum().ToNumber(), upperBound, tp) case KindMysqlSet: - val, err = ConvertFloatToUint(sc, d.GetMysqlSet().ToNumber(), upperBound, tp) + val, err = ConvertFloatToUint(sc.TypeFlags(), d.GetMysqlSet().ToNumber(), upperBound, tp) case KindBinaryLiteral, KindMysqlBit: val, err = d.GetBinaryLiteral().ToInt(sc.TypeCtxOrDefault()) if err == nil {