Skip to content

Commit

Permalink
types, *: move truncate flags to the types context (pingcap#47522)
Browse files Browse the repository at this point in the history
  • Loading branch information
YangKeao authored and wuhuizuo committed Apr 2, 2024
1 parent 443c38d commit 8b50a9d
Show file tree
Hide file tree
Showing 80 changed files with 593 additions and 495 deletions.
4 changes: 3 additions & 1 deletion br/pkg/lightning/backend/kv/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -286,11 +286,13 @@ func NewSession(options *encode.SessionOptions, logger log.Logger) *Session {
vars.StmtCtx.InInsertStmt = true
vars.StmtCtx.BatchCheck = true
vars.StmtCtx.BadNullAsWarning = !sqlMode.HasStrictMode()
vars.StmtCtx.TruncateAsWarning = !sqlMode.HasStrictMode()
vars.StmtCtx.OverflowAsWarning = !sqlMode.HasStrictMode()
vars.StmtCtx.AllowInvalidDate = sqlMode.HasAllowInvalidDatesMode()
vars.StmtCtx.IgnoreZeroInDate = !sqlMode.HasStrictMode() || sqlMode.HasAllowInvalidDatesMode()
vars.SQLMode = sqlMode

typeFlags := vars.StmtCtx.TypeFlags().WithTruncateAsWarning(!sqlMode.HasStrictMode())
vars.StmtCtx.SetTypeFlags(typeFlags)
if options.SysVars != nil {
for k, v := range options.SysVars {
// since 6.3(current master) tidb checks whether we can set a system variable
Expand Down
2 changes: 1 addition & 1 deletion br/pkg/lightning/backend/tidb/tidb.go
Original file line number Diff line number Diff line change
Expand Up @@ -455,7 +455,7 @@ func (enc *tidbEncoder) appendSQL(sb *strings.Builder, datum *types.Datum, _ *ta

case types.KindMysqlBit:
var buffer [20]byte
intValue, err := datum.GetBinaryLiteral().ToInt(nil)
intValue, err := datum.GetBinaryLiteral().ToInt(types.DefaultNoWarningContext)
if err != nil {
return err
}
Expand Down
5 changes: 4 additions & 1 deletion pkg/ddl/backfilling_scheduler.go
Original file line number Diff line number Diff line change
Expand Up @@ -163,12 +163,15 @@ func initSessCtx(
return errors.Trace(err)
}
sessCtx.GetSessionVars().StmtCtx.BadNullAsWarning = !sqlMode.HasStrictMode()
sessCtx.GetSessionVars().StmtCtx.TruncateAsWarning = !sqlMode.HasStrictMode()
sessCtx.GetSessionVars().StmtCtx.OverflowAsWarning = !sqlMode.HasStrictMode()
sessCtx.GetSessionVars().StmtCtx.AllowInvalidDate = sqlMode.HasAllowInvalidDatesMode()
sessCtx.GetSessionVars().StmtCtx.DividedByZeroAsWarning = !sqlMode.HasStrictMode()
sessCtx.GetSessionVars().StmtCtx.IgnoreZeroInDate = !sqlMode.HasStrictMode() || sqlMode.HasAllowInvalidDatesMode()
sessCtx.GetSessionVars().StmtCtx.NoZeroDate = sqlMode.HasStrictMode()

typeFlags := sessCtx.GetSessionVars().StmtCtx.TypeFlags().WithTruncateAsWarning(!sqlMode.HasStrictMode())
sessCtx.GetSessionVars().StmtCtx.SetTypeFlags(typeFlags)

// Prevent initializing the mock context in the workers concurrently.
// For details, see https://github.com/pingcap/tidb/issues/40879.
_ = sessCtx.GetDomainInfoSchema()
Expand Down
11 changes: 5 additions & 6 deletions pkg/ddl/ddl_api.go
Original file line number Diff line number Diff line change
Expand Up @@ -1352,7 +1352,7 @@ func getDefaultValue(ctx sessionctx.Context, col *table.Column, option *ast.Colu
return str, false, err
}
// For other kind of fields (e.g. INT), we supply its integer as string value.
value, err := v.GetBinaryLiteral().ToInt(ctx.GetSessionVars().StmtCtx)
value, err := v.GetBinaryLiteral().ToInt(ctx.GetSessionVars().StmtCtx.TypeCtx)
if err != nil {
return nil, false, err
}
Expand Down Expand Up @@ -5617,12 +5617,11 @@ func GetModifiableColumnJob(
}
pAst := at.Specs[0].Partition
sv := sctx.GetSessionVars().StmtCtx
oldTruncAsWarn, oldIgnoreTrunc := sv.TruncateAsWarning, sv.IgnoreTruncate.Load()
sv.TruncateAsWarning = false
sv.IgnoreTruncate.Store(false)
oldTypeFlags := sv.TypeFlags()
newTypeFlags := oldTypeFlags.WithTruncateAsWarning(false).WithIgnoreTruncateErr(false)
sv.SetTypeFlags(newTypeFlags)
_, err = buildPartitionDefinitionsInfo(sctx, pAst.Definitions, &newTblInfo, uint64(len(newTblInfo.Partition.Definitions)))
sv.TruncateAsWarning = oldTruncAsWarn
sv.IgnoreTruncate.Store(oldIgnoreTrunc)
sv.SetTypeFlags(oldTypeFlags)
if err != nil {
return nil, dbterror.ErrUnsupportedModifyColumn.GenWithStack("New column does not match partition definitions: %s", err.Error())
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/executor/aggfuncs/func_group_concat.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ func (e *baseGroupConcat4String) AppendFinalResult2Chunk(_ sessionctx.Context, p

func (e *baseGroupConcat4String) handleTruncateError(sctx sessionctx.Context) (err error) {
if atomic.CompareAndSwapInt32(e.truncated, 0, 1) {
if !sctx.GetSessionVars().StmtCtx.TruncateAsWarning {
if !sctx.GetSessionVars().StmtCtx.TypeFlags().TruncateAsWarning() {
return expression.ErrCutValueGroupConcat.GenWithStackByArgs(e.args[0].String())
}
sctx.GetSessionVars().StmtCtx.AppendWarning(expression.ErrCutValueGroupConcat.GenWithStackByArgs(e.args[0].String()))
Expand Down
6 changes: 3 additions & 3 deletions pkg/executor/coprocessor.go
Original file line number Diff line number Diff line change
Expand Up @@ -182,14 +182,14 @@ func (h *CoprocessorDAGHandler) buildDAGExecutor(req *coprocessor.Request) (exec
}

stmtCtx := h.sctx.GetSessionVars().StmtCtx
stmtCtx.SetFlagsFromPBFlag(dagReq.Flags)

tz, err := timeutil.ConstructTimeZone(dagReq.TimeZoneName, int(dagReq.TimeZoneOffset))
if err != nil {
return nil, errors.Trace(err)
}

stmtCtx.SetTimeZone(tz)
h.sctx.GetSessionVars().TimeZone = tz
stmtCtx.InitFromPBFlagAndTz(dagReq.Flags, tz)

h.dagReq = dagReq
is := h.sctx.GetInfoSchema().(infoschema.InfoSchema)
// Build physical plan.
Expand Down
38 changes: 22 additions & 16 deletions pkg/executor/executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -907,6 +907,13 @@ func (e *CheckTableExec) Next(ctx context.Context, _ *chunk.Chunk) error {
}
defer func() { e.done = true }()

// See the comment of `ColumnInfos2ColumnsAndNames`. It's fixing #42341
originalTypeFlags := e.Ctx().GetSessionVars().StmtCtx.TypeFlags()
defer func() {
e.Ctx().GetSessionVars().StmtCtx.SetTypeFlags(originalTypeFlags)
}()
e.Ctx().GetSessionVars().StmtCtx.SetTypeFlags(originalTypeFlags.WithIgnoreTruncateErr(true))

idxNames := make([]string, 0, len(e.indexInfos))
for _, idx := range e.indexInfos {
if idx.MVIndex {
Expand Down Expand Up @@ -2062,6 +2069,7 @@ func ResetContextOfStmt(ctx sessionctx.Context, s ast.StmtNode) (err error) {

sc.InRestrictedSQL = vars.InRestrictedSQL
switch stmt := s.(type) {
// `ResetUpdateStmtCtx` and `ResetDeleteStmtCtx` may modify the flags, so we'll need to store them.
case *ast.UpdateStmt:
ResetUpdateStmtCtx(sc, stmt, vars)
case *ast.DeleteStmt:
Expand All @@ -2075,17 +2083,17 @@ func ResetContextOfStmt(ctx sessionctx.Context, s ast.StmtNode) (err error) {
sc.BadNullAsWarning = !vars.StrictSQLMode || stmt.IgnoreErr
sc.IgnoreNoPartition = stmt.IgnoreErr
sc.ErrAutoincReadFailedAsWarning = stmt.IgnoreErr
sc.TruncateAsWarning = !vars.StrictSQLMode || stmt.IgnoreErr
sc.DividedByZeroAsWarning = !vars.StrictSQLMode || stmt.IgnoreErr
sc.AllowInvalidDate = vars.SQLMode.HasAllowInvalidDatesMode()
sc.IgnoreZeroInDate = !vars.SQLMode.HasNoZeroInDateMode() || !vars.SQLMode.HasNoZeroDateMode() || !vars.StrictSQLMode || stmt.IgnoreErr || sc.AllowInvalidDate
sc.Priority = stmt.Priority
sc.SetTypeFlags(sc.TypeFlags().WithTruncateAsWarning(!vars.StrictSQLMode || stmt.IgnoreErr))
case *ast.CreateTableStmt, *ast.AlterTableStmt:
sc.InCreateOrAlterStmt = true
sc.AllowInvalidDate = vars.SQLMode.HasAllowInvalidDatesMode()
sc.IgnoreZeroInDate = !vars.SQLMode.HasNoZeroInDateMode() || !vars.StrictSQLMode || sc.AllowInvalidDate
sc.NoZeroDate = vars.SQLMode.HasNoZeroDateMode()
sc.TruncateAsWarning = !vars.StrictSQLMode
sc.SetTypeFlags(sc.TypeFlags().WithTruncateAsWarning(!vars.StrictSQLMode))
case *ast.LoadDataStmt:
sc.InLoadDataStmt = true
// return warning instead of error when load data meet no partition for value
Expand All @@ -2100,7 +2108,7 @@ func ResetContextOfStmt(ctx sessionctx.Context, s ast.StmtNode) (err error) {
sc.OverflowAsWarning = true

// Return warning for truncate error in selection.
sc.TruncateAsWarning = true
sc.SetTypeFlags(sc.TypeFlags().WithTruncateAsWarning(true))
sc.IgnoreZeroInDate = true
sc.AllowInvalidDate = vars.SQLMode.HasAllowInvalidDatesMode()
if opts := stmt.SelectStmtOpts; opts != nil {
Expand All @@ -2111,38 +2119,36 @@ func ResetContextOfStmt(ctx sessionctx.Context, s ast.StmtNode) (err error) {
case *ast.SetOprStmt:
sc.InSelectStmt = true
sc.OverflowAsWarning = true
sc.TruncateAsWarning = true
sc.SetTypeFlags(sc.TypeFlags().WithTruncateAsWarning(true))
sc.IgnoreZeroInDate = true
sc.AllowInvalidDate = vars.SQLMode.HasAllowInvalidDatesMode()
case *ast.ShowStmt:
sc.IgnoreTruncate.Store(true)
sc.SetTypeFlags(sc.TypeFlags().WithIgnoreTruncateErr(true))
sc.IgnoreZeroInDate = true
sc.AllowInvalidDate = vars.SQLMode.HasAllowInvalidDatesMode()
if stmt.Tp == ast.ShowWarnings || stmt.Tp == ast.ShowErrors || stmt.Tp == ast.ShowSessionStates {
sc.InShowWarning = true
sc.SetWarnings(vars.StmtCtx.GetWarnings())
}
case *ast.SplitRegionStmt:
sc.IgnoreTruncate.Store(false)
sc.SetTypeFlags(sc.TypeFlags().WithIgnoreTruncateErr(false))
sc.IgnoreZeroInDate = true
sc.AllowInvalidDate = vars.SQLMode.HasAllowInvalidDatesMode()
case *ast.SetSessionStatesStmt:
sc.InSetSessionStatesStmt = true
sc.IgnoreTruncate.Store(true)
sc.SetTypeFlags(sc.TypeFlags().WithIgnoreTruncateErr(true))
sc.IgnoreZeroInDate = true
sc.AllowInvalidDate = vars.SQLMode.HasAllowInvalidDatesMode()
default:
sc.IgnoreTruncate.Store(true)
sc.SetTypeFlags(sc.TypeFlags().WithIgnoreTruncateErr(true))
sc.IgnoreZeroInDate = true
sc.AllowInvalidDate = vars.SQLMode.HasAllowInvalidDatesMode()
}

sc.UpdateTypeFlags(func(flags types.Flags) types.Flags {
return flags.
WithSkipUTF8Check(vars.SkipUTF8Check).
WithSkipSACIICheck(vars.SkipASCIICheck).
WithSkipUTF8MB4Check(!globalConfig.Instance.CheckMb4ValueInUTF8.Load())
})
sc.SetTypeFlags(sc.TypeFlags().
WithSkipUTF8Check(vars.SkipUTF8Check).
WithSkipSACIICheck(vars.SkipASCIICheck).
WithSkipUTF8MB4Check(!globalConfig.Instance.CheckMb4ValueInUTF8.Load()))

vars.PlanCacheParams.Reset()
if priority := mysql.PriorityEnum(atomic.LoadInt32(&variable.ForcePriority)); priority != mysql.NoPriority {
Expand Down Expand Up @@ -2192,24 +2198,24 @@ func ResetUpdateStmtCtx(sc *stmtctx.StatementContext, stmt *ast.UpdateStmt, vars
sc.InUpdateStmt = true
sc.DupKeyAsWarning = stmt.IgnoreErr
sc.BadNullAsWarning = !vars.StrictSQLMode || stmt.IgnoreErr
sc.TruncateAsWarning = !vars.StrictSQLMode || stmt.IgnoreErr
sc.DividedByZeroAsWarning = !vars.StrictSQLMode || stmt.IgnoreErr
sc.AllowInvalidDate = vars.SQLMode.HasAllowInvalidDatesMode()
sc.IgnoreZeroInDate = !vars.SQLMode.HasNoZeroInDateMode() || !vars.SQLMode.HasNoZeroDateMode() || !vars.StrictSQLMode || stmt.IgnoreErr || sc.AllowInvalidDate
sc.Priority = stmt.Priority
sc.IgnoreNoPartition = stmt.IgnoreErr
sc.SetTypeFlags(sc.TypeFlags().WithTruncateAsWarning(!vars.StrictSQLMode || stmt.IgnoreErr))
}

// ResetDeleteStmtCtx resets statement context for DeleteStmt.
func ResetDeleteStmtCtx(sc *stmtctx.StatementContext, stmt *ast.DeleteStmt, vars *variable.SessionVars) {
sc.InDeleteStmt = true
sc.DupKeyAsWarning = stmt.IgnoreErr
sc.BadNullAsWarning = !vars.StrictSQLMode || stmt.IgnoreErr
sc.TruncateAsWarning = !vars.StrictSQLMode || stmt.IgnoreErr
sc.DividedByZeroAsWarning = !vars.StrictSQLMode || stmt.IgnoreErr
sc.AllowInvalidDate = vars.SQLMode.HasAllowInvalidDatesMode()
sc.IgnoreZeroInDate = !vars.SQLMode.HasNoZeroInDateMode() || !vars.SQLMode.HasNoZeroDateMode() || !vars.StrictSQLMode || stmt.IgnoreErr || sc.AllowInvalidDate
sc.Priority = stmt.Priority
sc.SetTypeFlags(sc.TypeFlags().WithTruncateAsWarning(!vars.StrictSQLMode || stmt.IgnoreErr))
}

func setOptionForTopSQL(sc *stmtctx.StatementContext, snapshot kv.Snapshot) {
Expand Down
4 changes: 2 additions & 2 deletions pkg/executor/insert_common.go
Original file line number Diff line number Diff line change
Expand Up @@ -706,7 +706,7 @@ func (e *InsertValues) fillRow(ctx context.Context, row []types.Datum, hasValue
if err != nil && gCol.FieldType.IsArray() {
return nil, completeError(tbl, gCol.Offset, rowIdx, err)
}
if e.Ctx().GetSessionVars().StmtCtx.HandleTruncate(err) != nil {
if e.Ctx().GetSessionVars().StmtCtx.TypeCtx.HandleTruncate(err) != nil {
return nil, err
}
row[colIdx], err = table.CastValue(e.Ctx(), val, gCol.ToInfo(), false, false)
Expand Down Expand Up @@ -791,7 +791,7 @@ func setDatumAutoIDAndCast(ctx sessionctx.Context, d *types.Datum, id int64, col
// Auto ID is out of range.
sc := ctx.GetSessionVars().StmtCtx
insertPlan, ok := sc.GetPlan().(*core.Insert)
if ok && sc.TruncateAsWarning && len(insertPlan.OnDuplicate) > 0 {
if ok && sc.TypeFlags().TruncateAsWarning() && len(insertPlan.OnDuplicate) > 0 {
// Fix issue #38950: AUTO_INCREMENT is incompatible with mysql
// An auto id out of range error occurs in `insert ignore into ... on duplicate ...`.
// We should allow the SQL to be executed successfully.
Expand Down
3 changes: 2 additions & 1 deletion pkg/executor/load_data.go
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,9 @@ func setNonRestrictiveFlags(stmtCtx *stmtctx.StatementContext) {
// TODO: DupKeyAsWarning represents too many "ignore error" paths, the
// meaning of this flag is not clear. I can only reuse it here.
stmtCtx.DupKeyAsWarning = true
stmtCtx.TruncateAsWarning = true
stmtCtx.BadNullAsWarning = true

stmtCtx.SetTypeFlags(stmtCtx.TypeFlags().WithTruncateAsWarning(true))
}

// NewLoadDataWorker creates a new LoadDataWorker that is ready to work.
Expand Down
6 changes: 3 additions & 3 deletions pkg/executor/test/loaddatatest/load_data_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -141,11 +141,11 @@ func TestLoadData(t *testing.T) {
selectSQL := "select * from load_data_test;"

sc := ctx.GetSessionVars().StmtCtx
originIgnoreTruncate := sc.IgnoreTruncate.Load()
oldFlags := sc.TypeFlags()
defer func() {
sc.IgnoreTruncate.Store(originIgnoreTruncate)
sc.SetTypeFlags(oldFlags)
}()
sc.IgnoreTruncate.Store(false)
sc.SetTypeFlags(oldFlags.WithIgnoreTruncateErr(false))
// fields and lines are default, ReadOneBatchRows returns data is nil
tests := []testCase{
// In MySQL we have 4 warnings: 1*"Incorrect integer value: '' for column 'id' at row", 3*"Row 1 doesn't contain data for all columns"
Expand Down
6 changes: 3 additions & 3 deletions pkg/executor/test/writetest/write_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1317,11 +1317,11 @@ func TestIssue18681(t *testing.T) {
ctx.GetSessionVars().StmtCtx.BadNullAsWarning = true

sc := ctx.GetSessionVars().StmtCtx
originIgnoreTruncate := sc.IgnoreTruncate.Load()
oldTypeFlags := sc.TypeFlags()
defer func() {
sc.IgnoreTruncate.Store(originIgnoreTruncate)
sc.SetTypeFlags(oldTypeFlags)
}()
sc.IgnoreTruncate.Store(false)
sc.SetTypeFlags(oldTypeFlags.WithIgnoreTruncateErr(true))
tests := []testCase{
{[]byte("true\tfalse\t0\t1\n"), []string{"1|0|0|1"}, "Records: 1 Deleted: 0 Skipped: 0 Warnings: 0"},
}
Expand Down
4 changes: 2 additions & 2 deletions pkg/expression/aggregation/aggregation.go
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ func (af *aggFunction) ResetContext(sc *stmtctx.StatementContext, evalCtx *AggEv
evalCtx.Value.SetNull()
}

func (af *aggFunction) updateSum(sc *stmtctx.StatementContext, evalCtx *AggEvaluateContext, row chunk.Row) error {
func (af *aggFunction) updateSum(ctx types.Context, evalCtx *AggEvaluateContext, row chunk.Row) error {
a := af.Args[0]
value, err := a.Eval(row)
if err != nil {
Expand All @@ -158,7 +158,7 @@ func (af *aggFunction) updateSum(sc *stmtctx.StatementContext, evalCtx *AggEvalu
return nil
}
}
evalCtx.Value, err = calculateSum(sc, evalCtx.Value, value)
evalCtx.Value, err = calculateSum(ctx, evalCtx.Value, value)
if err != nil {
return err
}
Expand Down
8 changes: 4 additions & 4 deletions pkg/expression/aggregation/avg.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ type avgFunction struct {
aggFunction
}

func (af *avgFunction) updateAvg(sc *stmtctx.StatementContext, evalCtx *AggEvaluateContext, row chunk.Row) error {
func (af *avgFunction) updateAvg(ctx types.Context, evalCtx *AggEvaluateContext, row chunk.Row) error {
a := af.Args[1]
value, err := a.Eval(row)
if err != nil {
Expand All @@ -36,7 +36,7 @@ func (af *avgFunction) updateAvg(sc *stmtctx.StatementContext, evalCtx *AggEvalu
if value.IsNull() {
return nil
}
evalCtx.Value, err = calculateSum(sc, evalCtx.Value, value)
evalCtx.Value, err = calculateSum(ctx, evalCtx.Value, value)
if err != nil {
return err
}
Expand All @@ -60,9 +60,9 @@ func (af *avgFunction) ResetContext(sc *stmtctx.StatementContext, evalCtx *AggEv
func (af *avgFunction) Update(evalCtx *AggEvaluateContext, sc *stmtctx.StatementContext, row chunk.Row) (err error) {
switch af.Mode {
case Partial1Mode, CompleteMode:
err = af.updateSum(sc, evalCtx, row)
err = af.updateSum(sc.TypeCtx, evalCtx, row)
case Partial2Mode, FinalMode:
err = af.updateAvg(sc, evalCtx, row)
err = af.updateAvg(sc.TypeCtx, evalCtx, row)
case DedupMode:
panic("DedupMode is not supported now.")
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/expression/aggregation/sum.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ type sumFunction struct {

// Update implements Aggregation interface.
func (sf *sumFunction) Update(evalCtx *AggEvaluateContext, sc *stmtctx.StatementContext, row chunk.Row) error {
return sf.updateSum(sc, evalCtx, row)
return sf.updateSum(sc.TypeCtx, evalCtx, row)
}

// GetResult implements Aggregation interface.
Expand Down
6 changes: 3 additions & 3 deletions pkg/expression/aggregation/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ func (d *distinctChecker) Check(values []types.Datum) (bool, error) {
}

// calculateSum adds v to sum.
func calculateSum(sc *stmtctx.StatementContext, sum, v types.Datum) (data types.Datum, err error) {
func calculateSum(ctx types.Context, sum, v types.Datum) (data types.Datum, err error) {
// for avg and sum calculation
// avg and sum use decimal for integer and decimal type, use float for others
// see https://dev.mysql.com/doc/refman/5.7/en/group-by-functions.html
Expand All @@ -64,15 +64,15 @@ func calculateSum(sc *stmtctx.StatementContext, sum, v types.Datum) (data types.
case types.KindNull:
case types.KindInt64, types.KindUint64:
var d *types.MyDecimal
d, err = v.ToDecimal(sc)
d, err = v.ToDecimal(ctx)
if err == nil {
data = types.NewDecimalDatum(d)
}
case types.KindMysqlDecimal:
v.Copy(&data)
default:
var f float64
f, err = v.ToFloat64(sc)
f, err = v.ToFloat64(ctx)
if err == nil {
data = types.NewFloat64Datum(f)
}
Expand Down
4 changes: 2 additions & 2 deletions pkg/expression/builtin_arithmetic.go
Original file line number Diff line number Diff line change
Expand Up @@ -727,7 +727,7 @@ func (s *builtinArithmeticDivideDecimalSig) evalDecimal(row chunk.Row) (*types.M
return c, true, handleDivisionByZeroError(s.ctx)
} else if err == types.ErrTruncated {
sc := s.ctx.GetSessionVars().StmtCtx
err = sc.HandleTruncate(errTruncatedWrongValue.GenWithStackByArgs("DECIMAL", c))
err = sc.TypeCtx.HandleTruncate(errTruncatedWrongValue.GenWithStackByArgs("DECIMAL", c))
} else if err == nil {
_, frac := c.PrecisionAndFrac()
if frac < s.baseBuiltinFunc.tp.GetDecimal() {
Expand Down Expand Up @@ -846,7 +846,7 @@ func (s *builtinArithmeticIntDivideDecimalSig) evalInt(row chunk.Row) (ret int64
return 0, true, handleDivisionByZeroError(s.ctx)
}
if err == types.ErrTruncated {
err = sc.HandleTruncate(errTruncatedWrongValue.GenWithStackByArgs("DECIMAL", c))
err = sc.TypeCtx.HandleTruncate(errTruncatedWrongValue.GenWithStackByArgs("DECIMAL", c))
}
if err == types.ErrOverflow {
newErr := errTruncatedWrongValue.GenWithStackByArgs("DECIMAL", c)
Expand Down
4 changes: 2 additions & 2 deletions pkg/expression/builtin_arithmetic_vec.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ func (b *builtinArithmeticDivideDecimalSig) vecEvalDecimal(input *chunk.Chunk, r
result.SetNull(i, true)
continue
} else if err == types.ErrTruncated {
if err = sc.HandleTruncate(errTruncatedWrongValue.GenWithStackByArgs("DECIMAL", to)); err != nil {
if err = sc.TypeCtx.HandleTruncate(errTruncatedWrongValue.GenWithStackByArgs("DECIMAL", to)); err != nil {
return err
}
} else if err == nil {
Expand Down Expand Up @@ -617,7 +617,7 @@ func (b *builtinArithmeticIntDivideDecimalSig) vecEvalInt(input *chunk.Chunk, re
continue
}
if err == types.ErrTruncated {
err = sc.HandleTruncate(errTruncatedWrongValue.GenWithStackByArgs("DECIMAL", c))
err = sc.TypeCtx.HandleTruncate(errTruncatedWrongValue.GenWithStackByArgs("DECIMAL", c))
} else if err == types.ErrOverflow {
newErr := errTruncatedWrongValue.GenWithStackByArgs("DECIMAL", c)
err = sc.HandleOverflow(newErr, newErr)
Expand Down
Loading

0 comments on commit 8b50a9d

Please sign in to comment.