diff --git a/pkg/ddl/column_modify_test.go b/pkg/ddl/column_modify_test.go index d09b7f9715203..d733698203625 100644 --- a/pkg/ddl/column_modify_test.go +++ b/pkg/ddl/column_modify_test.go @@ -133,7 +133,7 @@ AddLoop: func(_ kv.Handle, data []types.Datum, cols []*table.Column) (bool, error) { i++ // c4 must be -1 or > 0 - v, err := data[3].ToInt64(tk.Session().GetSessionVars().StmtCtx) + v, err := data[3].ToInt64(tk.Session().GetSessionVars().StmtCtx.TypeCtx()) require.NoError(t, err) if v == -1 { j++ diff --git a/pkg/ddl/ddl_api.go b/pkg/ddl/ddl_api.go index d80aec23dce38..745437f81a980 100644 --- a/pkg/ddl/ddl_api.go +++ b/pkg/ddl/ddl_api.go @@ -1368,7 +1368,7 @@ func getDefaultValue(ctx sessionctx.Context, col *table.Column, option *ast.Colu val, err := getEnumDefaultValue(v, col) return val, false, err case mysql.TypeDuration, mysql.TypeDate: - if v, err = v.ConvertTo(ctx.GetSessionVars().StmtCtx, &col.FieldType); err != nil { + if v, err = v.ConvertTo(ctx.GetSessionVars().StmtCtx.TypeCtx(), &col.FieldType); err != nil { return "", false, errors.Trace(err) } case mysql.TypeBit: @@ -1380,7 +1380,7 @@ func getDefaultValue(ctx sessionctx.Context, col *table.Column, option *ast.Colu // For these types, convert it to standard format firstly. // like integer fields, convert it into integer string literals. like convert "1.25" into "1" and "2.8" into "3". // if raise a error, we will use original expression. We will handle it in check phase - if temp, err := v.ConvertTo(ctx.GetSessionVars().StmtCtx, &col.FieldType); err == nil { + if temp, err := v.ConvertTo(ctx.GetSessionVars().StmtCtx.TypeCtx(), &col.FieldType); err == nil { v = temp } } @@ -7838,7 +7838,7 @@ func checkAndGetColumnsTypeAndValuesMatch(ctx sessionctx.Context, colTypes []typ return nil, dbterror.ErrWrongTypeColumnValue.GenWithStackByArgs() } } - newVal, err := val.ConvertTo(ctx.GetSessionVars().StmtCtx, &colType) + newVal, err := val.ConvertTo(ctx.GetSessionVars().StmtCtx.TypeCtx(), &colType) if err != nil { return nil, dbterror.ErrWrongTypeColumnValue.GenWithStackByArgs() } diff --git a/pkg/ddl/partition.go b/pkg/ddl/partition.go index 721d257c03798..23be2ff7b71af 100644 --- a/pkg/ddl/partition.go +++ b/pkg/ddl/partition.go @@ -1097,7 +1097,7 @@ func GeneratePartDefsFromInterval(ctx sessionctx.Context, tp ast.AlterTableType, if err != nil { return err } - cmp, err := currVal.Compare(ctx.GetSessionVars().StmtCtx, &lastVal, collate.GetBinaryCollator()) + cmp, err := currVal.Compare(ctx.GetSessionVars().StmtCtx.TypeCtx(), &lastVal, collate.GetBinaryCollator()) if err != nil { return err } @@ -1427,7 +1427,7 @@ func checkPartitionValuesIsInt(ctx sessionctx.Context, defName interface{}, expr return dbterror.ErrValuesIsNotIntType.GenWithStackByArgs(defName) } - _, err = val.ConvertTo(ctx.GetSessionVars().StmtCtx, tp) + _, err = val.ConvertTo(ctx.GetSessionVars().StmtCtx.TypeCtx(), tp) if err != nil && !types.ErrOverflow.Equal(err) { return dbterror.ErrWrongTypeColumnValue.GenWithStackByArgs() } diff --git a/pkg/executor/aggfuncs/aggfunc_test.go b/pkg/executor/aggfuncs/aggfunc_test.go index 9b6c1ad5acb74..7b24739078281 100644 --- a/pkg/executor/aggfuncs/aggfunc_test.go +++ b/pkg/executor/aggfuncs/aggfunc_test.go @@ -301,7 +301,7 @@ func testMergePartialResult(t *testing.T, p aggTest) { if p.funcName == ast.AggFuncJsonArrayagg { dt = resultChk.GetRow(0).GetDatum(0, types.NewFieldType(mysql.TypeJSON)) } - result, err := dt.Compare(ctx.GetSessionVars().StmtCtx, &p.results[0], ctor) + result, err := dt.Compare(ctx.GetSessionVars().StmtCtx.TypeCtx(), &p.results[0], ctor) require.NoError(t, err) require.Equalf(t, 0, result, "%v != %v", dt.String(), p.results[0]) @@ -328,7 +328,7 @@ func testMergePartialResult(t *testing.T, p aggTest) { if p.funcName == ast.AggFuncJsonArrayagg { dt = resultChk.GetRow(0).GetDatum(0, types.NewFieldType(mysql.TypeJSON)) } - result, err = dt.Compare(ctx.GetSessionVars().StmtCtx, &p.results[1], ctor) + result, err = dt.Compare(ctx.GetSessionVars().StmtCtx.TypeCtx(), &p.results[1], ctor) require.NoError(t, err) require.Equalf(t, 0, result, "%v != %v", dt.String(), p.results[1]) _, err = finalFunc.MergePartialResult(ctx, partialResult, finalPr) @@ -351,7 +351,7 @@ func testMergePartialResult(t *testing.T, p aggTest) { if p.funcName == ast.AggFuncJsonArrayagg { dt = resultChk.GetRow(0).GetDatum(0, types.NewFieldType(mysql.TypeJSON)) } - result, err = dt.Compare(ctx.GetSessionVars().StmtCtx, &p.results[2], ctor) + result, err = dt.Compare(ctx.GetSessionVars().StmtCtx.TypeCtx(), &p.results[2], ctor) require.NoError(t, err) require.Equalf(t, 0, result, "%v != %v", dt.String(), p.results[2]) } @@ -410,7 +410,7 @@ func testMultiArgsMergePartialResult(t *testing.T, ctx sessionctx.Context, p mul err = partialFunc.AppendFinalResult2Chunk(ctx, partialResult, resultChk) require.NoError(t, err) dt := resultChk.GetRow(0).GetDatum(0, p.retType) - result, err := dt.Compare(ctx.GetSessionVars().StmtCtx, &p.results[0], ctor) + result, err := dt.Compare(ctx.GetSessionVars().StmtCtx.TypeCtx(), &p.results[0], ctor) require.NoError(t, err) require.Zero(t, result) @@ -431,7 +431,7 @@ func testMultiArgsMergePartialResult(t *testing.T, ctx sessionctx.Context, p mul err = partialFunc.AppendFinalResult2Chunk(ctx, partialResult, resultChk) require.NoError(t, err) dt = resultChk.GetRow(0).GetDatum(0, p.retType) - result, err = dt.Compare(ctx.GetSessionVars().StmtCtx, &p.results[1], ctor) + result, err = dt.Compare(ctx.GetSessionVars().StmtCtx.TypeCtx(), &p.results[1], ctor) require.NoError(t, err) require.Zero(t, result) _, err = finalFunc.MergePartialResult(ctx, partialResult, finalPr) @@ -442,7 +442,7 @@ func testMultiArgsMergePartialResult(t *testing.T, ctx sessionctx.Context, p mul require.NoError(t, err) dt = resultChk.GetRow(0).GetDatum(0, p.retType) - result, err = dt.Compare(ctx.GetSessionVars().StmtCtx, &p.results[2], ctor) + result, err = dt.Compare(ctx.GetSessionVars().StmtCtx.TypeCtx(), &p.results[2], ctor) require.NoError(t, err) require.Zero(t, result) } @@ -540,7 +540,7 @@ func testAggFunc(t *testing.T, p aggTest) { err = finalFunc.AppendFinalResult2Chunk(ctx, finalPr, resultChk) require.NoError(t, err) dt := resultChk.GetRow(0).GetDatum(0, desc.RetTp) - result, err := dt.Compare(ctx.GetSessionVars().StmtCtx, &p.results[1], ctor) + result, err := dt.Compare(ctx.GetSessionVars().StmtCtx.TypeCtx(), &p.results[1], ctor) require.NoError(t, err) require.Equalf(t, 0, result, "%v != %v", dt.String(), p.results[1]) @@ -550,7 +550,7 @@ func testAggFunc(t *testing.T, p aggTest) { err = finalFunc.AppendFinalResult2Chunk(ctx, finalPr, resultChk) require.NoError(t, err) dt = resultChk.GetRow(0).GetDatum(0, desc.RetTp) - result, err = dt.Compare(ctx.GetSessionVars().StmtCtx, &p.results[0], ctor) + result, err = dt.Compare(ctx.GetSessionVars().StmtCtx.TypeCtx(), &p.results[0], ctor) require.NoError(t, err) require.Equalf(t, 0, result, "%v != %v", dt.String(), p.results[0]) @@ -583,7 +583,7 @@ func testAggFunc(t *testing.T, p aggTest) { err = finalFunc.AppendFinalResult2Chunk(ctx, finalPr, resultChk) require.NoError(t, err) dt = resultChk.GetRow(0).GetDatum(0, desc.RetTp) - result, err = dt.Compare(ctx.GetSessionVars().StmtCtx, &p.results[1], ctor) + result, err = dt.Compare(ctx.GetSessionVars().StmtCtx.TypeCtx(), &p.results[1], ctor) require.NoError(t, err) require.Equalf(t, 0, result, "%v != %v", dt.String(), p.results[1]) @@ -593,7 +593,7 @@ func testAggFunc(t *testing.T, p aggTest) { err = finalFunc.AppendFinalResult2Chunk(ctx, finalPr, resultChk) require.NoError(t, err) dt = resultChk.GetRow(0).GetDatum(0, desc.RetTp) - result, err = dt.Compare(ctx.GetSessionVars().StmtCtx, &p.results[0], ctor) + result, err = dt.Compare(ctx.GetSessionVars().StmtCtx.TypeCtx(), &p.results[0], ctor) require.NoError(t, err) require.Equalf(t, 0, result, "%v != %v", dt.String(), p.results[0]) } @@ -630,7 +630,7 @@ func testAggFuncWithoutDistinct(t *testing.T, p aggTest) { err = finalFunc.AppendFinalResult2Chunk(ctx, finalPr, resultChk) require.NoError(t, err) dt := resultChk.GetRow(0).GetDatum(0, desc.RetTp) - result, err := dt.Compare(ctx.GetSessionVars().StmtCtx, &p.results[1], ctor) + result, err := dt.Compare(ctx.GetSessionVars().StmtCtx.TypeCtx(), &p.results[1], ctor) require.NoError(t, err) require.Zerof(t, result, "%v != %v", dt.String(), p.results[1]) @@ -640,7 +640,7 @@ func testAggFuncWithoutDistinct(t *testing.T, p aggTest) { err = finalFunc.AppendFinalResult2Chunk(ctx, finalPr, resultChk) require.NoError(t, err) dt = resultChk.GetRow(0).GetDatum(0, desc.RetTp) - result, err = dt.Compare(ctx.GetSessionVars().StmtCtx, &p.results[0], ctor) + result, err = dt.Compare(ctx.GetSessionVars().StmtCtx.TypeCtx(), &p.results[0], ctor) require.NoError(t, err) require.Zerof(t, result, "%v != %v", dt.String(), p.results[0]) } @@ -708,7 +708,7 @@ func testMultiArgsAggFunc(t *testing.T, ctx sessionctx.Context, p multiArgsAggTe err = finalFunc.AppendFinalResult2Chunk(ctx, finalPr, resultChk) require.NoError(t, err) dt := resultChk.GetRow(0).GetDatum(0, desc.RetTp) - result, err := dt.Compare(ctx.GetSessionVars().StmtCtx, &p.results[1], ctor) + result, err := dt.Compare(ctx.GetSessionVars().StmtCtx.TypeCtx(), &p.results[1], ctor) require.NoError(t, err) require.Zerof(t, result, "%v != %v", dt.String(), p.results[1]) @@ -718,7 +718,7 @@ func testMultiArgsAggFunc(t *testing.T, ctx sessionctx.Context, p multiArgsAggTe err = finalFunc.AppendFinalResult2Chunk(ctx, finalPr, resultChk) require.NoError(t, err) dt = resultChk.GetRow(0).GetDatum(0, desc.RetTp) - result, err = dt.Compare(ctx.GetSessionVars().StmtCtx, &p.results[0], ctor) + result, err = dt.Compare(ctx.GetSessionVars().StmtCtx.TypeCtx(), &p.results[0], ctor) require.NoError(t, err) require.Zerof(t, result, "%v != %v", dt.String(), p.results[0]) @@ -751,7 +751,7 @@ func testMultiArgsAggFunc(t *testing.T, ctx sessionctx.Context, p multiArgsAggTe err = finalFunc.AppendFinalResult2Chunk(ctx, finalPr, resultChk) require.NoError(t, err) dt = resultChk.GetRow(0).GetDatum(0, desc.RetTp) - result, err = dt.Compare(ctx.GetSessionVars().StmtCtx, &p.results[1], ctor) + result, err = dt.Compare(ctx.GetSessionVars().StmtCtx.TypeCtx(), &p.results[1], ctor) require.NoError(t, err) require.Zerof(t, result, "%v != %v", dt.String(), p.results[1]) @@ -761,7 +761,7 @@ func testMultiArgsAggFunc(t *testing.T, ctx sessionctx.Context, p multiArgsAggTe err = finalFunc.AppendFinalResult2Chunk(ctx, finalPr, resultChk) require.NoError(t, err) dt = resultChk.GetRow(0).GetDatum(0, desc.RetTp) - result, err = dt.Compare(ctx.GetSessionVars().StmtCtx, &p.results[0], ctor) + result, err = dt.Compare(ctx.GetSessionVars().StmtCtx.TypeCtx(), &p.results[0], ctor) require.NoError(t, err) require.Zero(t, result) } diff --git a/pkg/executor/aggfuncs/builder.go b/pkg/executor/aggfuncs/builder.go index 487f5a56626aa..ab759e9cf6764 100644 --- a/pkg/executor/aggfuncs/builder.go +++ b/pkg/executor/aggfuncs/builder.go @@ -704,7 +704,7 @@ func buildLeadLag(ctx sessionctx.Context, aggFuncDesc *aggregation.AggFuncDesc, if len(aggFuncDesc.Args) == 3 { defaultExpr = aggFuncDesc.Args[2] if et, ok := defaultExpr.(*expression.Constant); ok { - res, err1 := et.Value.ConvertTo(ctx.GetSessionVars().StmtCtx, aggFuncDesc.RetTp) + res, err1 := et.Value.ConvertTo(ctx.GetSessionVars().StmtCtx.TypeCtx(), aggFuncDesc.RetTp) if err1 == nil { defaultExpr = &expression.Constant{Value: res, RetType: aggFuncDesc.RetTp} } diff --git a/pkg/executor/aggfuncs/func_group_concat.go b/pkg/executor/aggfuncs/func_group_concat.go index afe1d204f7fb9..c67a6d7805972 100644 --- a/pkg/executor/aggfuncs/func_group_concat.go +++ b/pkg/executor/aggfuncs/func_group_concat.go @@ -304,7 +304,7 @@ func (h topNRows) Len() int { func (h topNRows) Less(i, j int) bool { n := len(h.rows[i].byItems) for k := 0; k < n; k++ { - ret, err := h.rows[i].byItems[k].Compare(h.sctx.GetSessionVars().StmtCtx, h.rows[j].byItems[k], h.collators[k]) + ret, err := h.rows[i].byItems[k].Compare(h.sctx.GetSessionVars().StmtCtx.TypeCtx(), h.rows[j].byItems[k], h.collators[k]) if err != nil { h.err = err return false diff --git a/pkg/executor/aggfuncs/window_func_test.go b/pkg/executor/aggfuncs/window_func_test.go index 3c109aae4a95a..f657d01a89b49 100644 --- a/pkg/executor/aggfuncs/window_func_test.go +++ b/pkg/executor/aggfuncs/window_func_test.go @@ -76,7 +76,7 @@ func testWindowFunc(t *testing.T, p windowTest) { err = finalFunc.AppendFinalResult2Chunk(ctx, finalPr, resultChk) require.NoError(t, err) dt := resultChk.GetRow(0).GetDatum(0, desc.RetTp) - result, err := dt.Compare(ctx.GetSessionVars().StmtCtx, &p.results[i], collate.GetCollator(desc.RetTp.GetCollate())) + result, err := dt.Compare(ctx.GetSessionVars().StmtCtx.TypeCtx(), &p.results[i], collate.GetCollator(desc.RetTp.GetCollate())) require.NoError(t, err) require.Equal(t, 0, result) resultChk.Reset() diff --git a/pkg/executor/analyze_test.go b/pkg/executor/analyze_test.go index 9cbd34c415772..d97003a305308 100644 --- a/pkg/executor/analyze_test.go +++ b/pkg/executor/analyze_test.go @@ -37,7 +37,7 @@ import ( func checkHistogram(sc *stmtctx.StatementContext, hg *statistics.Histogram) (bool, error) { for i := 0; i < len(hg.Buckets); i++ { lower, upper := hg.GetLower(i), hg.GetUpper(i) - cmp, err := upper.Compare(sc, lower, collate.GetBinaryCollator()) + cmp, err := upper.Compare(sc.TypeCtx(), lower, collate.GetBinaryCollator()) if cmp < 0 || err != nil { return false, err } @@ -45,7 +45,7 @@ func checkHistogram(sc *stmtctx.StatementContext, hg *statistics.Histogram) (boo continue } previousUpper := hg.GetUpper(i - 1) - cmp, err = lower.Compare(sc, previousUpper, collate.GetBinaryCollator()) + cmp, err = lower.Compare(sc.TypeCtx(), previousUpper, collate.GetBinaryCollator()) if cmp <= 0 || err != nil { return false, err } diff --git a/pkg/executor/foreign_key.go b/pkg/executor/foreign_key.go index 512399c9931a6..55f34f8b02b64 100644 --- a/pkg/executor/foreign_key.go +++ b/pkg/executor/foreign_key.go @@ -174,7 +174,7 @@ func (fkc *FKCheckExec) updateRowNeedToCheck(sc *stmtctx.StatementContext, oldRo if len(oldVals) == len(newVals) { isSameValue := true for i := range oldVals { - cmp, err := oldVals[i].Compare(sc, &newVals[i], collate.GetCollator(oldVals[i].Collation())) + cmp, err := oldVals[i].Compare(sc.TypeCtx(), &newVals[i], collate.GetCollator(oldVals[i].Collation())) if err != nil || cmp != 0 { isSameValue = false break diff --git a/pkg/executor/index_lookup_join.go b/pkg/executor/index_lookup_join.go index 750321b8c0d3e..bb5aa2bd5c603 100644 --- a/pkg/executor/index_lookup_join.go +++ b/pkg/executor/index_lookup_join.go @@ -627,7 +627,7 @@ func (iw *innerWorker) constructDatumLookupKey(task *lookUpJoinTask, chkIdx, row return nil, nil, nil } innerColType := iw.rowTypes[iw.hashCols[i]] - innerValue, err := outerValue.ConvertTo(sc, innerColType) + innerValue, err := outerValue.ConvertTo(sc.TypeCtx(), innerColType) if err != nil && !(terror.ErrorEqual(err, types.ErrTruncated) && (innerColType.GetType() == mysql.TypeSet || innerColType.GetType() == mysql.TypeEnum)) { // If the converted outerValue overflows or invalid to innerValue, we don't need to lookup it. if terror.ErrorEqual(err, types.ErrOverflow) || terror.ErrorEqual(err, types.ErrWarnDataOutOfRange) { @@ -635,7 +635,7 @@ func (iw *innerWorker) constructDatumLookupKey(task *lookUpJoinTask, chkIdx, row } return nil, nil, err } - cmp, err := outerValue.Compare(sc, &innerValue, iw.hashCollators[i]) + cmp, err := outerValue.Compare(sc.TypeCtx(), &innerValue, iw.hashCollators[i]) if err != nil { return nil, nil, err } @@ -675,7 +675,7 @@ func (iw *innerWorker) sortAndDedupLookUpContents(lookUpContents []*indexJoinLoo func compareRow(sc *stmtctx.StatementContext, left, right []types.Datum, ctors []collate.Collator) int { for idx := 0; idx < len(left); idx++ { - cmp, err := left[idx].Compare(sc, &right[idx], ctors[idx]) + cmp, err := left[idx].Compare(sc.TypeCtx(), &right[idx], ctors[idx]) // We only compare rows with the same type, no error to return. terror.Log(err) if cmp > 0 { diff --git a/pkg/executor/index_lookup_merge_join.go b/pkg/executor/index_lookup_merge_join.go index c190a3d0eb0fb..8a43375a4bbdb 100644 --- a/pkg/executor/index_lookup_merge_join.go +++ b/pkg/executor/index_lookup_merge_join.go @@ -671,7 +671,7 @@ func (imw *innerMergeWorker) constructDatumLookupKey(task *lookUpMergeJoinTask, return nil, nil } innerColType := imw.rowTypes[imw.keyCols[i]] - innerValue, err := outerValue.ConvertTo(sc, innerColType) + innerValue, err := outerValue.ConvertTo(sc.TypeCtx(), innerColType) if err != nil { // If the converted outerValue overflows, we don't need to lookup it. if terror.ErrorEqual(err, types.ErrOverflow) || terror.ErrorEqual(err, types.ErrWarnDataOutOfRange) { @@ -682,7 +682,7 @@ func (imw *innerMergeWorker) constructDatumLookupKey(task *lookUpMergeJoinTask, } return nil, err } - cmp, err := outerValue.Compare(sc, &innerValue, imw.keyCollators[i]) + cmp, err := outerValue.Compare(sc.TypeCtx(), &innerValue, imw.keyCollators[i]) if err != nil { return nil, err } diff --git a/pkg/executor/insert_common.go b/pkg/executor/insert_common.go index f7f1c48d525ae..9a2d4d25ceef3 100644 --- a/pkg/executor/insert_common.go +++ b/pkg/executor/insert_common.go @@ -1368,7 +1368,7 @@ func (e *InsertValues) equalDatumsAsBinary(a []types.Datum, b []types.Datum) (bo return false, nil } for i, ai := range a { - v, err := ai.Compare(e.Ctx().GetSessionVars().StmtCtx, &b[i], collate.GetBinaryCollator()) + v, err := ai.Compare(e.Ctx().GetSessionVars().StmtCtx.TypeCtx(), &b[i], collate.GetBinaryCollator()) if err != nil { return false, errors.Trace(err) } diff --git a/pkg/executor/union_scan.go b/pkg/executor/union_scan.go index 83d3cb06ddb98..947e7c9589068 100644 --- a/pkg/executor/union_scan.go +++ b/pkg/executor/union_scan.go @@ -298,7 +298,7 @@ func (ce compareExec) compare(sctx *stmtctx.StatementContext, a, b []types.Datum for _, colOff := range ce.usedIndex { aColumn := a[colOff] bColumn := b[colOff] - cmp, err = aColumn.Compare(sctx, &bColumn, ce.collators[colOff]) + cmp, err = aColumn.Compare(sctx.TypeCtx(), &bColumn, ce.collators[colOff]) if err != nil { return 0, err } diff --git a/pkg/executor/write.go b/pkg/executor/write.go index 39187586979d5..28733152208ea 100644 --- a/pkg/executor/write.go +++ b/pkg/executor/write.go @@ -90,7 +90,7 @@ func updateRecord( // Compare datum, then handle some flags. for i, col := range t.Cols() { // We should use binary collation to compare datum, otherwise the result will be incorrect. - cmp, err := newData[i].Compare(sc, &oldData[i], collate.GetBinaryCollator()) + cmp, err := newData[i].Compare(sc.TypeCtx(), &oldData[i], collate.GetBinaryCollator()) if err != nil { return false, err } diff --git a/pkg/expression/aggregation/bit_and.go b/pkg/expression/aggregation/bit_and.go index 2d975f6563fcc..8d7adb6593850 100644 --- a/pkg/expression/aggregation/bit_and.go +++ b/pkg/expression/aggregation/bit_and.go @@ -47,7 +47,7 @@ func (bf *bitAndFunction) Update(evalCtx *AggEvaluateContext, sc *stmtctx.Statem if value.Kind() == types.KindUint64 { evalCtx.Value.SetUint64(evalCtx.Value.GetUint64() & value.GetUint64()) } else { - int64Value, err := value.ToInt64(sc) + int64Value, err := value.ToInt64(sc.TypeCtx()) if err != nil { return err } diff --git a/pkg/expression/aggregation/bit_or.go b/pkg/expression/aggregation/bit_or.go index 409085855919c..35c72a0e70a1d 100644 --- a/pkg/expression/aggregation/bit_or.go +++ b/pkg/expression/aggregation/bit_or.go @@ -45,7 +45,7 @@ func (bf *bitOrFunction) Update(evalCtx *AggEvaluateContext, sc *stmtctx.Stateme if value.Kind() == types.KindUint64 { evalCtx.Value.SetUint64(evalCtx.Value.GetUint64() | value.GetUint64()) } else { - int64Value, err := value.ToInt64(sc) + int64Value, err := value.ToInt64(sc.TypeCtx()) if err != nil { return err } diff --git a/pkg/expression/aggregation/bit_xor.go b/pkg/expression/aggregation/bit_xor.go index c3c97d5bd1712..47582cf844f15 100644 --- a/pkg/expression/aggregation/bit_xor.go +++ b/pkg/expression/aggregation/bit_xor.go @@ -45,7 +45,7 @@ func (bf *bitXorFunction) Update(evalCtx *AggEvaluateContext, sc *stmtctx.Statem if value.Kind() == types.KindUint64 { evalCtx.Value.SetUint64(evalCtx.Value.GetUint64() ^ value.GetUint64()) } else { - int64Value, err := value.ToInt64(sc) + int64Value, err := value.ToInt64(sc.TypeCtx()) if err != nil { return err } diff --git a/pkg/expression/aggregation/max_min.go b/pkg/expression/aggregation/max_min.go index 2dd7696cdf9e7..cffdae9a2e22e 100644 --- a/pkg/expression/aggregation/max_min.go +++ b/pkg/expression/aggregation/max_min.go @@ -51,7 +51,7 @@ func (mmf *maxMinFunction) Update(evalCtx *AggEvaluateContext, sc *stmtctx.State return nil } var c int - c, err = evalCtx.Value.Compare(sc, &value, mmf.ctor) + c, err = evalCtx.Value.Compare(sc.TypeCtx(), &value, mmf.ctor) if err != nil { return err } diff --git a/pkg/expression/builtin_cast.go b/pkg/expression/builtin_cast.go index c4d0d39e0b864..49230bd12b5d5 100644 --- a/pkg/expression/builtin_cast.go +++ b/pkg/expression/builtin_cast.go @@ -757,7 +757,7 @@ func (b *builtinCastIntAsTimeSig) evalTime(row chunk.Row) (res types.Time, isNul } if b.args[0].GetType().GetType() == mysql.TypeYear { - res, err = types.ParseTimeFromYear(b.ctx.GetSessionVars().StmtCtx, val) + res, err = types.ParseTimeFromYear(val) } else { res, err = types.ParseTimeFromNum(b.ctx.GetSessionVars().StmtCtx.TypeCtx(), val, b.tp.GetType(), b.tp.GetDecimal()) } diff --git a/pkg/expression/builtin_cast_vec.go b/pkg/expression/builtin_cast_vec.go index 2af8f84bd97a3..d4dcd97b07e23 100644 --- a/pkg/expression/builtin_cast_vec.go +++ b/pkg/expression/builtin_cast_vec.go @@ -385,7 +385,7 @@ func (b *builtinCastIntAsTimeSig) vecEvalTime(input *chunk.Chunk, result *chunk. } if b.args[0].GetType().GetType() == mysql.TypeYear { - tm, err = types.ParseTimeFromYear(stmt, i64s[i]) + tm, err = types.ParseTimeFromYear(i64s[i]) } else { tm, err = types.ParseTimeFromNum(stmt.TypeCtx(), i64s[i], b.tp.GetType(), fsp) } diff --git a/pkg/expression/builtin_compare.go b/pkg/expression/builtin_compare.go index 0eecdf1b52cc2..133a54182f9b1 100644 --- a/pkg/expression/builtin_compare.go +++ b/pkg/expression/builtin_compare.go @@ -1443,7 +1443,7 @@ func tryToConvertConstantInt(ctx sessionctx.Context, targetFieldType *types.Fiel } sc := ctx.GetSessionVars().StmtCtx - dt, err = dt.ConvertTo(sc, targetFieldType) + dt, err = dt.ConvertTo(sc.TypeCtx(), targetFieldType) if err != nil { if terror.ErrorEqual(err, types.ErrOverflow) { return &Constant{ @@ -1482,7 +1482,7 @@ func RefineComparedConstant(ctx sessionctx.Context, targetFieldType types.FieldT targetFieldType = *types.NewFieldType(mysql.TypeLonglong) } var intDatum types.Datum - intDatum, err = dt.ConvertTo(sc, &targetFieldType) + intDatum, err = dt.ConvertTo(sc.TypeCtx(), &targetFieldType) if err != nil { if terror.ErrorEqual(err, types.ErrOverflow) { return &Constant{ @@ -1494,7 +1494,7 @@ func RefineComparedConstant(ctx sessionctx.Context, targetFieldType types.FieldT } return con, false } - c, err := intDatum.Compare(sc, &con.Value, collate.GetBinaryCollator()) + c, err := intDatum.Compare(sc.TypeCtx(), &con.Value, collate.GetBinaryCollator()) if err != nil { return con, false } @@ -1539,7 +1539,7 @@ func RefineComparedConstant(ctx sessionctx.Context, targetFieldType types.FieldT // 3. Suppose the value of `con` is 2, when `targetFieldType.GetType()` is `TypeYear`, the value of `doubleDatum` // will be 2.0 and the value of `intDatum` will be 2002 in this case. var doubleDatum types.Datum - doubleDatum, err = dt.ConvertTo(sc, types.NewFieldType(mysql.TypeDouble)) + doubleDatum, err = dt.ConvertTo(sc.TypeCtx(), types.NewFieldType(mysql.TypeDouble)) if err != nil { return con, false } @@ -1737,7 +1737,7 @@ func (c *compareFunctionClass) refineNumericConstantCmpDatetime(ctx sessionctx.C sc := ctx.GetSessionVars().StmtCtx var datetimeDatum types.Datum targetFieldType := types.NewFieldType(mysql.TypeDatetime) - datetimeDatum, err = dt.ConvertTo(sc, targetFieldType) + datetimeDatum, err = dt.ConvertTo(sc.TypeCtx(), targetFieldType) if err != nil || datetimeDatum.IsNull() { return args } diff --git a/pkg/expression/builtin_other_test.go b/pkg/expression/builtin_other_test.go index c0d364edc9cfc..968cf9ee5aa91 100644 --- a/pkg/expression/builtin_other_test.go +++ b/pkg/expression/builtin_other_test.go @@ -21,7 +21,6 @@ import ( "github.com/pingcap/tidb/pkg/parser/ast" "github.com/pingcap/tidb/pkg/parser/mysql" - "github.com/pingcap/tidb/pkg/sessionctx/stmtctx" "github.com/pingcap/tidb/pkg/types" "github.com/pingcap/tidb/pkg/util/chunk" "github.com/pingcap/tidb/pkg/util/collate" @@ -67,9 +66,8 @@ func TestBitCount(t *testing.T) { require.Nil(t, test.count) continue } - sc := stmtctx.NewStmtCtxWithTimeZone(stmtCtx.TimeZone()) - sc.SetTypeFlags(sc.TypeFlags().WithIgnoreTruncateErr(true)) - res, err := count.ToInt64(sc) + ctx := types.DefaultStmtNoWarningContext.WithFlags(types.DefaultStmtFlags.WithIgnoreTruncateErr(true)) + res, err := count.ToInt64(ctx) require.NoError(t, err) require.Equal(t, test.count, res) } @@ -195,7 +193,7 @@ func TestValues(t *testing.T) { ret, err = evalBuiltinFunc(sig, chunk.Row{}) require.NoError(t, err) - cmp, err := ret.Compare(nil, &currInsertValues[1], collate.GetBinaryCollator()) + cmp, err := ret.Compare(types.DefaultStmtNoWarningContext, &currInsertValues[1], collate.GetBinaryCollator()) require.NoError(t, err) require.Equal(t, 0, cmp) } diff --git a/pkg/expression/builtin_time_test.go b/pkg/expression/builtin_time_test.go index cb85b8ab63146..f605647c711eb 100644 --- a/pkg/expression/builtin_time_test.go +++ b/pkg/expression/builtin_time_test.go @@ -1192,7 +1192,7 @@ func convertToTimeWithFsp(sc *stmtctx.StatementContext, arg types.Datum, tp byte f := types.NewFieldType(tp) f.SetDecimal(fsp) - d, err = arg.ConvertTo(sc, f) + d, err = arg.ConvertTo(sc.TypeCtx(), f) if err != nil { d.SetNull() return d, err diff --git a/pkg/expression/column.go b/pkg/expression/column.go index e7e8af727a6c1..bfd1d051ea577 100644 --- a/pkg/expression/column.go +++ b/pkg/expression/column.go @@ -100,7 +100,7 @@ func (col *CorrelatedColumn) EvalInt(ctx sessionctx.Context, row chunk.Row) (int return 0, true, nil } if col.GetType().Hybrid() { - res, err := col.Data.ToInt64(ctx.GetSessionVars().StmtCtx) + res, err := col.Data.ToInt64(ctx.GetSessionVars().StmtCtx.TypeCtx()) return res, err != nil, err } return col.Data.GetInt64(), false, nil @@ -425,7 +425,7 @@ func (col *Column) EvalInt(ctx sessionctx.Context, row chunk.Row) (int64, bool, val, err := val.GetBinaryLiteral().ToInt(ctx.GetSessionVars().StmtCtx.TypeCtx()) return int64(val), err != nil, err } - res, err := val.ToInt64(ctx.GetSessionVars().StmtCtx) + res, err := val.ToInt64(ctx.GetSessionVars().StmtCtx.TypeCtx()) return res, err != nil, err } if row.IsNull(col.Index) { @@ -703,7 +703,7 @@ func (col *Column) SupportReverseEval() bool { // ReverseEval evaluates the only one column value with given function result. func (col *Column) ReverseEval(sc *stmtctx.StatementContext, res types.Datum, rType types.RoundingType) (val types.Datum, err error) { - return types.ChangeReverseResultByUpperLowerBound(sc, col.RetType, res, rType) + return types.ChangeReverseResultByUpperLowerBound(sc.TypeCtx(), col.RetType, res, rType) } // Coercibility returns the coercibility value which is used to check collations. diff --git a/pkg/expression/constant.go b/pkg/expression/constant.go index 74f2504f8beaf..91703c3c29595 100644 --- a/pkg/expression/constant.go +++ b/pkg/expression/constant.go @@ -250,7 +250,7 @@ func (c *Constant) Eval(row chunk.Row) (types.Datum, error) { sf, sfOk := c.DeferredExpr.(*ScalarFunction) if sfOk { if dt.Kind() != types.KindMysqlDecimal { - val, err := dt.ConvertTo(sf.GetCtx().GetSessionVars().StmtCtx, c.RetType) + val, err := dt.ConvertTo(sf.GetCtx().GetSessionVars().StmtCtx.TypeCtx(), c.RetType) if err != nil { return dt, err } @@ -281,7 +281,7 @@ func (c *Constant) EvalInt(ctx sessionctx.Context, row chunk.Row) (int64, bool, val, err := dt.GetBinaryLiteral().ToInt(ctx.GetSessionVars().StmtCtx.TypeCtx()) return int64(val), err != nil, err } else if c.GetType().Hybrid() || dt.Kind() == types.KindString { - res, err := dt.ToInt64(ctx.GetSessionVars().StmtCtx) + res, err := dt.ToInt64(ctx.GetSessionVars().StmtCtx.TypeCtx()) return res, false, err } else if dt.Kind() == types.KindMysqlBit { uintVal, err := dt.GetBinaryLiteral().ToInt(ctx.GetSessionVars().StmtCtx.TypeCtx()) @@ -412,7 +412,7 @@ func (c *Constant) Equal(ctx sessionctx.Context, b Expression) bool { if err1 != nil || err2 != nil { return false } - con, err := c.Value.Compare(ctx.GetSessionVars().StmtCtx, &y.Value, collate.GetBinaryCollator()) + con, err := c.Value.Compare(ctx.GetSessionVars().StmtCtx.TypeCtx(), &y.Value, collate.GetBinaryCollator()) if err != nil || con != 0 { return false } diff --git a/pkg/expression/constant_propagation.go b/pkg/expression/constant_propagation.go index 5bb233a3127f2..5b19da65c55e9 100644 --- a/pkg/expression/constant_propagation.go +++ b/pkg/expression/constant_propagation.go @@ -62,7 +62,7 @@ func (s *basePropConstSolver) tryToUpdateEQList(col *Column, con *Constant) (boo id := s.getColID(col) oldCon := s.eqList[id] if oldCon != nil { - res, err := oldCon.Value.Compare(s.ctx.GetSessionVars().StmtCtx, &con.Value, collate.GetCollator(col.GetType().GetCollate())) + res, err := oldCon.Value.Compare(s.ctx.GetSessionVars().StmtCtx.TypeCtx(), &con.Value, collate.GetCollator(col.GetType().GetCollate())) return false, res != 0 || err != nil } s.eqList[id] = con diff --git a/pkg/expression/distsql_builtin_test.go b/pkg/expression/distsql_builtin_test.go index 34d9ef2ca49bd..b12fb40f8fc08 100644 --- a/pkg/expression/distsql_builtin_test.go +++ b/pkg/expression/distsql_builtin_test.go @@ -785,7 +785,7 @@ func TestEval(t *testing.T) { result, err := expr.Eval(row) require.NoError(t, err) require.Equal(t, tt.result.Kind(), result.Kind()) - cmp, err := result.Compare(sc, &tt.result, collate.GetCollator(fieldTps[0].GetCollate())) + cmp, err := result.Compare(sc.TypeCtx(), &tt.result, collate.GetCollator(fieldTps[0].GetCollate())) require.NoError(t, err) require.Equal(t, 0, cmp) } diff --git a/pkg/expression/evaluator_test.go b/pkg/expression/evaluator_test.go index cf1295d71ad7e..24ecee4186b92 100644 --- a/pkg/expression/evaluator_test.go +++ b/pkg/expression/evaluator_test.go @@ -568,7 +568,7 @@ func TestUnaryOp(t *testing.T) { require.NoError(t, err) expect := types.NewDatum(tt.result) - ret, err := result.Compare(ctx.GetSessionVars().StmtCtx, &expect, collate.GetBinaryCollator()) + ret, err := result.Compare(ctx.GetSessionVars().StmtCtx.TypeCtx(), &expect, collate.GetBinaryCollator()) require.NoError(t, err) require.Equalf(t, 0, ret, "%v %s", tt.arg, tt.op) } diff --git a/pkg/expression/helper.go b/pkg/expression/helper.go index 8e07ae72e4ca1..0643db7f850e9 100644 --- a/pkg/expression/helper.go +++ b/pkg/expression/helper.go @@ -137,7 +137,7 @@ func GetTimeValue(ctx sessionctx.Context, v interface{}, tp byte, fsp int, expli return d, err } ft := types.NewFieldType(mysql.TypeLonglong) - xval, err := v.ConvertTo(ctx.GetSessionVars().StmtCtx, ft) + xval, err := v.ConvertTo(ctx.GetSessionVars().StmtCtx.TypeCtx(), ft) if err != nil { return d, err } diff --git a/pkg/planner/cardinality/pseudo.go b/pkg/planner/cardinality/pseudo.go index bd0d62a4b7d7b..7a8dba62a17fc 100644 --- a/pkg/planner/cardinality/pseudo.go +++ b/pkg/planner/cardinality/pseudo.go @@ -217,7 +217,7 @@ func getPseudoRowCountByColumnRanges(sc *stmtctx.StatementContext, tableRowCount } else if ran.HighVal[colIdx].Kind() == types.KindMaxValue { rowCount += tableRowCount / pseudoLessRate } else { - compare, err := ran.LowVal[colIdx].Compare(sc, &ran.HighVal[colIdx], ran.Collators[colIdx]) + compare, err := ran.LowVal[colIdx].Compare(sc.TypeCtx(), &ran.HighVal[colIdx], ran.Collators[colIdx]) if err != nil { return 0, errors.Trace(err) } diff --git a/pkg/planner/cardinality/row_count_column.go b/pkg/planner/cardinality/row_count_column.go index f9f075dc818f6..ee8a6220655de 100644 --- a/pkg/planner/cardinality/row_count_column.go +++ b/pkg/planner/cardinality/row_count_column.go @@ -187,7 +187,7 @@ func GetColumnRowCount(sctx sessionctx.Context, c *statistics.Column, ranges []* if lowVal.Kind() == types.KindString { lowVal.SetBytes(collate.GetCollator(lowVal.Collation()).Key(lowVal.GetString())) } - cmp, err := lowVal.Compare(sc, &highVal, collate.GetBinaryCollator()) + cmp, err := lowVal.Compare(sc.TypeCtx(), &highVal, collate.GetBinaryCollator()) if err != nil { return 0, errors.Trace(err) } diff --git a/pkg/planner/cardinality/row_count_index.go b/pkg/planner/cardinality/row_count_index.go index 06ff7c9627cf3..eef562fd59a40 100644 --- a/pkg/planner/cardinality/row_count_index.go +++ b/pkg/planner/cardinality/row_count_index.go @@ -515,7 +515,7 @@ func betweenRowCountOnIndex(sctx sessionctx.Context, idx *statistics.Index, l, r func getOrdinalOfRangeCond(sc *stmtctx.StatementContext, ran *ranger.Range) int { for i := range ran.LowVal { a, b := ran.LowVal[i], ran.HighVal[i] - cmp, err := a.Compare(sc, &b, ran.Collators[0]) + cmp, err := a.Compare(sc.TypeCtx(), &b, ran.Collators[0]) if err != nil { return 0 } diff --git a/pkg/planner/core/handle_cols.go b/pkg/planner/core/handle_cols.go index c745e3d1fb4c3..13c07443501f9 100644 --- a/pkg/planner/core/handle_cols.go +++ b/pkg/planner/core/handle_cols.go @@ -174,7 +174,7 @@ func (cb *CommonHandleCols) Compare(a, b []types.Datum, ctors []collate.Collator for i, col := range cb.columns { aDatum := &a[col.Index] bDatum := &b[col.Index] - cmp, err := aDatum.Compare(cb.sc, bDatum, ctors[i]) + cmp, err := aDatum.Compare(cb.sc.TypeCtx(), bDatum, ctors[i]) if err != nil { return 0, err } @@ -288,7 +288,7 @@ func (*IntHandleCols) NumCols() int { func (ib *IntHandleCols) Compare(a, b []types.Datum, ctors []collate.Collator) (int, error) { aVal := &a[ib.col.Index] bVal := &b[ib.col.Index] - return aVal.Compare(nil, bVal, ctors[ib.col.Index]) + return aVal.Compare(types.DefaultStmtNoWarningContext, bVal, ctors[ib.col.Index]) } // GetFieldsTypes implements the kv.HandleCols interface. diff --git a/pkg/planner/core/memtable_predicate_extractor.go b/pkg/planner/core/memtable_predicate_extractor.go index fb33a510e7bdc..6ea343681928e 100644 --- a/pkg/planner/core/memtable_predicate_extractor.go +++ b/pkg/planner/core/memtable_predicate_extractor.go @@ -446,7 +446,7 @@ func (helper extractHelper) extractTimeRange( if colName == extractColName { timeType := types.NewFieldType(mysql.TypeDatetime) timeType.SetDecimal(6) - timeDatum, err := datums[0].ConvertTo(ctx.GetSessionVars().StmtCtx, timeType) + timeDatum, err := datums[0].ConvertTo(ctx.GetSessionVars().StmtCtx.TypeCtx(), timeType) if err != nil || timeDatum.Kind() == types.KindNull { remained = append(remained, expr) continue diff --git a/pkg/planner/core/plan_cache.go b/pkg/planner/core/plan_cache.go index 9cad0384c919b..6b406788035af 100644 --- a/pkg/planner/core/plan_cache.go +++ b/pkg/planner/core/plan_cache.go @@ -494,7 +494,7 @@ func rebuildRange(p Plan) error { if err != nil { return err } - iv, err := dVal.ToInt64(sc) + iv, err := dVal.ToInt64(sc.TypeCtx()) if err != nil { return err } @@ -560,7 +560,7 @@ func rebuildRange(p Plan) error { if err != nil { return err } - iv, err := dVal.ToInt64(sc) + iv, err := dVal.ToInt64(sc.TypeCtx()) if err != nil { return err } @@ -619,12 +619,12 @@ func convertConstant2Datum(sc *stmtctx.StatementContext, con *expression.Constan if err != nil { return nil, err } - dVal, err := val.ConvertTo(sc, target) + dVal, err := val.ConvertTo(sc.TypeCtx(), target) if err != nil { return nil, err } // The converted result must be same as original datum. - cmp, err := dVal.Compare(sc, &val, collate.GetCollator(target.GetCollate())) + cmp, err := dVal.Compare(sc.TypeCtx(), &val, collate.GetCollator(target.GetCollate())) if err != nil || cmp != 0 { return nil, errors.New("Convert constant to datum is failed, because the constant has changed after the covert") } diff --git a/pkg/planner/core/planbuilder.go b/pkg/planner/core/planbuilder.go index e256e90d8fe66..b216dbdc5eac0 100644 --- a/pkg/planner/core/planbuilder.go +++ b/pkg/planner/core/planbuilder.go @@ -4676,7 +4676,7 @@ func (b *PlanBuilder) convertValue(valueItem ast.ExprNode, mockTablePlan Logical if err != nil { return d, err } - d, err = value.ConvertTo(b.ctx.GetSessionVars().StmtCtx, &col.FieldType) + d, err = value.ConvertTo(b.ctx.GetSessionVars().StmtCtx.TypeCtx(), &col.FieldType) if err != nil { if !types.ErrTruncated.Equal(err) && !types.ErrTruncatedWrongVal.Equal(err) && !types.ErrBadNumber.Equal(err) { return d, err @@ -5571,10 +5571,8 @@ func calcTSForPlanReplayer(sctx sessionctx.Context, tsExpr ast.ExprNode) uint64 tpLonglong.SetFlag(mysql.UnsignedFlag) // We need a strict check, which means no truncate or any other warnings/errors, or it will wrongly try to parse // a date/time string into a TSO. - // To achieve this, we need to set fields like StatementContext.IgnoreTruncate to false, and maybe it's better - // not to modify and reuse the original StatementContext, so we use a temporary one here. - tmpStmtCtx := stmtctx.NewStmtCtxWithTimeZone(sctx.GetSessionVars().Location()) - tso, err := tsVal.ConvertTo(tmpStmtCtx, tpLonglong) + // To achieve this, we create a new type context without re-using the one in statement context. + tso, err := tsVal.ConvertTo(types.DefaultStmtNoWarningContext.WithLocation(sctx.GetSessionVars().Location()), tpLonglong) if err == nil { return tso.GetUint64() } @@ -5583,7 +5581,7 @@ func calcTSForPlanReplayer(sctx sessionctx.Context, tsExpr ast.ExprNode) uint64 // this part is similar to CalculateAsOfTsExpr tpDateTime := types.NewFieldType(mysql.TypeDatetime) tpDateTime.SetDecimal(6) - timestamp, err := tsVal.ConvertTo(sctx.GetSessionVars().StmtCtx, tpDateTime) + timestamp, err := tsVal.ConvertTo(sctx.GetSessionVars().StmtCtx.TypeCtx(), tpDateTime) if err != nil { sctx.GetSessionVars().StmtCtx.AppendWarning(err) return 0 diff --git a/pkg/planner/core/point_get_plan.go b/pkg/planner/core/point_get_plan.go index 63403cfe5256f..68ab03f220ba6 100644 --- a/pkg/planner/core/point_get_plan.go +++ b/pkg/planner/core/point_get_plan.go @@ -1444,7 +1444,7 @@ func getNameValuePairs(ctx sessionctx.Context, tbl *model.TableInfo, tblName mod if !checkCanConvertInPointGet(col, d) { return nil, false } - dVal, err := d.ConvertTo(stmtCtx, &col.FieldType) + dVal, err := d.ConvertTo(stmtCtx.TypeCtx(), &col.FieldType) if err != nil { if terror.ErrorEqual(types.ErrOverflow, err) { return append(nvPairs, nameValuePair{colName: colName.Name.Name.L, colFieldType: &col.FieldType, value: d, con: con}), true @@ -1455,7 +1455,7 @@ func getNameValuePairs(ctx sessionctx.Context, tbl *model.TableInfo, tblName mod } } // The converted result must be same as original datum. - cmp, err := dVal.Compare(stmtCtx, &d, collate.GetCollator(col.GetCollate())) + cmp, err := dVal.Compare(stmtCtx.TypeCtx(), &d, collate.GetCollator(col.GetCollate())) if err != nil || cmp != 0 { return nil, false } @@ -1468,12 +1468,12 @@ func getPointGetValue(stmtCtx *stmtctx.StatementContext, col *model.ColumnInfo, if !checkCanConvertInPointGet(col, *d) { return nil } - dVal, err := d.ConvertTo(stmtCtx, &col.FieldType) + dVal, err := d.ConvertTo(stmtCtx.TypeCtx(), &col.FieldType) if err != nil { return nil } // The converted result must be same as original datum. - cmp, err := dVal.Compare(stmtCtx, d, collate.GetCollator(col.GetCollate())) + cmp, err := dVal.Compare(stmtCtx.TypeCtx(), d, collate.GetCollator(col.GetCollate())) if err != nil || cmp != 0 { return nil } diff --git a/pkg/planner/core/rule_partition_processor.go b/pkg/planner/core/rule_partition_processor.go index 661328691cde7..f3bfe3eaf3007 100644 --- a/pkg/planner/core/rule_partition_processor.go +++ b/pkg/planner/core/rule_partition_processor.go @@ -1081,7 +1081,7 @@ func minCmp(ctx sessionctx.Context, lowVal []types.Datum, columnsPruner *rangeCo return true } // Add Null as point here? - cmp, err := con.Value.Compare(ctx.GetSessionVars().StmtCtx, &lowVal[j], comparer[j]) + cmp, err := con.Value.Compare(ctx.GetSessionVars().StmtCtx.TypeCtx(), &lowVal[j], comparer[j]) if err != nil { *gotError = true } @@ -1160,7 +1160,7 @@ func maxCmp(ctx sessionctx.Context, hiVal []types.Datum, columnsPruner *rangeCol return false } // Add Null as point here? - cmp, err := con.Value.Compare(ctx.GetSessionVars().StmtCtx, &hiVal[j], comparer[j]) + cmp, err := con.Value.Compare(ctx.GetSessionVars().StmtCtx.TypeCtx(), &hiVal[j], comparer[j]) if err != nil { *gotError = true // error pushed, we will still use the cmp value @@ -1388,7 +1388,7 @@ func partitionRangeForInExpr(sctx sessionctx.Context, args []expression.Expressi partFnConst := replaceColumnWithConst(pruner.partFn, constExpr) val, _, err = partFnConst.EvalInt(sctx, chunk.Row{}) } else { - val, err = constExpr.Value.ToInt64(sctx.GetSessionVars().StmtCtx) + val, err = constExpr.Value.ToInt64(sctx.GetSessionVars().StmtCtx.TypeCtx()) } if err != nil { return pruner.fullRange() diff --git a/pkg/server/handler/tikv_handler.go b/pkg/server/handler/tikv_handler.go index 5dd0c8af3f4c4..b5e72c1f5c508 100644 --- a/pkg/server/handler/tikv_handler.go +++ b/pkg/server/handler/tikv_handler.go @@ -160,7 +160,7 @@ func (*TikvHandlerTool) formValue2DatumRow(sc *stmtctx.StatementContext, values data[i].SetNull() case 1: bDatum := types.NewStringDatum(vals[0]) - cDatum, err := bDatum.ConvertTo(sc, &col.FieldType) + cDatum, err := bDatum.ConvertTo(sc.TypeCtx(), &col.FieldType) if err != nil { return nil, errors.Trace(err) } diff --git a/pkg/session/nontransactional.go b/pkg/session/nontransactional.go index b9f150529ce39..197643c48ebb3 100644 --- a/pkg/session/nontransactional.go +++ b/pkg/session/nontransactional.go @@ -509,7 +509,7 @@ func buildShardJobs(ctx context.Context, stmt *ast.NonTransactionalDMLStmt, se S } newEnd := row.GetDatum(0, &rs.Fields()[0].Column.FieldType) if currentSize >= batchSize { - cmp, err := newEnd.Compare(se.GetSessionVars().StmtCtx, ¤tEnd, collate.GetCollator(shardColumnCollate)) + cmp, err := newEnd.Compare(se.GetSessionVars().StmtCtx.TypeCtx(), ¤tEnd, collate.GetCollator(shardColumnCollate)) if err != nil { return nil, err } diff --git a/pkg/sessionctx/stmtctx/BUILD.bazel b/pkg/sessionctx/stmtctx/BUILD.bazel index 1178d300448ea..f768b8c2eb7a1 100644 --- a/pkg/sessionctx/stmtctx/BUILD.bazel +++ b/pkg/sessionctx/stmtctx/BUILD.bazel @@ -12,7 +12,7 @@ go_library( "//pkg/parser/model", "//pkg/parser/mysql", "//pkg/parser/terror", - "//pkg/types/context", + "//pkg/types", "//pkg/util/disk", "//pkg/util/execdetails", "//pkg/util/intest", @@ -47,7 +47,6 @@ go_test( "//pkg/testkit", "//pkg/testkit/testsetup", "//pkg/types", - "//pkg/types/context", "//pkg/util/execdetails", "@com_github_pingcap_errors//:errors", "@com_github_stretchr_testify//require", diff --git a/pkg/sessionctx/stmtctx/stmtctx.go b/pkg/sessionctx/stmtctx/stmtctx.go index 8168ef94198bc..2216d05ef3f99 100644 --- a/pkg/sessionctx/stmtctx/stmtctx.go +++ b/pkg/sessionctx/stmtctx/stmtctx.go @@ -34,7 +34,7 @@ import ( "github.com/pingcap/tidb/pkg/parser/model" "github.com/pingcap/tidb/pkg/parser/mysql" "github.com/pingcap/tidb/pkg/parser/terror" - typectx "github.com/pingcap/tidb/pkg/types/context" + "github.com/pingcap/tidb/pkg/types" "github.com/pingcap/tidb/pkg/util/disk" "github.com/pingcap/tidb/pkg/util/execdetails" "github.com/pingcap/tidb/pkg/util/intest" @@ -156,7 +156,7 @@ type StatementContext struct { _ constructor.Constructor `ctor:"NewStmtCtx,NewStmtCtxWithTimeZone,Reset"` // typeCtx is used to indicate how to make the type conversation. - typeCtx typectx.Context + typeCtx types.Context // Set the following variables before execution StmtHints @@ -425,7 +425,7 @@ type StatementContext struct { // NewStmtCtx creates a new statement context func NewStmtCtx() *StatementContext { sc := &StatementContext{} - sc.typeCtx = typectx.NewContext(typectx.DefaultStmtFlags, time.UTC, sc.AppendWarning) + sc.typeCtx = types.NewContext(types.DefaultStmtFlags, time.UTC, sc.AppendWarning) return sc } @@ -433,14 +433,14 @@ func NewStmtCtx() *StatementContext { func NewStmtCtxWithTimeZone(tz *time.Location) *StatementContext { intest.Assert(tz) sc := &StatementContext{} - sc.typeCtx = typectx.NewContext(typectx.DefaultStmtFlags, tz, sc.AppendWarning) + sc.typeCtx = types.NewContext(types.DefaultStmtFlags, tz, sc.AppendWarning) return sc } // Reset resets a statement context func (sc *StatementContext) Reset() { *sc = StatementContext{ - typeCtx: typectx.NewContext(typectx.DefaultStmtFlags, time.UTC, sc.AppendWarning), + typeCtx: types.NewContext(types.DefaultStmtFlags, time.UTC, sc.AppendWarning), } } @@ -456,17 +456,17 @@ func (sc *StatementContext) SetTimeZone(tz *time.Location) { } // TypeCtx returns the type context -func (sc *StatementContext) TypeCtx() typectx.Context { +func (sc *StatementContext) TypeCtx() types.Context { return sc.typeCtx } // TypeFlags returns the type flags -func (sc *StatementContext) TypeFlags() typectx.Flags { +func (sc *StatementContext) TypeFlags() types.Flags { return sc.typeCtx.Flags() } // SetTypeFlags sets the type flags -func (sc *StatementContext) SetTypeFlags(flags typectx.Flags) { +func (sc *StatementContext) SetTypeFlags(flags types.Flags) { sc.typeCtx = sc.typeCtx.WithFlags(flags) } @@ -1209,7 +1209,7 @@ func (sc *StatementContext) InitFromPBFlagAndTz(flags uint64, tz *time.Location) sc.OverflowAsWarning = (flags & model.FlagOverflowAsWarning) > 0 sc.DividedByZeroAsWarning = (flags & model.FlagDividedByZeroAsWarning) > 0 sc.SetTimeZone(tz) - sc.SetTypeFlags(typectx.DefaultStmtFlags. + sc.SetTypeFlags(types.DefaultStmtFlags. WithIgnoreTruncateErr((flags & model.FlagIgnoreTruncate) > 0). WithTruncateAsWarning((flags & model.FlagTruncateAsWarning) > 0). WithIgnoreZeroInDate((flags & model.FlagIgnoreZeroInDate) > 0). @@ -1356,12 +1356,12 @@ func (sc *StatementContext) RecordedStatsLoadStatusCnt() (cnt int) { // If the statement context is nil, it'll return a newly created default type context. // **don't** use this function if you can make sure the `sc` is not nil. We should limit the usage of this function as // little as possible. -func (sc *StatementContext) TypeCtxOrDefault() typectx.Context { +func (sc *StatementContext) TypeCtxOrDefault() types.Context { if sc != nil { return sc.typeCtx } - return typectx.DefaultStmtNoWarningContext + return types.DefaultStmtNoWarningContext } // UsedStatsInfoForTable records stats that are used during query and their information. diff --git a/pkg/sessionctx/stmtctx/stmtctx_test.go b/pkg/sessionctx/stmtctx/stmtctx_test.go index 80686718aa784..60b5bdad3ff84 100644 --- a/pkg/sessionctx/stmtctx/stmtctx_test.go +++ b/pkg/sessionctx/stmtctx/stmtctx_test.go @@ -30,7 +30,6 @@ import ( "github.com/pingcap/tidb/pkg/sessionctx/variable" "github.com/pingcap/tidb/pkg/testkit" "github.com/pingcap/tidb/pkg/types" - typectx "github.com/pingcap/tidb/pkg/types/context" "github.com/pingcap/tidb/pkg/util/execdetails" "github.com/stretchr/testify/require" "github.com/tikv/client-go/v2/util" @@ -349,12 +348,12 @@ func TestSetStmtCtxTypeFlags(t *testing.T) { sc := stmtctx.NewStmtCtx() require.Equal(t, types.DefaultStmtFlags, sc.TypeFlags()) - sc.SetTypeFlags(typectx.FlagAllowNegativeToUnsigned | typectx.FlagSkipASCIICheck) - require.Equal(t, typectx.FlagAllowNegativeToUnsigned|typectx.FlagSkipASCIICheck, sc.TypeFlags()) + sc.SetTypeFlags(types.FlagAllowNegativeToUnsigned | types.FlagSkipASCIICheck) + require.Equal(t, types.FlagAllowNegativeToUnsigned|types.FlagSkipASCIICheck, sc.TypeFlags()) require.Equal(t, sc.TypeFlags(), sc.TypeFlags()) - sc.SetTypeFlags(typectx.FlagSkipASCIICheck | typectx.FlagSkipUTF8Check | typectx.FlagTruncateAsWarning) - require.Equal(t, typectx.FlagSkipASCIICheck|typectx.FlagSkipUTF8Check|typectx.FlagTruncateAsWarning, sc.TypeFlags()) + sc.SetTypeFlags(types.FlagSkipASCIICheck | types.FlagSkipUTF8Check | types.FlagTruncateAsWarning) + require.Equal(t, types.FlagSkipASCIICheck|types.FlagSkipUTF8Check|types.FlagTruncateAsWarning, sc.TypeFlags()) require.Equal(t, sc.TypeFlags(), sc.TypeFlags()) } @@ -364,13 +363,13 @@ func TestResetStmtCtx(t *testing.T) { tz := time.FixedZone("UTC+1", 2*60*60) sc.SetTimeZone(tz) - sc.SetTypeFlags(typectx.FlagAllowNegativeToUnsigned | typectx.FlagSkipASCIICheck) + sc.SetTypeFlags(types.FlagAllowNegativeToUnsigned | types.FlagSkipASCIICheck) sc.AppendWarning(errors.New("err1")) sc.InRestrictedSQL = true sc.StmtType = "Insert" require.Same(t, tz, sc.TimeZone()) - require.Equal(t, typectx.FlagAllowNegativeToUnsigned|typectx.FlagSkipASCIICheck, sc.TypeFlags()) + require.Equal(t, types.FlagAllowNegativeToUnsigned|types.FlagSkipASCIICheck, sc.TypeFlags()) require.Equal(t, 1, len(sc.GetWarnings())) sc.Reset() diff --git a/pkg/sessiontxn/staleread/util.go b/pkg/sessiontxn/staleread/util.go index fde24b7b7c8ed..0f578694b5b80 100644 --- a/pkg/sessiontxn/staleread/util.go +++ b/pkg/sessiontxn/staleread/util.go @@ -52,7 +52,7 @@ func CalculateAsOfTsExpr(ctx context.Context, sctx sessionctx.Context, tsExpr as toTypeTimestamp := types.NewFieldType(mysql.TypeTimestamp) // We need at least the millionsecond here, so set fsp to 3. toTypeTimestamp.SetDecimal(3) - tsTimestamp, err := tsVal.ConvertTo(sctx.GetSessionVars().StmtCtx, toTypeTimestamp) + tsTimestamp, err := tsVal.ConvertTo(sctx.GetSessionVars().StmtCtx.TypeCtx(), toTypeTimestamp) if err != nil { return 0, err } diff --git a/pkg/statistics/builder.go b/pkg/statistics/builder.go index 363c023a579a9..b7dd64a20a4ad 100644 --- a/pkg/statistics/builder.go +++ b/pkg/statistics/builder.go @@ -71,7 +71,7 @@ func (b *SortedBuilder) Iterate(data types.Datum) error { b.hist.NDV = 1 return nil } - cmp, err := b.hist.GetUpper(int(b.bucketIdx)).Compare(b.sc, &data, collate.GetBinaryCollator()) + cmp, err := b.hist.GetUpper(int(b.bucketIdx)).Compare(b.sc.TypeCtx(), &data, collate.GetBinaryCollator()) if err != nil { return errors.Trace(err) } @@ -178,7 +178,7 @@ func buildHist(sc *stmtctx.StatementContext, hg *Histogram, samples []*SampleIte memTracker.BufferedConsume(&bufferedMemSize, deltaSize) memTracker.BufferedRelease(&bufferedReleaseSize, deltaSize) } - cmp, err := upper.Compare(sc, &samples[i].Value, collate.GetBinaryCollator()) + cmp, err := upper.Compare(sc.TypeCtx(), &samples[i].Value, collate.GetBinaryCollator()) if err != nil { return 0, errors.Trace(err) } diff --git a/pkg/statistics/handle/bootstrap.go b/pkg/statistics/handle/bootstrap.go index ba5895c8d5091..83dd83d1a209a 100644 --- a/pkg/statistics/handle/bootstrap.go +++ b/pkg/statistics/handle/bootstrap.go @@ -407,14 +407,14 @@ func (*Handle) initStatsBuckets4Chunk(cache util.StatsCache, iter *chunk.Iterato sc := stmtctx.NewStmtCtxWithTimeZone(time.UTC) sc.SetTypeFlags(sc.TypeFlags().WithIgnoreInvalidDateErr(true).WithIgnoreZeroInDate(true)) var err error - lower, err = d.ConvertTo(sc, &column.Info.FieldType) + lower, err = d.ConvertTo(sc.TypeCtx(), &column.Info.FieldType) if err != nil { logutil.BgLogger().Debug("decode bucket lower bound failed", zap.Error(err)) delete(table.Columns, histID) continue } d = types.NewBytesDatum(row.GetBytes(6)) - upper, err = d.ConvertTo(sc, &column.Info.FieldType) + upper, err = d.ConvertTo(sc.TypeCtx(), &column.Info.FieldType) if err != nil { logutil.BgLogger().Debug("decode bucket upper bound failed", zap.Error(err)) delete(table.Columns, histID) diff --git a/pkg/statistics/handle/storage/read.go b/pkg/statistics/handle/storage/read.go index 352171f7a0dcf..e8c0894148b2c 100644 --- a/pkg/statistics/handle/storage/read.go +++ b/pkg/statistics/handle/storage/read.go @@ -85,12 +85,12 @@ func HistogramFromStorage(sctx sessionctx.Context, tableID int64, colID int64, t if tp.EvalType() == types.ETString && tp.GetType() != mysql.TypeEnum && tp.GetType() != mysql.TypeSet { tp = types.NewFieldType(mysql.TypeBlob) } - lowerBound, err = d.ConvertTo(sc, tp) + lowerBound, err = d.ConvertTo(sc.TypeCtx(), tp) if err != nil { return nil, errors.Trace(err) } d = rows[i].GetDatum(3, &fields[3].Column.FieldType) - upperBound, err = d.ConvertTo(sc, tp) + upperBound, err = d.ConvertTo(sc.TypeCtx(), tp) if err != nil { return nil, errors.Trace(err) } diff --git a/pkg/statistics/handle/storage/save.go b/pkg/statistics/handle/storage/save.go index c69a6633ee3f9..179968af0beed 100644 --- a/pkg/statistics/handle/storage/save.go +++ b/pkg/statistics/handle/storage/save.go @@ -90,7 +90,7 @@ func saveBucketsToStorage(sctx sessionctx.Context, tableID int64, isIndex int, h count -= hg.Buckets[j-1].Count } var upperBound types.Datum - upperBound, err = hg.GetUpper(j).ConvertTo(sc, types.NewFieldType(mysql.TypeBlob)) + upperBound, err = hg.GetUpper(j).ConvertTo(sc.TypeCtx(), types.NewFieldType(mysql.TypeBlob)) if err != nil { return } @@ -98,7 +98,7 @@ func saveBucketsToStorage(sctx sessionctx.Context, tableID int64, isIndex int, h lastAnalyzePos = upperBound.GetBytes() } var lowerBound types.Datum - lowerBound, err = hg.GetLower(j).ConvertTo(sc, types.NewFieldType(mysql.TypeBlob)) + lowerBound, err = hg.GetLower(j).ConvertTo(sc.TypeCtx(), types.NewFieldType(mysql.TypeBlob)) if err != nil { return } diff --git a/pkg/statistics/handle/storage/stats_read_writer.go b/pkg/statistics/handle/storage/stats_read_writer.go index edc6043c46578..b187b7b1e557b 100644 --- a/pkg/statistics/handle/storage/stats_read_writer.go +++ b/pkg/statistics/handle/storage/stats_read_writer.go @@ -86,7 +86,7 @@ func (s *statsReadWriter) InsertColStats2KV(physicalID int64, colInfos []*model. count := req.GetRow(0).GetInt64(0) for _, colInfo := range colInfos { value := types.NewDatum(colInfo.GetOriginDefaultValue()) - value, err = value.ConvertTo(sctx.GetSessionVars().StmtCtx, &colInfo.FieldType) + value, err = value.ConvertTo(sctx.GetSessionVars().StmtCtx.TypeCtx(), &colInfo.FieldType) if err != nil { return err } @@ -100,7 +100,7 @@ func (s *statsReadWriter) InsertColStats2KV(physicalID int64, colInfos []*model. if _, err := util.Exec(sctx, "insert into mysql.stats_histograms (version, table_id, is_index, hist_id, distinct_count, tot_col_size) values (%?, %?, 0, %?, 1, %?)", startTS, physicalID, colInfo.ID, int64(len(value.GetBytes()))*count); err != nil { return err } - value, err = value.ConvertTo(sctx.GetSessionVars().StmtCtx, types.NewFieldType(mysql.TypeBlob)) + value, err = value.ConvertTo(sctx.GetSessionVars().StmtCtx.TypeCtx(), types.NewFieldType(mysql.TypeBlob)) if err != nil { return err } diff --git a/pkg/statistics/histogram.go b/pkg/statistics/histogram.go index 932c0bd986255..a94f1512720ab 100644 --- a/pkg/statistics/histogram.go +++ b/pkg/statistics/histogram.go @@ -202,7 +202,7 @@ func (hg *Histogram) ConvertTo(sc *stmtctx.StatementContext, tp *types.FieldType iter := chunk.NewIterator4Chunk(hg.Bounds) for row := iter.Begin(); row != iter.End(); row = iter.Next() { d := row.GetDatum(0, hg.Tp) - d, err := d.ConvertTo(sc, tp) + d, err := d.ConvertTo(sc.TypeCtx(), tp) if err != nil { return nil, errors.Trace(err) } @@ -832,7 +832,7 @@ func MergeHistograms(sc *stmtctx.StatementContext, lh *Histogram, rh *Histogram, } lh.NDV += rh.NDV lLen := lh.Len() - cmp, err := lh.GetUpper(lLen-1).Compare(sc, rh.GetLower(0), collate.GetBinaryCollator()) + cmp, err := lh.GetUpper(lLen-1).Compare(sc.TypeCtx(), rh.GetLower(0), collate.GetBinaryCollator()) if err != nil { return nil, errors.Trace(err) } @@ -1254,7 +1254,7 @@ func mergeBucketNDV(sc *stmtctx.StatementContext, left *bucket4Merging, right *b res.NDV = left.NDV return &res, nil } - upperCompare, err := right.upper.Compare(sc, left.upper, collate.GetBinaryCollator()) + upperCompare, err := right.upper.Compare(sc.TypeCtx(), left.upper, collate.GetBinaryCollator()) if err != nil { return nil, err } @@ -1268,7 +1268,7 @@ func mergeBucketNDV(sc *stmtctx.StatementContext, left *bucket4Merging, right *b // ___left__| // They have the same upper. if upperCompare == 0 { - lowerCompare, err := right.lower.Compare(sc, left.lower, collate.GetBinaryCollator()) + lowerCompare, err := right.lower.Compare(sc.TypeCtx(), left.lower, collate.GetBinaryCollator()) if err != nil { return nil, err } @@ -1299,7 +1299,7 @@ func mergeBucketNDV(sc *stmtctx.StatementContext, left *bucket4Merging, right *b // ____right___| // ____left__| // right.upper > left.upper - lowerCompareUpper, err := right.lower.Compare(sc, left.upper, collate.GetBinaryCollator()) + lowerCompareUpper, err := right.lower.Compare(sc.TypeCtx(), left.upper, collate.GetBinaryCollator()) if err != nil { return nil, err } @@ -1316,7 +1316,7 @@ func mergeBucketNDV(sc *stmtctx.StatementContext, left *bucket4Merging, right *b return &res, nil } upperRatio := calcFraction4Datums(right.lower, right.upper, left.upper) - lowerCompare, err := right.lower.Compare(sc, left.lower, collate.GetBinaryCollator()) + lowerCompare, err := right.lower.Compare(sc.TypeCtx(), left.lower, collate.GetBinaryCollator()) if err != nil { return nil, err } @@ -1370,7 +1370,7 @@ func mergePartitionBuckets(sc *stmtctx.StatementContext, buckets []*bucket4Mergi for i := len(buckets) - 1; i >= 0; i-- { totNDV += buckets[i].NDV res.Count += buckets[i].Count - compare, err := buckets[i].upper.Compare(sc, res.upper, collate.GetBinaryCollator()) + compare, err := buckets[i].upper.Compare(sc.TypeCtx(), res.upper, collate.GetBinaryCollator()) if err != nil { return nil, err } @@ -1426,7 +1426,7 @@ func MergePartitionHist2GlobalHist(sc *stmtctx.StatementContext, hists []*Histog continue } tmpValue := hist.GetLower(0) - res, err := tmpValue.Compare(sc, minValue, collate.GetBinaryCollator()) + res, err := tmpValue.Compare(sc.TypeCtx(), minValue, collate.GetBinaryCollator()) if err != nil { return nil, err } @@ -1455,7 +1455,7 @@ func MergePartitionHist2GlobalHist(sc *stmtctx.StatementContext, hists []*Histog minValue = d.Clone() continue } - res, err := d.Compare(sc, minValue, collate.GetBinaryCollator()) + res, err := d.Compare(sc.TypeCtx(), minValue, collate.GetBinaryCollator()) if err != nil { return nil, err } @@ -1480,14 +1480,14 @@ func MergePartitionHist2GlobalHist(sc *stmtctx.StatementContext, hists []*Histog var sortError error slices.SortFunc(buckets, func(i, j *bucket4Merging) int { - res, err := i.upper.Compare(sc, j.upper, collate.GetBinaryCollator()) + res, err := i.upper.Compare(sc.TypeCtx(), j.upper, collate.GetBinaryCollator()) if err != nil { sortError = err } if res != 0 { return res } - res, err = i.lower.Compare(sc, j.lower, collate.GetBinaryCollator()) + res, err = i.lower.Compare(sc.TypeCtx(), j.lower, collate.GetBinaryCollator()) if err != nil { sortError = err } @@ -1507,7 +1507,7 @@ func MergePartitionHist2GlobalHist(sc *stmtctx.StatementContext, hists []*Histog bucketNDV += buckets[i].NDV if sum >= totCount*bucketCount/expBucketNumber && sum-prevSum >= gBucketCountThreshold { for ; i > 0; i-- { // if the buckets have the same upper, we merge them into the same new buckets. - res, err := buckets[i-1].upper.Compare(sc, buckets[i].upper, collate.GetBinaryCollator()) + res, err := buckets[i-1].upper.Compare(sc.TypeCtx(), buckets[i].upper, collate.GetBinaryCollator()) if err != nil { return nil, err } diff --git a/pkg/statistics/main_test.go b/pkg/statistics/main_test.go index 818e3bf1e46e3..0a460ffa6139c 100644 --- a/pkg/statistics/main_test.go +++ b/pkg/statistics/main_test.go @@ -128,7 +128,7 @@ func createTestStatisticsSamples(t *testing.T) *testStatisticsSamples { for i := start; i < rc.count; i += 5 { rc.data[i].SetInt64(rc.data[i].GetInt64() + 2) } - require.NoError(t, types.SortDatums(sc, rc.data)) + require.NoError(t, types.SortDatums(sc.TypeCtx(), rc.data)) s.rc = rc diff --git a/pkg/statistics/sample.go b/pkg/statistics/sample.go index c813f7e7886ff..54dc0a4892d13 100644 --- a/pkg/statistics/sample.go +++ b/pkg/statistics/sample.go @@ -68,7 +68,7 @@ func SortSampleItems(sc *stmtctx.StatementContext, items []*SampleItem) ([]*Samp var err error slices.SortStableFunc(sortedItems, func(i, j *SampleItem) int { var cmp int - cmp, err = i.Value.Compare(sc, &j.Value, collate.GetBinaryCollator()) + cmp, err = i.Value.Compare(sc.TypeCtx(), &j.Value, collate.GetBinaryCollator()) if err != nil { return -1 } diff --git a/pkg/statistics/statistics_test.go b/pkg/statistics/statistics_test.go index d126d5922e8dc..b4eb55efa9afc 100644 --- a/pkg/statistics/statistics_test.go +++ b/pkg/statistics/statistics_test.go @@ -180,11 +180,11 @@ func TestMergeHistogram(t *testing.T) { require.Equal(t, tt.bucketNum, h.Len()) require.Equal(t, tt.leftNum+tt.rightNum, int64(h.TotalRowCount())) expectLower := types.NewIntDatum(tt.leftLower) - cmp, err := h.GetLower(0).Compare(sc, &expectLower, collate.GetBinaryCollator()) + cmp, err := h.GetLower(0).Compare(sc.TypeCtx(), &expectLower, collate.GetBinaryCollator()) require.NoError(t, err) require.Equal(t, 0, cmp) expectUpper := types.NewIntDatum(tt.rightLower + tt.rightNum - 1) - cmp, err = h.GetUpper(h.Len()-1).Compare(sc, &expectUpper, collate.GetBinaryCollator()) + cmp, err = h.GetUpper(h.Len()-1).Compare(sc.TypeCtx(), &expectUpper, collate.GetBinaryCollator()) require.NoError(t, err) require.Equal(t, 0, cmp) } diff --git a/pkg/store/mockstore/mockcopr/aggregate.go b/pkg/store/mockstore/mockcopr/aggregate.go index a7f1dccf63421..afb9ea6e0972f 100644 --- a/pkg/store/mockstore/mockcopr/aggregate.go +++ b/pkg/store/mockstore/mockcopr/aggregate.go @@ -286,7 +286,7 @@ func (e *streamAggExec) meetNewGroup(row [][]byte) (bool, error) { return false, errors.Trace(err) } if matched { - c, err := d.Compare(e.evalCtx.sc, &e.nextGroupByRow[i], e.groupByCollators[i]) + c, err := d.Compare(e.evalCtx.sc.TypeCtx(), &e.nextGroupByRow[i], e.groupByCollators[i]) if err != nil { return false, errors.Trace(err) } diff --git a/pkg/store/mockstore/mockcopr/topn.go b/pkg/store/mockstore/mockcopr/topn.go index 083d90d2a3876..ac53b08173db6 100644 --- a/pkg/store/mockstore/mockcopr/topn.go +++ b/pkg/store/mockstore/mockcopr/topn.go @@ -50,7 +50,7 @@ func (t *topNSorter) Less(i, j int) bool { v1 := t.rows[i].key[index] v2 := t.rows[j].key[index] - ret, err := v1.Compare(t.sc, &v2, collate.GetCollator(collate.ProtoToCollation(by.Expr.FieldType.Collate))) + ret, err := v1.Compare(t.sc.TypeCtx(), &v2, collate.GetCollator(collate.ProtoToCollation(by.Expr.FieldType.Collate))) if err != nil { t.err = errors.Trace(err) return true @@ -99,7 +99,7 @@ func (t *topNHeap) Less(i, j int) bool { v1 := t.rows[i].key[index] v2 := t.rows[j].key[index] - ret, err := v1.Compare(t.sc, &v2, collate.GetCollator(collate.ProtoToCollation(by.Expr.FieldType.Collate))) + ret, err := v1.Compare(t.sc.TypeCtx(), &v2, collate.GetCollator(collate.ProtoToCollation(by.Expr.FieldType.Collate))) if err != nil { t.err = errors.Trace(err) return true diff --git a/pkg/store/mockstore/unistore/cophandler/cop_handler_test.go b/pkg/store/mockstore/unistore/cophandler/cop_handler_test.go index 41eb69bb3e283..87aea1bd48dd6 100644 --- a/pkg/store/mockstore/unistore/cophandler/cop_handler_test.go +++ b/pkg/store/mockstore/unistore/cophandler/cop_handler_test.go @@ -352,10 +352,10 @@ func TestPointGet(t *testing.T) { // verify the returned rows value as input expectedRow := data.rows[handle] - eq, err := returnedRow[0].Compare(nil, &expectedRow[0], collate.GetBinaryCollator()) + eq, err := returnedRow[0].Compare(types.DefaultStmtNoWarningContext, &expectedRow[0], collate.GetBinaryCollator()) require.NoError(t, err) require.Equal(t, 0, eq) - eq, err = returnedRow[1].Compare(nil, &expectedRow[1], collate.GetBinaryCollator()) + eq, err = returnedRow[1].Compare(types.DefaultStmtNoWarningContext, &expectedRow[1], collate.GetBinaryCollator()) require.NoError(t, err) require.Equal(t, 0, eq) } diff --git a/pkg/store/mockstore/unistore/cophandler/mpp_exec.go b/pkg/store/mockstore/unistore/cophandler/mpp_exec.go index c86b690298e7f..4ab16a8bb1dbd 100644 --- a/pkg/store/mockstore/unistore/cophandler/mpp_exec.go +++ b/pkg/store/mockstore/unistore/cophandler/mpp_exec.go @@ -850,7 +850,7 @@ type joinExec struct { } func (e *joinExec) getHashKey(keyCol types.Datum) (str string, err error) { - keyCol, err = keyCol.ConvertTo(e.sc, e.comKeyTp) + keyCol, err = keyCol.ConvertTo(e.sc.TypeCtx(), e.comKeyTp) if err != nil { return str, errors.Trace(err) } @@ -1076,7 +1076,7 @@ func (e *aggExec) processAllRows() (*chunk.Chunk, error) { result := agg.GetResult(aggCtxs[i]) if e.fieldTypes[i].GetType() == mysql.TypeLonglong && result.Kind() == types.KindMysqlDecimal { var err error - result, err = result.ConvertTo(e.sc, e.fieldTypes[i]) + result, err = result.ConvertTo(e.sc.TypeCtx(), e.fieldTypes[i]) if err != nil { return nil, errors.Trace(err) } diff --git a/pkg/store/mockstore/unistore/cophandler/topn.go b/pkg/store/mockstore/unistore/cophandler/topn.go index de2798a2902b9..80f0661710b23 100644 --- a/pkg/store/mockstore/unistore/cophandler/topn.go +++ b/pkg/store/mockstore/unistore/cophandler/topn.go @@ -53,7 +53,7 @@ func (t *topNSorter) Less(i, j int) bool { v1 := t.rows[i].key[index] v2 := t.rows[j].key[index] - ret, err := v1.Compare(t.sc, &v2, collate.GetCollator(collate.ProtoToCollation(by.Expr.FieldType.Collate))) + ret, err := v1.Compare(t.sc.TypeCtx(), &v2, collate.GetCollator(collate.ProtoToCollation(by.Expr.FieldType.Collate))) if err != nil { t.err = errors.Trace(err) return true @@ -107,7 +107,7 @@ func (t *topNHeap) Less(i, j int) bool { if expression.FieldTypeFromPB(by.GetExpr().GetFieldType()).GetType() == mysql.TypeEnum { ret = cmp.Compare(v1.GetUint64(), v2.GetUint64()) } else { - ret, err = v1.Compare(t.sc, &v2, collate.GetCollator(collate.ProtoToCollation(by.Expr.FieldType.Collate))) + ret, err = v1.Compare(t.sc.TypeCtx(), &v2, collate.GetCollator(collate.ProtoToCollation(by.Expr.FieldType.Collate))) if err != nil { t.err = errors.Trace(err) return true diff --git a/pkg/table/column.go b/pkg/table/column.go index 30ee46b96d7f7..1b3a6a5eaf7bd 100644 --- a/pkg/table/column.go +++ b/pkg/table/column.go @@ -295,7 +295,7 @@ func handleZeroDatetime(ctx sessionctx.Context, col *model.ColumnInfo, casted ty // TODO: change the third arg to TypeField. Not pass ColumnInfo. func CastValue(ctx sessionctx.Context, val types.Datum, col *model.ColumnInfo, returnErr, forceIgnoreTruncate bool) (casted types.Datum, err error) { sc := ctx.GetSessionVars().StmtCtx - casted, err = val.ConvertTo(sc, &col.FieldType) + casted, err = val.ConvertTo(sc.TypeCtx(), &col.FieldType) // TODO: make sure all truncate errors are handled by ConvertTo. if returnErr && err != nil { return casted, err diff --git a/pkg/table/column_test.go b/pkg/table/column_test.go index 48caa4257539d..c951fdf60e79f 100644 --- a/pkg/table/column_test.go +++ b/pkg/table/column_test.go @@ -132,7 +132,7 @@ func TestHandleBadNull(t *testing.T) { d := types.Datum{} err := col.HandleBadNull(&d, sc, 0) require.NoError(t, err) - cmp, err := d.Compare(sc, &types.Datum{}, collate.GetBinaryCollator()) + cmp, err := d.Compare(sc.TypeCtx(), &types.Datum{}, collate.GetBinaryCollator()) require.NoError(t, err) require.Equal(t, 0, cmp) @@ -262,7 +262,7 @@ func TestGetZeroValue(t *testing.T) { colInfo := &model.ColumnInfo{FieldType: *tt.ft} zv := GetZeroValue(colInfo) require.Equal(t, tt.value.Kind(), zv.Kind()) - cmp, err := zv.Compare(sc, &tt.value, collate.GetCollator(tt.ft.GetCollate())) + cmp, err := zv.Compare(sc.TypeCtx(), &tt.value, collate.GetCollator(tt.ft.GetCollate())) require.NoError(t, err) require.Equal(t, 0, cmp) }) diff --git a/pkg/table/tables/mutation_checker.go b/pkg/table/tables/mutation_checker.go index 3ecb0b39080c8..3ef2c42730b1e 100644 --- a/pkg/table/tables/mutation_checker.go +++ b/pkg/table/tables/mutation_checker.go @@ -300,7 +300,7 @@ func checkRowInsertionConsistency( for columnID, decodedDatum := range decodedData { inputDatum := rowToInsert[columnIDToInfo[columnID].Offset] - cmp, err := decodedDatum.Compare(sessVars.StmtCtx, &inputDatum, collate.GetCollator(decodedDatum.Collation())) + cmp, err := decodedDatum.Compare(sessVars.StmtCtx.TypeCtx(), &inputDatum, collate.GetCollator(decodedDatum.Collation())) if err != nil { return errors.Trace(err) } @@ -398,7 +398,7 @@ func CompareIndexAndVal(sctx *stmtctx.StatementContext, rowVal types.Datum, idxV count := bj.GetElemCount() for elemIdx := 0; elemIdx < count; elemIdx++ { jsonDatum := types.NewJSONDatum(bj.ArrayGetElem(elemIdx)) - cmpRes, err = jsonDatum.Compare(sctx, &idxVal, collate.GetBinaryCollator()) + cmpRes, err = jsonDatum.Compare(sctx.TypeCtx(), &idxVal, collate.GetBinaryCollator()) if err != nil { return 0, errors.Trace(err) } @@ -407,7 +407,7 @@ func CompareIndexAndVal(sctx *stmtctx.StatementContext, rowVal types.Datum, idxV } } } else { - cmpRes, err = idxVal.Compare(sctx, &rowVal, collator) + cmpRes, err = idxVal.Compare(sctx.TypeCtx(), &rowVal, collator) } return cmpRes, err } diff --git a/pkg/table/tables/partition.go b/pkg/table/tables/partition.go index a7af6b55b5239..53178d07decdb 100644 --- a/pkg/table/tables/partition.go +++ b/pkg/table/tables/partition.go @@ -1103,7 +1103,7 @@ func (lp *ForListColumnPruning) genConstExprKey(ctx sessionctx.Context, sc *stmt } func (lp *ForListColumnPruning) genKey(sc *stmtctx.StatementContext, v types.Datum) ([]byte, error) { - v, err := v.ConvertTo(sc, lp.valueTp) + v, err := v.ConvertTo(sc.TypeCtx(), lp.valueTp) if err != nil { return nil, errors.Trace(err) } @@ -1456,7 +1456,7 @@ func (t *partitionedTable) locateHashPartition(ctx sessionctx.Context, partExpr data = r[col.Index] default: var err error - data, err = r[col.Index].ConvertTo(ctx.GetSessionVars().StmtCtx, types.NewFieldType(mysql.TypeLong)) + data, err = r[col.Index].ConvertTo(ctx.GetSessionVars().StmtCtx.TypeCtx(), types.NewFieldType(mysql.TypeLong)) if err != nil { return 0, err } diff --git a/pkg/tablecodec/tablecodec_test.go b/pkg/tablecodec/tablecodec_test.go index 79617385232a6..b3e133f74f159 100644 --- a/pkg/tablecodec/tablecodec_test.go +++ b/pkg/tablecodec/tablecodec_test.go @@ -117,7 +117,7 @@ func TestRowCodec(t *testing.T) { for i, col := range cols { v, ok := r[col.id] require.True(t, ok) - equal, err1 := v.Compare(sc, &row[i], collate.GetBinaryCollator()) + equal, err1 := v.Compare(sc.TypeCtx(), &row[i], collate.GetBinaryCollator()) require.NoError(t, err1) require.Equalf(t, 0, equal, "expect: %v, got %v", row[i], v) } @@ -131,7 +131,7 @@ func TestRowCodec(t *testing.T) { for i, col := range cols { v, ok := r[col.id] require.True(t, ok) - equal, err1 := v.Compare(sc, &row[i], collate.GetBinaryCollator()) + equal, err1 := v.Compare(sc.TypeCtx(), &row[i], collate.GetBinaryCollator()) require.NoError(t, err1) require.Equal(t, 0, equal) } @@ -149,7 +149,7 @@ func TestRowCodec(t *testing.T) { } v, ok := r[col.id] require.True(t, ok) - equal, err1 := v.Compare(sc, &row[i], collate.GetBinaryCollator()) + equal, err1 := v.Compare(sc.TypeCtx(), &row[i], collate.GetBinaryCollator()) require.NoError(t, err1) require.Equal(t, 0, equal) } @@ -177,7 +177,7 @@ func TestDecodeColumnValue(t *testing.T) { tp := types.NewFieldType(mysql.TypeTimestamp) d1, err := DecodeColumnValue(bs, tp, sc.TimeZone()) require.NoError(t, err) - cmp, err := d1.Compare(sc, &d, collate.GetBinaryCollator()) + cmp, err := d1.Compare(sc.TypeCtx(), &d, collate.GetBinaryCollator()) require.NoError(t, err) require.Equal(t, 0, cmp) @@ -194,7 +194,7 @@ func TestDecodeColumnValue(t *testing.T) { tp.SetElems(elems) d1, err = DecodeColumnValue(bs, tp, sc.TimeZone()) require.NoError(t, err) - cmp, err = d1.Compare(sc, &d, collate.GetCollator(tp.GetCollate())) + cmp, err = d1.Compare(sc.TypeCtx(), &d, collate.GetCollator(tp.GetCollate())) require.NoError(t, err) require.Equal(t, 0, cmp) @@ -209,7 +209,7 @@ func TestDecodeColumnValue(t *testing.T) { tp.SetFlen(24) d1, err = DecodeColumnValue(bs, tp, sc.TimeZone()) require.NoError(t, err) - cmp, err = d1.Compare(sc, &d, collate.GetBinaryCollator()) + cmp, err = d1.Compare(sc.TypeCtx(), &d, collate.GetBinaryCollator()) require.NoError(t, err) require.Equal(t, 0, cmp) @@ -223,7 +223,7 @@ func TestDecodeColumnValue(t *testing.T) { tp = types.NewFieldType(mysql.TypeEnum) d1, err = DecodeColumnValue(bs, tp, sc.TimeZone()) require.NoError(t, err) - cmp, err = d1.Compare(sc, &d, collate.GetCollator(tp.GetCollate())) + cmp, err = d1.Compare(sc.TypeCtx(), &d, collate.GetCollator(tp.GetCollate())) require.NoError(t, err) require.Equal(t, 0, cmp) } @@ -234,7 +234,7 @@ func TestUnflattenDatums(t *testing.T) { tps := []*types.FieldType{types.NewFieldType(mysql.TypeLonglong)} output, err := UnflattenDatums(input, tps, sc.TimeZone()) require.NoError(t, err) - cmp, err := input[0].Compare(sc, &output[0], collate.GetBinaryCollator()) + cmp, err := input[0].Compare(sc.TypeCtx(), &output[0], collate.GetBinaryCollator()) require.NoError(t, err) require.Equal(t, 0, cmp) @@ -243,7 +243,7 @@ func TestUnflattenDatums(t *testing.T) { tps[0].SetCollate("utf8mb4_unicode_ci") output, err = UnflattenDatums(input, tps, sc.TimeZone()) require.NoError(t, err) - cmp, err = input[0].Compare(sc, &output[0], collate.GetBinaryCollator()) + cmp, err = input[0].Compare(sc.TypeCtx(), &output[0], collate.GetBinaryCollator()) require.NoError(t, err) require.Equal(t, 0, cmp) require.Equal(t, "utf8mb4_unicode_ci", output[0].Collation()) @@ -292,7 +292,7 @@ func TestTimeCodec(t *testing.T) { for i, col := range cols { v, ok := r[col.id] require.True(t, ok) - equal, err1 := v.Compare(sc, &row[i], collate.GetBinaryCollator()) + equal, err1 := v.Compare(sc.TypeCtx(), &row[i], collate.GetBinaryCollator()) require.Nil(t, err1) require.Equal(t, 0, equal) } diff --git a/pkg/testkit/testutil/require.go b/pkg/testkit/testutil/require.go index 02acb7d604b55..876e9ffdf10b1 100644 --- a/pkg/testkit/testutil/require.go +++ b/pkg/testkit/testutil/require.go @@ -21,7 +21,6 @@ import ( "testing" "github.com/pingcap/tidb/pkg/kv" - "github.com/pingcap/tidb/pkg/sessionctx/stmtctx" "github.com/pingcap/tidb/pkg/types" "github.com/pingcap/tidb/pkg/util/collate" "github.com/stretchr/testify/require" @@ -29,8 +28,7 @@ import ( // DatumEqual verifies that the actual value is equal to the expected value. For string datum, they are compared by the binary collation. func DatumEqual(t testing.TB, expected, actual types.Datum, msgAndArgs ...interface{}) { - sc := stmtctx.NewStmtCtx() - res, err := actual.Compare(sc, &expected, collate.GetBinaryCollator()) + res, err := actual.Compare(types.DefaultStmtNoWarningContext, &expected, collate.GetBinaryCollator()) require.NoError(t, err, msgAndArgs) require.Zero(t, res, msgAndArgs) } diff --git a/pkg/types/BUILD.bazel b/pkg/types/BUILD.bazel index f2be108b18850..416a8f81070f8 100644 --- a/pkg/types/BUILD.bazel +++ b/pkg/types/BUILD.bazel @@ -36,6 +36,7 @@ go_library( "overflow.go", "set.go", "time.go", + "truncate.go", ], importpath = "github.com/pingcap/tidb/pkg/types", visibility = [ @@ -50,11 +51,10 @@ go_library( "//pkg/parser/opcode", "//pkg/parser/terror", "//pkg/parser/types", - "//pkg/sessionctx/stmtctx", - "//pkg/types/context", "//pkg/util/collate", "//pkg/util/dbterror", "//pkg/util/hack", + "//pkg/util/intest", "//pkg/util/kvcache", "//pkg/util/logutil", "//pkg/util/mathutil", @@ -75,6 +75,7 @@ go_test( "binary_literal_test.go", "compare_test.go", "const_test.go", + "context_test.go", "convert_test.go", "core_time_test.go", "datum_test.go", @@ -103,7 +104,6 @@ go_test( "//pkg/parser/charset", "//pkg/parser/mysql", "//pkg/parser/terror", - "//pkg/sessionctx/stmtctx", "//pkg/testkit/testsetup", "//pkg/util/collate", "//pkg/util/hack", diff --git a/pkg/types/compare_test.go b/pkg/types/compare_test.go index b5d76497ec05e..1e08f58cfead8 100644 --- a/pkg/types/compare_test.go +++ b/pkg/types/compare_test.go @@ -20,7 +20,6 @@ import ( "time" "github.com/pingcap/tidb/pkg/parser/mysql" - "github.com/pingcap/tidb/pkg/sessionctx/stmtctx" "github.com/pingcap/tidb/pkg/util/collate" "github.com/stretchr/testify/require" ) @@ -146,11 +145,10 @@ func TestCompare(t *testing.T) { } func compareForTest(a, b interface{}) (int, error) { - sc := stmtctx.NewStmtCtx() - sc.SetTypeFlags(sc.TypeFlags().WithIgnoreTruncateErr(true)) + ctx := DefaultStmtNoWarningContext.WithFlags(DefaultStmtFlags.WithIgnoreTruncateErr(true)) aDatum := NewDatum(a) bDatum := NewDatum(b) - return aDatum.Compare(sc, &bDatum, collate.GetBinaryCollator()) + return aDatum.Compare(ctx, &bDatum, collate.GetBinaryCollator()) } func TestCompareDatum(t *testing.T) { @@ -168,14 +166,13 @@ func TestCompareDatum(t *testing.T) { {Datum{}, MinNotNullDatum(), -1}, {MinNotNullDatum(), MaxValueDatum(), -1}, } - sc := stmtctx.NewStmtCtx() - sc.SetTypeFlags(sc.TypeFlags().WithIgnoreTruncateErr(true)) + ctx := DefaultStmtNoWarningContext.WithFlags(DefaultStmtFlags.WithIgnoreTruncateErr(true)) for i, tt := range cmpTbl { - ret, err := tt.lhs.Compare(sc, &tt.rhs, collate.GetBinaryCollator()) + ret, err := tt.lhs.Compare(ctx, &tt.rhs, collate.GetBinaryCollator()) require.NoError(t, err) require.Equal(t, tt.ret, ret, "%d %v %v", i, tt.lhs, tt.rhs) - ret, err = tt.rhs.Compare(sc, &tt.lhs, collate.GetBinaryCollator()) + ret, err = tt.rhs.Compare(ctx, &tt.lhs, collate.GetBinaryCollator()) require.NoError(t, err) require.Equal(t, -tt.ret, ret, "%d %v %v", i, tt.lhs, tt.rhs) } diff --git a/pkg/types/context.go b/pkg/types/context.go index a0d6db5c3c77c..b969e4801f553 100644 --- a/pkg/types/context.go +++ b/pkg/types/context.go @@ -15,29 +15,238 @@ package types import ( - "github.com/pingcap/tidb/pkg/types/context" + "time" + + "github.com/pingcap/tidb/pkg/util/intest" ) -// TODO: move a contents in `types/context/context.go` to this file after refactor finished. -// Because package `types` has a dependency on `sessionctx/stmtctx`, we need a separate package `type/context` to define -// context objects during refactor works. +// StrictFlags is a flags with a fields unset and has the most strict behavior. +const StrictFlags Flags = 0 -// Context is an alias of `context.Context` -type Context = context.Context +// Flags indicate how to handle the conversion of a value. +type Flags uint16 -// Flags is an alias of `Flags` -type Flags = context.Flags +const ( + // FlagIgnoreTruncateErr indicates to ignore the truncate error. + // If this flag is set, `FlagTruncateAsWarning` will be ignored. + FlagIgnoreTruncateErr Flags = 1 << iota + // FlagTruncateAsWarning indicates to append the truncate error to warnings instead of returning it to user. + FlagTruncateAsWarning + // FlagAllowNegativeToUnsigned indicates to allow the casting from negative to unsigned int. + // When this flag is not set by default, casting a negative value to unsigned results an overflow error. + // Otherwise, a negative value will be cast to the corresponding unsigned value without any error. + // For example, when casting -1 to an unsigned bigint with `FlagAllowNegativeToUnsigned` set, + // we will get `18446744073709551615` which is the biggest unsigned value. + FlagAllowNegativeToUnsigned + // FlagIgnoreZeroDateErr indicates to ignore the zero-date error. + // See: https://dev.mysql.com/doc/refman/8.0/en/sql-mode.html#sqlmode_no_zero_date for details about the "zero-date" error. + // If this flag is set, `FlagZeroDateAsWarning` will be ignored. + // + // TODO: `FlagIgnoreZeroDateErr` and `FlagZeroDateAsWarning` don't represent the comments right now, because the + // errors related with `time` and `duration` are handled directly according to SQL mode in many places (expression, + // ddl ...). These error handling will be refined in the future. Currently, the `FlagZeroDateAsWarning` is not used, + // and the `FlagIgnoreZeroDateErr` is used to allow or disallow casting zero to date in `alter` statement. See #25728 + // This flag is the reverse of `NoZeroDate` in #30507. It's set to `true` for most context, and is only set to + // `false` for `alter` (and `create`) statements. + FlagIgnoreZeroDateErr + // FlagIgnoreZeroInDateErr indicates to ignore the zero-in-date error. + // See: https://dev.mysql.com/doc/refman/8.0/en/sql-mode.html#sqlmode_no_zero_in_date for details about the "zero-in-date" error. + FlagIgnoreZeroInDateErr + // FlagIgnoreInvalidDateErr indicates to ignore the invalid-date error. + // See: https://dev.mysql.com/doc/refman/8.0/en/sql-mode.html#sqlmode_allow_invalid_dates for details about the "invalid-date" error. + FlagIgnoreInvalidDateErr + // FlagSkipASCIICheck indicates to skip the ASCII check when converting the value to an ASCII string. + FlagSkipASCIICheck + // FlagSkipUTF8Check indicates to skip the UTF8 check when converting the value to an UTF8MB3 string. + FlagSkipUTF8Check + // FlagSkipUTF8MB4Check indicates to skip the UTF8MB4 check when converting the value to an UTF8 string. + FlagSkipUTF8MB4Check +) -// StrictFlags is a flags with a fields unset and has the most strict behavior. -const StrictFlags = context.StrictFlags +// AllowNegativeToUnsigned indicates whether the flag `FlagAllowNegativeToUnsigned` is set +func (f Flags) AllowNegativeToUnsigned() bool { + return f&FlagAllowNegativeToUnsigned != 0 +} + +// WithAllowNegativeToUnsigned returns a new flags with `FlagAllowNegativeToUnsigned` set/unset according to the clip parameter +func (f Flags) WithAllowNegativeToUnsigned(clip bool) Flags { + if clip { + return f | FlagAllowNegativeToUnsigned + } + return f &^ FlagAllowNegativeToUnsigned +} + +// SkipASCIICheck indicates whether the flag `FlagSkipASCIICheck` is set +func (f Flags) SkipASCIICheck() bool { + return f&FlagSkipASCIICheck != 0 +} + +// WithSkipSACIICheck returns a new flags with `FlagSkipASCIICheck` set/unset according to the skip parameter +func (f Flags) WithSkipSACIICheck(skip bool) Flags { + if skip { + return f | FlagSkipASCIICheck + } + return f &^ FlagSkipASCIICheck +} + +// SkipUTF8Check indicates whether the flag `FlagSkipUTF8Check` is set +func (f Flags) SkipUTF8Check() bool { + return f&FlagSkipUTF8Check != 0 +} + +// WithSkipUTF8Check returns a new flags with `FlagSkipUTF8Check` set/unset according to the skip parameter +func (f Flags) WithSkipUTF8Check(skip bool) Flags { + if skip { + return f | FlagSkipUTF8Check + } + return f &^ FlagSkipUTF8Check +} + +// SkipUTF8MB4Check indicates whether the flag `FlagSkipUTF8MB4Check` is set +func (f Flags) SkipUTF8MB4Check() bool { + return f&FlagSkipUTF8MB4Check != 0 +} + +// WithSkipUTF8MB4Check returns a new flags with `FlagSkipUTF8MB4Check` set/unset according to the skip parameter +func (f Flags) WithSkipUTF8MB4Check(skip bool) Flags { + if skip { + return f | FlagSkipUTF8MB4Check + } + return f &^ FlagSkipUTF8MB4Check +} + +// IgnoreTruncateErr indicates whether the flag `FlagIgnoreTruncateErr` is set +func (f Flags) IgnoreTruncateErr() bool { + return f&FlagIgnoreTruncateErr != 0 +} + +// WithIgnoreTruncateErr returns a new flags with `FlagIgnoreTruncateErr` set/unset according to the skip parameter +func (f Flags) WithIgnoreTruncateErr(ignore bool) Flags { + if ignore { + return f | FlagIgnoreTruncateErr + } + return f &^ FlagIgnoreTruncateErr +} + +// TruncateAsWarning indicates whether the flag `FlagTruncateAsWarning` is set +func (f Flags) TruncateAsWarning() bool { + return f&FlagTruncateAsWarning != 0 +} + +// WithTruncateAsWarning returns a new flags with `FlagTruncateAsWarning` set/unset according to the skip parameter +func (f Flags) WithTruncateAsWarning(warn bool) Flags { + if warn { + return f | FlagTruncateAsWarning + } + return f &^ FlagTruncateAsWarning +} + +// IgnoreZeroInDate indicates whether the flag `FlagIgnoreZeroInData` is set +func (f Flags) IgnoreZeroInDate() bool { + return f&FlagIgnoreZeroInDateErr != 0 +} + +// WithIgnoreZeroInDate returns a new flags with `FlagIgnoreZeroInDateErr` set/unset according to the ignore parameter +func (f Flags) WithIgnoreZeroInDate(ignore bool) Flags { + if ignore { + return f | FlagIgnoreZeroInDateErr + } + return f &^ FlagIgnoreZeroInDateErr +} + +// IgnoreInvalidDateErr indicates whether the flag `FlagIgnoreInvalidDateErr` is set +func (f Flags) IgnoreInvalidDateErr() bool { + return f&FlagIgnoreInvalidDateErr != 0 +} + +// WithIgnoreInvalidDateErr returns a new flags with `FlagIgnoreInvalidDateErr` set/unset according to the ignore parameter +func (f Flags) WithIgnoreInvalidDateErr(ignore bool) Flags { + if ignore { + return f | FlagIgnoreInvalidDateErr + } + return f &^ FlagIgnoreInvalidDateErr +} + +// IgnoreZeroDateErr indicates whether the flag `FlagIgnoreZeroDateErr` is set +func (f Flags) IgnoreZeroDateErr() bool { + return f&FlagIgnoreZeroDateErr != 0 +} + +// WithIgnoreZeroDateErr returns a new flags with `FlagIgnoreZeroDateErr` set/unset according to the ignore parameter +func (f Flags) WithIgnoreZeroDateErr(ignore bool) Flags { + if ignore { + return f | FlagIgnoreZeroDateErr + } + return f &^ FlagIgnoreZeroDateErr +} + +// Context provides the information when converting between different types. +type Context struct { + flags Flags + loc *time.Location + appendWarningFn func(err error) +} // NewContext creates a new `Context` -var NewContext = context.NewContext +func NewContext(flags Flags, loc *time.Location, appendWarningFn func(err error)) Context { + intest.Assert(loc != nil && appendWarningFn != nil) + return Context{ + flags: flags, + loc: loc, + appendWarningFn: appendWarningFn, + } +} + +// Flags returns the flags of the context +func (c *Context) Flags() Flags { + return c.flags +} + +// WithFlags returns a new context with the flags set to the given value +func (c *Context) WithFlags(f Flags) Context { + ctx := *c + ctx.flags = f + return ctx +} + +// WithLocation returns a new context with the given location +func (c *Context) WithLocation(loc *time.Location) Context { + intest.Assert(loc) + ctx := *c + ctx.loc = loc + return ctx +} + +// Location returns the location of the context +func (c *Context) Location() *time.Location { + intest.Assert(c.loc) + if c.loc == nil { + // c.loc should always not be nil, just make the code safe here. + return time.UTC + } + return c.loc +} + +// AppendWarning appends the error to warning. If the inner `appendWarningFn` is nil, do nothing. +func (c *Context) AppendWarning(err error) { + intest.Assert(c.appendWarningFn != nil) + if fn := c.appendWarningFn; fn != nil { + // appendWarningFn should always not be nil, check fn != nil here to just make code safe. + fn(err) + } +} + +// AppendWarningFunc returns the inner `appendWarningFn` +func (c *Context) AppendWarningFunc() func(err error) { + return c.appendWarningFn +} // DefaultStmtFlags is the default flags for statement context with the flag `FlagAllowNegativeToUnsigned` set. // TODO: make DefaultStmtFlags to be equal with StrictFlags, and setting flag `FlagAllowNegativeToUnsigned` // is only for make the code to be equivalent with the old implement during refactoring. -const DefaultStmtFlags = context.DefaultStmtFlags +const DefaultStmtFlags = StrictFlags | FlagAllowNegativeToUnsigned | FlagIgnoreZeroDateErr -// DefaultStmtNoWarningContext is an alias of `DefaultStmtNoWarningContext` -var DefaultStmtNoWarningContext = context.DefaultStmtNoWarningContext +// DefaultStmtNoWarningContext is the context with default statement flags without any other special configuration +var DefaultStmtNoWarningContext = NewContext(DefaultStmtFlags, time.UTC, func(_ error) { + // the error is ignored +}) diff --git a/pkg/types/context/BUILD.bazel b/pkg/types/context/BUILD.bazel deleted file mode 100644 index 5d9a1d17e6e49..0000000000000 --- a/pkg/types/context/BUILD.bazel +++ /dev/null @@ -1,25 +0,0 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test") - -go_library( - name = "context", - srcs = [ - "context.go", - "truncate.go", - ], - importpath = "github.com/pingcap/tidb/pkg/types/context", - visibility = ["//visibility:public"], - deps = [ - "//pkg/errno", - "//pkg/util/intest", - "@com_github_pingcap_errors//:errors", - ], -) - -go_test( - name = "context_test", - timeout = "short", - srcs = ["context_test.go"], - embed = [":context"], - flaky = True, - deps = ["@com_github_stretchr_testify//require"], -) diff --git a/pkg/types/context/context.go b/pkg/types/context/context.go deleted file mode 100644 index d315cb5019c27..0000000000000 --- a/pkg/types/context/context.go +++ /dev/null @@ -1,252 +0,0 @@ -// Copyright 2023 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package context - -import ( - "time" - - "github.com/pingcap/tidb/pkg/util/intest" -) - -// StrictFlags is a flags with a fields unset and has the most strict behavior. -const StrictFlags Flags = 0 - -// Flags indicate how to handle the conversion of a value. -type Flags uint16 - -const ( - // FlagIgnoreTruncateErr indicates to ignore the truncate error. - // If this flag is set, `FlagTruncateAsWarning` will be ignored. - FlagIgnoreTruncateErr Flags = 1 << iota - // FlagTruncateAsWarning indicates to append the truncate error to warnings instead of returning it to user. - FlagTruncateAsWarning - // FlagAllowNegativeToUnsigned indicates to allow the casting from negative to unsigned int. - // When this flag is not set by default, casting a negative value to unsigned results an overflow error. - // Otherwise, a negative value will be cast to the corresponding unsigned value without any error. - // For example, when casting -1 to an unsigned bigint with `FlagAllowNegativeToUnsigned` set, - // we will get `18446744073709551615` which is the biggest unsigned value. - FlagAllowNegativeToUnsigned - // FlagIgnoreZeroDateErr indicates to ignore the zero-date error. - // See: https://dev.mysql.com/doc/refman/8.0/en/sql-mode.html#sqlmode_no_zero_date for details about the "zero-date" error. - // If this flag is set, `FlagZeroDateAsWarning` will be ignored. - // - // TODO: `FlagIgnoreZeroDateErr` and `FlagZeroDateAsWarning` don't represent the comments right now, because the - // errors related with `time` and `duration` are handled directly according to SQL mode in many places (expression, - // ddl ...). These error handling will be refined in the future. Currently, the `FlagZeroDateAsWarning` is not used, - // and the `FlagIgnoreZeroDateErr` is used to allow or disallow casting zero to date in `alter` statement. See #25728 - // This flag is the reverse of `NoZeroDate` in #30507. It's set to `true` for most context, and is only set to - // `false` for `alter` (and `create`) statements. - FlagIgnoreZeroDateErr - // FlagIgnoreZeroInDateErr indicates to ignore the zero-in-date error. - // See: https://dev.mysql.com/doc/refman/8.0/en/sql-mode.html#sqlmode_no_zero_in_date for details about the "zero-in-date" error. - FlagIgnoreZeroInDateErr - // FlagIgnoreInvalidDateErr indicates to ignore the invalid-date error. - // See: https://dev.mysql.com/doc/refman/8.0/en/sql-mode.html#sqlmode_allow_invalid_dates for details about the "invalid-date" error. - FlagIgnoreInvalidDateErr - // FlagSkipASCIICheck indicates to skip the ASCII check when converting the value to an ASCII string. - FlagSkipASCIICheck - // FlagSkipUTF8Check indicates to skip the UTF8 check when converting the value to an UTF8MB3 string. - FlagSkipUTF8Check - // FlagSkipUTF8MB4Check indicates to skip the UTF8MB4 check when converting the value to an UTF8 string. - FlagSkipUTF8MB4Check -) - -// AllowNegativeToUnsigned indicates whether the flag `FlagAllowNegativeToUnsigned` is set -func (f Flags) AllowNegativeToUnsigned() bool { - return f&FlagAllowNegativeToUnsigned != 0 -} - -// WithAllowNegativeToUnsigned returns a new flags with `FlagAllowNegativeToUnsigned` set/unset according to the clip parameter -func (f Flags) WithAllowNegativeToUnsigned(clip bool) Flags { - if clip { - return f | FlagAllowNegativeToUnsigned - } - return f &^ FlagAllowNegativeToUnsigned -} - -// SkipASCIICheck indicates whether the flag `FlagSkipASCIICheck` is set -func (f Flags) SkipASCIICheck() bool { - return f&FlagSkipASCIICheck != 0 -} - -// WithSkipSACIICheck returns a new flags with `FlagSkipASCIICheck` set/unset according to the skip parameter -func (f Flags) WithSkipSACIICheck(skip bool) Flags { - if skip { - return f | FlagSkipASCIICheck - } - return f &^ FlagSkipASCIICheck -} - -// SkipUTF8Check indicates whether the flag `FlagSkipUTF8Check` is set -func (f Flags) SkipUTF8Check() bool { - return f&FlagSkipUTF8Check != 0 -} - -// WithSkipUTF8Check returns a new flags with `FlagSkipUTF8Check` set/unset according to the skip parameter -func (f Flags) WithSkipUTF8Check(skip bool) Flags { - if skip { - return f | FlagSkipUTF8Check - } - return f &^ FlagSkipUTF8Check -} - -// SkipUTF8MB4Check indicates whether the flag `FlagSkipUTF8MB4Check` is set -func (f Flags) SkipUTF8MB4Check() bool { - return f&FlagSkipUTF8MB4Check != 0 -} - -// WithSkipUTF8MB4Check returns a new flags with `FlagSkipUTF8MB4Check` set/unset according to the skip parameter -func (f Flags) WithSkipUTF8MB4Check(skip bool) Flags { - if skip { - return f | FlagSkipUTF8MB4Check - } - return f &^ FlagSkipUTF8MB4Check -} - -// IgnoreTruncateErr indicates whether the flag `FlagIgnoreTruncateErr` is set -func (f Flags) IgnoreTruncateErr() bool { - return f&FlagIgnoreTruncateErr != 0 -} - -// WithIgnoreTruncateErr returns a new flags with `FlagIgnoreTruncateErr` set/unset according to the skip parameter -func (f Flags) WithIgnoreTruncateErr(ignore bool) Flags { - if ignore { - return f | FlagIgnoreTruncateErr - } - return f &^ FlagIgnoreTruncateErr -} - -// TruncateAsWarning indicates whether the flag `FlagTruncateAsWarning` is set -func (f Flags) TruncateAsWarning() bool { - return f&FlagTruncateAsWarning != 0 -} - -// WithTruncateAsWarning returns a new flags with `FlagTruncateAsWarning` set/unset according to the skip parameter -func (f Flags) WithTruncateAsWarning(warn bool) Flags { - if warn { - return f | FlagTruncateAsWarning - } - return f &^ FlagTruncateAsWarning -} - -// IgnoreZeroInDate indicates whether the flag `FlagIgnoreZeroInData` is set -func (f Flags) IgnoreZeroInDate() bool { - return f&FlagIgnoreZeroInDateErr != 0 -} - -// WithIgnoreZeroInDate returns a new flags with `FlagIgnoreZeroInDateErr` set/unset according to the ignore parameter -func (f Flags) WithIgnoreZeroInDate(ignore bool) Flags { - if ignore { - return f | FlagIgnoreZeroInDateErr - } - return f &^ FlagIgnoreZeroInDateErr -} - -// IgnoreInvalidDateErr indicates whether the flag `FlagIgnoreInvalidDateErr` is set -func (f Flags) IgnoreInvalidDateErr() bool { - return f&FlagIgnoreInvalidDateErr != 0 -} - -// WithIgnoreInvalidDateErr returns a new flags with `FlagIgnoreInvalidDateErr` set/unset according to the ignore parameter -func (f Flags) WithIgnoreInvalidDateErr(ignore bool) Flags { - if ignore { - return f | FlagIgnoreInvalidDateErr - } - return f &^ FlagIgnoreInvalidDateErr -} - -// IgnoreZeroDateErr indicates whether the flag `FlagIgnoreZeroDateErr` is set -func (f Flags) IgnoreZeroDateErr() bool { - return f&FlagIgnoreZeroDateErr != 0 -} - -// WithIgnoreZeroDateErr returns a new flags with `FlagIgnoreZeroDateErr` set/unset according to the ignore parameter -func (f Flags) WithIgnoreZeroDateErr(ignore bool) Flags { - if ignore { - return f | FlagIgnoreZeroDateErr - } - return f &^ FlagIgnoreZeroDateErr -} - -// Context provides the information when converting between different types. -type Context struct { - flags Flags - loc *time.Location - appendWarningFn func(err error) -} - -// NewContext creates a new `Context` -func NewContext(flags Flags, loc *time.Location, appendWarningFn func(err error)) Context { - intest.Assert(loc != nil && appendWarningFn != nil) - return Context{ - flags: flags, - loc: loc, - appendWarningFn: appendWarningFn, - } -} - -// Flags returns the flags of the context -func (c *Context) Flags() Flags { - return c.flags -} - -// WithFlags returns a new context with the flags set to the given value -func (c *Context) WithFlags(f Flags) Context { - ctx := *c - ctx.flags = f - return ctx -} - -// WithLocation returns a new context with the given location -func (c *Context) WithLocation(loc *time.Location) Context { - intest.Assert(loc) - ctx := *c - ctx.loc = loc - return ctx -} - -// Location returns the location of the context -func (c *Context) Location() *time.Location { - intest.Assert(c.loc) - if c.loc == nil { - // c.loc should always not be nil, just make the code safe here. - return time.UTC - } - return c.loc -} - -// AppendWarning appends the error to warning. If the inner `appendWarningFn` is nil, do nothing. -func (c *Context) AppendWarning(err error) { - intest.Assert(c.appendWarningFn != nil) - if fn := c.appendWarningFn; fn != nil { - // appendWarningFn should always not be nil, check fn != nil here to just make code safe. - fn(err) - } -} - -// AppendWarningFunc returns the inner `appendWarningFn` -func (c *Context) AppendWarningFunc() func(err error) { - return c.appendWarningFn -} - -// DefaultStmtFlags is the default flags for statement context with the flag `FlagAllowNegativeToUnsigned` set. -// TODO: make DefaultStmtFlags to be equal with StrictFlags, and setting flag `FlagAllowNegativeToUnsigned` -// is only for make the code to be equivalent with the old implement during refactoring. -const DefaultStmtFlags = StrictFlags | FlagAllowNegativeToUnsigned | FlagIgnoreZeroDateErr - -// DefaultStmtNoWarningContext is the context with default statement flags without any other special configuration -var DefaultStmtNoWarningContext = NewContext(DefaultStmtFlags, time.UTC, func(_ error) { - // the error is ignored -}) diff --git a/pkg/types/context/context_test.go b/pkg/types/context_test.go similarity index 88% rename from pkg/types/context/context_test.go rename to pkg/types/context_test.go index 4b51cc9bb2ef3..0bd8d5cbcc811 100644 --- a/pkg/types/context/context_test.go +++ b/pkg/types/context_test.go @@ -12,10 +12,11 @@ // See the License for the specific language governing permissions and // limitations under the License. -package context +package types import ( "fmt" + "sync" "testing" "time" @@ -105,3 +106,26 @@ func TestSimpleOnOffFlags(t *testing.T) { require.False(t, c.readFn(f), msg) } } + +type warnStore struct { + sync.Mutex + warnings []error +} + +func (w *warnStore) AppendWarning(warn error) { + w.Lock() + defer w.Unlock() + + w.warnings = append(w.warnings, warn) +} + +func (w *warnStore) Reset() { + w.Lock() + defer w.Unlock() + + w.warnings = nil +} + +func (w *warnStore) GetWarnings() []error { + return w.warnings +} diff --git a/pkg/types/convert.go b/pkg/types/convert.go index 02435c0bed3b1..7a860843d6afd 100644 --- a/pkg/types/convert.go +++ b/pkg/types/convert.go @@ -26,7 +26,6 @@ import ( "github.com/pingcap/errors" "github.com/pingcap/tidb/pkg/parser/mysql" - "github.com/pingcap/tidb/pkg/sessionctx/stmtctx" "github.com/pingcap/tidb/pkg/util/hack" ) @@ -230,7 +229,7 @@ func convertScientificNotation(str string) (string, error) { } } -func convertDecimalStrToUint(sc *stmtctx.StatementContext, str string, upperBound uint64, tp byte) (uint64, error) { +func convertDecimalStrToUint(str string, upperBound uint64, tp byte) (uint64, error) { str, err := convertScientificNotation(str) if err != nil { return 0, err @@ -271,8 +270,8 @@ func convertDecimalStrToUint(sc *stmtctx.StatementContext, str string, upperBoun } // ConvertDecimalToUint converts a decimal to a uint by converting it to a string first to avoid float overflow (#10181). -func ConvertDecimalToUint(sc *stmtctx.StatementContext, d *MyDecimal, upperBound uint64, tp byte) (uint64, error) { - return convertDecimalStrToUint(sc, string(d.ToString()), upperBound, tp) +func ConvertDecimalToUint(d *MyDecimal, upperBound uint64, tp byte) (uint64, error) { + return convertDecimalStrToUint(string(d.ToString()), upperBound, tp) } // StrToInt converts a string to an integer at the best-effort. diff --git a/pkg/types/convert_test.go b/pkg/types/convert_test.go index 8578671d7de18..fed44d5670603 100644 --- a/pkg/types/convert_test.go +++ b/pkg/types/convert_test.go @@ -25,7 +25,6 @@ import ( "github.com/pingcap/tidb/pkg/parser/charset" "github.com/pingcap/tidb/pkg/parser/mysql" "github.com/pingcap/tidb/pkg/parser/terror" - "github.com/pingcap/tidb/pkg/sessionctx/stmtctx" "github.com/stretchr/testify/require" ) @@ -35,8 +34,7 @@ type invalidMockType struct { // Convert converts the val with type tp. func Convert(val interface{}, target *FieldType) (v interface{}, err error) { d := NewDatum(val) - sc := stmtctx.NewStmtCtxWithTimeZone(time.UTC) - ret, err := d.ConvertTo(sc, target) + ret, err := d.ConvertTo(DefaultStmtNoWarningContext, target) if err != nil { return ret.GetValue(), errors.Trace(err) } @@ -383,8 +381,7 @@ func TestConvertToString(t *testing.T) { ft.SetFlen(tt.flen) ft.SetCharset(tt.charset) inputDatum := NewStringDatum(tt.input) - sc := stmtctx.NewStmtCtx() - outputDatum, err := inputDatum.ConvertTo(sc, ft) + outputDatum, err := inputDatum.ConvertTo(DefaultStmtNoWarningContext, ft) if tt.input != tt.output { require.True(t, ErrDataTooLong.Equal(err), "flen: %d, charset: %s, input: %s, output: %s", tt.flen, tt.charset, tt.input, tt.output) } else { @@ -420,10 +417,9 @@ func TestConvertToStringWithCheck(t *testing.T) { ft.SetFlen(255) ft.SetCharset(tt.outputChs) inputDatum := NewStringDatum(tt.input) - sc := stmtctx.NewStmtCtx() - flags := tt.newFlags(sc.TypeFlags()) - sc.SetTypeFlags(flags) - outputDatum, err := inputDatum.ConvertTo(sc, ft) + ctx := DefaultStmtNoWarningContext + ctx = ctx.WithFlags(tt.newFlags(DefaultStmtFlags)) + outputDatum, err := inputDatum.ConvertTo(ctx, ft) if len(tt.output) == 0 { require.True(t, charset.ErrInvalidCharacterString.Equal(err), tt) } else { @@ -460,8 +456,7 @@ func TestConvertToBinaryString(t *testing.T) { ft.SetFlen(255) ft.SetCharset(tt.outputCharset) inputDatum := NewCollationStringDatum(tt.input, tt.inputCollate) - sc := stmtctx.NewStmtCtx() - outputDatum, err := inputDatum.ConvertTo(sc, ft) + outputDatum, err := inputDatum.ConvertTo(DefaultStmtNoWarningContext, ft) if len(tt.output) == 0 { require.True(t, charset.ErrInvalidCharacterString.Equal(err), tt) } else { @@ -555,34 +550,22 @@ func TestStrToNum(t *testing.T) { } func testSelectUpdateDeleteEmptyStringError(t *testing.T) { - testCases := []struct { - inSelect bool - inDelete bool - }{ - {true, false}, - {false, true}, - } - sc := stmtctx.NewStmtCtx() - sc.SetTypeFlags(sc.TypeFlags().WithTruncateAsWarning(true)) - for _, tc := range testCases { - sc.InSelectStmt = tc.inSelect - sc.InDeleteStmt = tc.inDelete + ctx := DefaultStmtNoWarningContext.WithFlags(DefaultStmtFlags.WithTruncateAsWarning(true)) - str := "" - expect := 0 + str := "" + expect := 0 - val, err := StrToInt(sc.TypeCtxOrDefault(), str, false) - require.NoError(t, err) - require.Equal(t, int64(expect), val) + val, err := StrToInt(ctx, str, false) + require.NoError(t, err) + require.Equal(t, int64(expect), val) - val1, err := StrToUint(sc.TypeCtxOrDefault(), str, false) - require.NoError(t, err) - require.Equal(t, uint64(expect), val1) + val1, err := StrToUint(ctx, str, false) + require.NoError(t, err) + require.Equal(t, uint64(expect), val1) - val2, err := StrToFloat(sc.TypeCtxOrDefault(), str, false) - require.NoError(t, err) - require.Equal(t, float64(expect), val2) - } + val2, err := StrToFloat(ctx, str, false) + require.NoError(t, err) + require.Equal(t, float64(expect), val2) } func TestFieldTypeToStr(t *testing.T) { @@ -600,10 +583,8 @@ func accept(t *testing.T, tp byte, value interface{}, unsigned bool, expected st ft.AddFlag(mysql.UnsignedFlag) } d := NewDatum(value) - sc := stmtctx.NewStmtCtx() - sc.SetTimeZone(time.UTC) - sc.SetTypeFlags(sc.TypeFlags().WithIgnoreTruncateErr(true)) - casted, err := d.ConvertTo(sc, ft) + ctx := DefaultStmtNoWarningContext.WithFlags(DefaultStmtFlags.WithIgnoreTruncateErr(true)) + casted, err := d.ConvertTo(ctx, ft) require.NoErrorf(t, err, "%v", ft) if casted.IsNull() { require.Equal(t, "", expected) @@ -628,8 +609,7 @@ func deny(t *testing.T, tp byte, value interface{}, unsigned bool, expected stri ft.AddFlag(mysql.UnsignedFlag) } d := NewDatum(value) - sc := stmtctx.NewStmtCtx() - casted, err := d.ConvertTo(sc, ft) + casted, err := d.ConvertTo(DefaultStmtNoWarningContext, ft) require.Error(t, err) if casted.IsNull() { require.Equal(t, "", expected) @@ -883,12 +863,11 @@ func TestGetValidInt(t *testing.T) { {"123e+", "123", true, true}, {"123de", "123", true, true}, } - sc := stmtctx.NewStmtCtx() - sc.SetTypeFlags(sc.TypeFlags().WithTruncateAsWarning(true)) - sc.InSelectStmt = true + warnings := &warnStore{} + ctx := NewContext(DefaultStmtFlags.WithTruncateAsWarning(true), time.UTC, warnings.AppendWarning) warningCount := 0 for i, tt := range tests { - prefix, err := getValidIntPrefix(sc.TypeCtxOrDefault(), tt.origin, false) + prefix, err := getValidIntPrefix(ctx, tt.origin, false) require.NoError(t, err) require.Equal(t, tt.valid, prefix) if tt.signed { @@ -897,13 +876,13 @@ func TestGetValidInt(t *testing.T) { _, err = strconv.ParseUint(prefix, 10, 64) } require.NoError(t, err) - warnings := sc.GetWarnings() + warn := warnings.GetWarnings() if tt.warning { - require.Lenf(t, warnings, warningCount+1, "%d", i) - require.True(t, terror.ErrorEqual(warnings[len(warnings)-1].Err, ErrTruncatedWrongVal)) + require.Lenf(t, warn, warningCount+1, "%d", i) + require.True(t, terror.ErrorEqual(warn[len(warn)-1], ErrTruncatedWrongVal)) warningCount++ } else { - require.Len(t, warnings, warningCount) + require.Len(t, warn, warningCount) } } @@ -927,10 +906,9 @@ func TestGetValidInt(t *testing.T) { {"123e+", "123", true}, {"123de", "123", true}, } - sc.SetTypeFlags(DefaultStmtFlags) - sc.InSelectStmt = false + ctx = ctx.WithFlags(DefaultStmtFlags) for _, tt := range tests2 { - prefix, err := getValidIntPrefix(sc.TypeCtxOrDefault(), tt.origin, false) + prefix, err := getValidIntPrefix(ctx, tt.origin, false) if tt.warning { require.True(t, terror.ErrorEqual(err, ErrTruncatedWrongVal)) } else { @@ -1017,12 +995,12 @@ func TestConvertTime(t *testing.T) { } for _, timezone := range timezones { - sc := stmtctx.NewStmtCtxWithTimeZone(timezone) - testConvertTimeTimeZone(t, sc) + ctx := DefaultStmtNoWarningContext.WithLocation(timezone) + testConvertTimeTimeZone(t, ctx) } } -func testConvertTimeTimeZone(t *testing.T, sc *stmtctx.StatementContext) { +func testConvertTimeTimeZone(t *testing.T, ctx Context) { raw := FromDate(2002, 3, 4, 4, 6, 7, 8) tests := []struct { input Time @@ -1054,7 +1032,7 @@ func testConvertTimeTimeZone(t *testing.T, sc *stmtctx.StatementContext) { for _, test := range tests { var d Datum d.SetMysqlTime(test.input) - nd, err := d.ConvertTo(sc, test.target) + nd, err := d.ConvertTo(ctx, test.target) require.NoError(t, err) v := nd.GetMysqlTime() require.Equal(t, test.expect.Type(), v.Type()) @@ -1084,7 +1062,7 @@ func TestConvertJSONToInt(t *testing.T) { j, err := ParseBinaryJSONFromString(tt.in) require.NoError(t, err) - casted, err := ConvertJSONToInt64(stmtctx.NewStmtCtx().TypeCtx(), j, false) + casted, err := ConvertJSONToInt64(DefaultStmtNoWarningContext, j, false) if tt.err { require.Error(t, err, tt) } else { @@ -1287,7 +1265,7 @@ func TestConvertDecimalStrToUint(t *testing.T) { {"-10000000000000000000.0", 0, false}, } for _, ca := range cases { - result, err := convertDecimalStrToUint(stmtctx.NewStmtCtx(), ca.input, math.MaxUint64, 0) + result, err := convertDecimalStrToUint(ca.input, math.MaxUint64, 0) if !ca.succ { require.Error(t, err) } else { @@ -1296,11 +1274,11 @@ func TestConvertDecimalStrToUint(t *testing.T) { require.Equal(t, ca.result, result, "input=%v", ca.input) } - result, err := convertDecimalStrToUint(stmtctx.NewStmtCtx(), "-99.0", math.MaxUint8, 0) + result, err := convertDecimalStrToUint("-99.0", math.MaxUint8, 0) require.Error(t, err) require.Equal(t, uint64(0), result) - result, err = convertDecimalStrToUint(stmtctx.NewStmtCtx(), "-100.0", math.MaxUint8, 0) + result, err = convertDecimalStrToUint("-100.0", math.MaxUint8, 0) require.Error(t, err) require.Equal(t, uint64(0), result) } diff --git a/pkg/types/datum.go b/pkg/types/datum.go index 2408669b9604e..5d5a804eba54e 100644 --- a/pkg/types/datum.go +++ b/pkg/types/datum.go @@ -32,7 +32,6 @@ import ( "github.com/pingcap/tidb/pkg/parser/mysql" "github.com/pingcap/tidb/pkg/parser/terror" "github.com/pingcap/tidb/pkg/parser/types" - "github.com/pingcap/tidb/pkg/sessionctx/stmtctx" "github.com/pingcap/tidb/pkg/util/collate" "github.com/pingcap/tidb/pkg/util/hack" "github.com/pingcap/tidb/pkg/util/logutil" @@ -630,13 +629,9 @@ func (d *Datum) SetValue(val interface{}, tp *types.FieldType) { // Compare compares datum to another datum. // Notes: don't rely on datum.collation to get the collator, it's tend to buggy. -func (d *Datum) Compare(sc *stmtctx.StatementContext, ad *Datum, comparer collate.Collator) (int, error) { - typeCtx := DefaultStmtNoWarningContext - if sc != nil { - typeCtx = sc.TypeCtx() - } +func (d *Datum) Compare(ctx Context, ad *Datum, comparer collate.Collator) (int, error) { if d.k == KindMysqlJSON && ad.k != KindMysqlJSON { - cmp, err := ad.Compare(sc, d, comparer) + cmp, err := ad.Compare(ctx, d, comparer) return cmp * -1, errors.Trace(err) } switch ad.k { @@ -658,29 +653,29 @@ func (d *Datum) Compare(sc *stmtctx.StatementContext, ad *Datum, comparer collat } return -1, nil case KindInt64: - return d.compareInt64(typeCtx, ad.GetInt64()) + return d.compareInt64(ctx, ad.GetInt64()) case KindUint64: - return d.compareUint64(typeCtx, ad.GetUint64()) + return d.compareUint64(ctx, ad.GetUint64()) case KindFloat32, KindFloat64: - return d.compareFloat64(typeCtx, ad.GetFloat64()) + return d.compareFloat64(ctx, ad.GetFloat64()) case KindString: - return d.compareString(typeCtx, ad.GetString(), comparer) + return d.compareString(ctx, ad.GetString(), comparer) case KindBytes: - return d.compareString(typeCtx, ad.GetString(), comparer) + return d.compareString(ctx, ad.GetString(), comparer) case KindMysqlDecimal: - return d.compareMysqlDecimal(sc, ad.GetMysqlDecimal()) + return d.compareMysqlDecimal(ctx, ad.GetMysqlDecimal()) case KindMysqlDuration: - return d.compareMysqlDuration(typeCtx, ad.GetMysqlDuration()) + return d.compareMysqlDuration(ctx, ad.GetMysqlDuration()) case KindMysqlEnum: - return d.compareMysqlEnum(typeCtx, ad.GetMysqlEnum(), comparer) + return d.compareMysqlEnum(ctx, ad.GetMysqlEnum(), comparer) case KindBinaryLiteral, KindMysqlBit: - return d.compareBinaryLiteral(typeCtx, ad.GetBinaryLiteral4Cmp(), comparer) + return d.compareBinaryLiteral(ctx, ad.GetBinaryLiteral4Cmp(), comparer) case KindMysqlSet: - return d.compareMysqlSet(typeCtx, ad.GetMysqlSet(), comparer) + return d.compareMysqlSet(ctx, ad.GetMysqlSet(), comparer) case KindMysqlJSON: - return d.compareMysqlJSON(sc, ad.GetMysqlJSON()) + return d.compareMysqlJSON(ad.GetMysqlJSON()) case KindMysqlTime: - return d.compareMysqlTime(typeCtx, ad.GetMysqlTime()) + return d.compareMysqlTime(ctx, ad.GetMysqlTime()) default: return 0, nil } @@ -790,9 +785,7 @@ func (d *Datum) compareString(ctx Context, s string, comparer collate.Collator) } } -func (d *Datum) compareMysqlDecimal(sc *stmtctx.StatementContext, dec *MyDecimal) (int, error) { - typeCtx := sc.TypeCtxOrDefault() - +func (d *Datum) compareMysqlDecimal(ctx Context, dec *MyDecimal) (int, error) { switch d.k { case KindNull, KindMinNotNull: return -1, nil @@ -802,10 +795,10 @@ func (d *Datum) compareMysqlDecimal(sc *stmtctx.StatementContext, dec *MyDecimal return d.GetMysqlDecimal().Compare(dec), nil case KindString, KindBytes: dDec := new(MyDecimal) - err := typeCtx.HandleTruncate(dDec.FromString(d.GetBytes())) + err := ctx.HandleTruncate(dDec.FromString(d.GetBytes())) return dDec.Compare(dec), errors.Trace(err) default: - dVal, err := d.ConvertTo(sc, NewFieldType(mysql.TypeNewDecimal)) + dVal, err := d.ConvertTo(ctx, NewFieldType(mysql.TypeNewDecimal)) if err != nil { return 0, errors.Trace(err) } @@ -875,7 +868,7 @@ func (d *Datum) compareMysqlSet(ctx Context, set Set, comparer collate.Collator) } } -func (d *Datum) compareMysqlJSON(_ *stmtctx.StatementContext, target BinaryJSON) (int, error) { +func (d *Datum) compareMysqlJSON(target BinaryJSON) (int, error) { // json is not equal with NULL if d.k == KindNull { return 1, nil @@ -910,9 +903,7 @@ func (d *Datum) compareMysqlTime(ctx Context, time Time) (int, error) { // ConvertTo converts a datum to the target field type. // change this method need sync modification to type2Kind in rowcodec/types.go -func (d *Datum) ConvertTo(sc *stmtctx.StatementContext, target *FieldType) (Datum, error) { - typeCtx := sc.TypeCtxOrDefault() - +func (d *Datum) ConvertTo(ctx Context, target *FieldType) (Datum, error) { if d.k == KindNull { return Datum{}, nil } @@ -920,32 +911,32 @@ func (d *Datum) ConvertTo(sc *stmtctx.StatementContext, target *FieldType) (Datu case mysql.TypeTiny, mysql.TypeShort, mysql.TypeInt24, mysql.TypeLong, mysql.TypeLonglong: unsigned := mysql.HasUnsignedFlag(target.GetFlag()) if unsigned { - return d.convertToUint(sc, target) + return d.convertToUint(ctx, target) } - return d.convertToInt(sc, target) + return d.convertToInt(ctx, target) case mysql.TypeFloat, mysql.TypeDouble: - return d.convertToFloat(sc.TypeCtxOrDefault(), target) + return d.convertToFloat(ctx, target) case mysql.TypeBlob, mysql.TypeTinyBlob, mysql.TypeMediumBlob, mysql.TypeLongBlob, mysql.TypeString, mysql.TypeVarchar, mysql.TypeVarString: - return d.convertToString(sc, target) + return d.convertToString(ctx, target) case mysql.TypeTimestamp: - return d.convertToMysqlTimestamp(typeCtx, target) + return d.convertToMysqlTimestamp(ctx, target) case mysql.TypeDatetime, mysql.TypeDate: - return d.convertToMysqlTime(typeCtx, target) + return d.convertToMysqlTime(ctx, target) case mysql.TypeDuration: - return d.convertToMysqlDuration(sc, target) + return d.convertToMysqlDuration(ctx, target) case mysql.TypeNewDecimal: - return d.convertToMysqlDecimal(sc.TypeCtxOrDefault(), target) + return d.convertToMysqlDecimal(ctx, target) case mysql.TypeYear: - return d.ConvertToMysqlYear(sc, target) + return d.ConvertToMysqlYear(ctx, target) case mysql.TypeEnum: - return d.convertToMysqlEnum(sc, target) + return d.convertToMysqlEnum(ctx, target) case mysql.TypeBit: - return d.convertToMysqlBit(sc, target) + return d.convertToMysqlBit(ctx, target) case mysql.TypeSet: - return d.convertToMysqlSet(sc, target) + return d.convertToMysqlSet(ctx, target) case mysql.TypeJSON: - return d.convertToMysqlJSON(sc, target) + return d.convertToMysqlJSON(target) case mysql.TypeNull: return Datum{}, nil default: @@ -1032,13 +1023,12 @@ func ProduceFloatWithSpecifiedTp(f float64, target *FieldType) (_ float64, err e return f, errors.Trace(err) } -func (d *Datum) convertToString(sc *stmtctx.StatementContext, target *FieldType) (Datum, error) { +func (d *Datum) convertToString(ctx Context, target *FieldType) (Datum, error) { var ( ret Datum s string err error ) - ctx := sc.TypeCtx() switch d.k { case KindInt64: s = strconv.FormatInt(d.GetInt64(), 10) @@ -1183,12 +1173,12 @@ func ProduceStrWithSpecifiedTp(s string, tp *FieldType, ctx Context, padZero boo return s, errors.Trace(ctx.HandleTruncate(err)) } -func (d *Datum) convertToInt(sc *stmtctx.StatementContext, target *FieldType) (Datum, error) { - i64, err := d.toSignedInteger(sc, target.GetType()) +func (d *Datum) convertToInt(ctx Context, target *FieldType) (Datum, error) { + i64, err := d.toSignedInteger(ctx, target.GetType()) return NewIntDatum(i64), errors.Trace(err) } -func (d *Datum) convertToUint(sc *stmtctx.StatementContext, target *FieldType) (Datum, error) { +func (d *Datum) convertToUint(ctx Context, target *FieldType) (Datum, error) { tp := target.GetType() upperBound := IntergerUnsignedUpperBound(tp) var ( @@ -1198,14 +1188,14 @@ func (d *Datum) convertToUint(sc *stmtctx.StatementContext, target *FieldType) ( ) switch d.k { case KindInt64: - val, err = ConvertIntToUint(sc.TypeFlags(), d.GetInt64(), upperBound, tp) + val, err = ConvertIntToUint(ctx.Flags(), d.GetInt64(), upperBound, tp) case KindUint64: val, err = ConvertUintToUint(d.GetUint64(), upperBound, tp) case KindFloat32, KindFloat64: - val, err = ConvertFloatToUint(sc.TypeFlags(), d.GetFloat64(), upperBound, tp) + val, err = ConvertFloatToUint(ctx.Flags(), d.GetFloat64(), upperBound, tp) case KindString, KindBytes: var err1 error - val, err1 = StrToUint(sc.TypeCtxOrDefault(), d.GetString(), false) + val, err1 = StrToUint(ctx, d.GetString(), false) val, err = ConvertUintToUint(val, upperBound, tp) if err == nil { err = err1 @@ -1217,7 +1207,7 @@ func (d *Datum) convertToUint(sc *stmtctx.StatementContext, target *FieldType) ( if err == nil { err = err1 } - val, err1 = ConvertIntToUint(sc.TypeFlags(), ival, upperBound, tp) + val, err1 = ConvertIntToUint(ctx.Flags(), ival, upperBound, tp) if err == nil { err = err1 } @@ -1225,24 +1215,24 @@ func (d *Datum) convertToUint(sc *stmtctx.StatementContext, target *FieldType) ( dec := d.GetMysqlDuration().ToNumber() err = dec.Round(dec, 0, ModeHalfUp) var err1 error - val, err1 = ConvertDecimalToUint(sc, dec, upperBound, tp) + val, err1 = ConvertDecimalToUint(dec, upperBound, tp) if err == nil { err = err1 } case KindMysqlDecimal: - val, err = ConvertDecimalToUint(sc, d.GetMysqlDecimal(), upperBound, tp) + val, err = ConvertDecimalToUint(d.GetMysqlDecimal(), upperBound, tp) case KindMysqlEnum: - val, err = ConvertFloatToUint(sc.TypeFlags(), d.GetMysqlEnum().ToNumber(), upperBound, tp) + val, err = ConvertFloatToUint(ctx.Flags(), d.GetMysqlEnum().ToNumber(), upperBound, tp) case KindMysqlSet: - val, err = ConvertFloatToUint(sc.TypeFlags(), d.GetMysqlSet().ToNumber(), upperBound, tp) + val, err = ConvertFloatToUint(ctx.Flags(), d.GetMysqlSet().ToNumber(), upperBound, tp) case KindBinaryLiteral, KindMysqlBit: - val, err = d.GetBinaryLiteral().ToInt(sc.TypeCtxOrDefault()) + val, err = d.GetBinaryLiteral().ToInt(ctx) if err == nil { val, err = ConvertUintToUint(val, upperBound, tp) } case KindMysqlJSON: var i64 int64 - i64, err = ConvertJSONToInt(sc.TypeCtxOrDefault(), d.GetMysqlJSON(), true, tp) + i64, err = ConvertJSONToInt(ctx, d.GetMysqlJSON(), true, tp) val = uint64(i64) default: return invalidConv(d, target.GetType()) @@ -1369,9 +1359,7 @@ func (d *Datum) convertToMysqlTime(ctx Context, target *FieldType) (Datum, error return ret, nil } -func (d *Datum) convertToMysqlDuration(sc *stmtctx.StatementContext, target *FieldType) (Datum, error) { - typeCtx := sc.TypeCtx() - +func (d *Datum) convertToMysqlDuration(typeCtx Context, target *FieldType) (Datum, error) { tp := target.GetType() fsp := DefaultFsp if target.GetDecimal() != UnspecifiedLength { @@ -1385,13 +1373,13 @@ func (d *Datum) convertToMysqlDuration(sc *stmtctx.StatementContext, target *Fie ret.SetMysqlDuration(dur) return ret, errors.Trace(err) } - dur, err = dur.RoundFrac(fsp, sc.TimeZone()) + dur, err = dur.RoundFrac(fsp, typeCtx.Location()) ret.SetMysqlDuration(dur) if err != nil { return ret, errors.Trace(err) } case KindMysqlDuration: - dur, err := d.GetMysqlDuration().RoundFrac(fsp, sc.TimeZone()) + dur, err := d.GetMysqlDuration().RoundFrac(fsp, typeCtx.Location()) ret.SetMysqlDuration(dur) if err != nil { return ret, errors.Trace(err) @@ -1402,7 +1390,7 @@ func (d *Datum) convertToMysqlDuration(sc *stmtctx.StatementContext, target *Fie if err != nil { return ret, errors.Trace(err) } - timeNum, err := d.ToInt64(sc) + timeNum, err := d.ToInt64(typeCtx) if err != nil { return ret, errors.Trace(err) } @@ -1540,7 +1528,7 @@ func ProduceDecWithSpecifiedTp(ctx Context, dec *MyDecimal, tp *FieldType) (_ *M } // ConvertToMysqlYear converts a datum to MySQLYear. -func (d *Datum) ConvertToMysqlYear(sc *stmtctx.StatementContext, target *FieldType) (Datum, error) { +func (d *Datum) ConvertToMysqlYear(ctx Context, target *FieldType) (Datum, error) { var ( ret Datum y int64 @@ -1551,7 +1539,7 @@ func (d *Datum) ConvertToMysqlYear(sc *stmtctx.StatementContext, target *FieldTy case KindString, KindBytes: s := d.GetString() trimS := strings.TrimSpace(s) - y, err = StrToInt(sc.TypeCtxOrDefault(), trimS, false) + y, err = StrToInt(ctx, trimS, false) if err != nil { ret.SetInt64(0) return ret, errors.Trace(err) @@ -1564,13 +1552,13 @@ func (d *Datum) ConvertToMysqlYear(sc *stmtctx.StatementContext, target *FieldTy case KindMysqlTime: y = int64(d.GetMysqlTime().Year()) case KindMysqlJSON: - y, err = ConvertJSONToInt64(sc.TypeCtxOrDefault(), d.GetMysqlJSON(), false) + y, err = ConvertJSONToInt64(ctx, d.GetMysqlJSON(), false) if err != nil { ret.SetInt64(0) return ret, errors.Trace(err) } default: - ret, err = d.convertToInt(sc, NewFieldType(mysql.TypeLonglong)) + ret, err = d.convertToInt(ctx, NewFieldType(mysql.TypeLonglong)) if err != nil { _, err = invalidConv(d, target.GetType()) ret.SetInt64(0) @@ -1592,13 +1580,13 @@ func (d *Datum) convertStringToMysqlBit(ctx Context) (uint64, error) { return bitStr.ToInt(ctx) } -func (d *Datum) convertToMysqlBit(sc *stmtctx.StatementContext, target *FieldType) (Datum, error) { +func (d *Datum) convertToMysqlBit(ctx Context, target *FieldType) (Datum, error) { var ret Datum var uintValue uint64 var err error switch d.k { case KindBytes: - uintValue, err = BinaryLiteral(d.b).ToInt(sc.TypeCtxOrDefault()) + uintValue, err = BinaryLiteral(d.b).ToInt(ctx) case KindString: // For single bit value, we take string like "true", "1" as 1, and "false", "0" as 0, // this behavior is not documented in MySQL, but it behaves so, for more information, see issue #18681 @@ -1610,17 +1598,17 @@ func (d *Datum) convertToMysqlBit(sc *stmtctx.StatementContext, target *FieldTyp case "false", "0": uintValue = 0 default: - uintValue, err = d.convertStringToMysqlBit(sc.TypeCtxOrDefault()) + uintValue, err = d.convertStringToMysqlBit(ctx) } } else { - uintValue, err = d.convertStringToMysqlBit(sc.TypeCtxOrDefault()) + uintValue, err = d.convertStringToMysqlBit(ctx) } case KindInt64: // if input kind is int64 (signed), when trans to bit, we need to treat it as unsigned d.k = KindUint64 fallthrough default: - uintDatum, err1 := d.convertToUint(sc, target) + uintDatum, err1 := d.convertToUint(ctx, target) uintValue, err = uintDatum.GetUint64(), err1 } // Avoid byte size panic, never goto this branch. @@ -1636,7 +1624,7 @@ func (d *Datum) convertToMysqlBit(sc *stmtctx.StatementContext, target *FieldTyp return ret, errors.Trace(err) } -func (d *Datum) convertToMysqlEnum(sc *stmtctx.StatementContext, target *FieldType) (Datum, error) { +func (d *Datum) convertToMysqlEnum(ctx Context, target *FieldType) (Datum, error) { var ( ret Datum e Enum @@ -1657,7 +1645,7 @@ func (d *Datum) convertToMysqlEnum(sc *stmtctx.StatementContext, target *FieldTy e, err = ParseEnum(target.GetElems(), d.GetMysqlSet().Name, target.GetCollate()) default: var uintDatum Datum - uintDatum, err = d.convertToUint(sc, target) + uintDatum, err = d.convertToUint(ctx, target) if err == nil { e, err = ParseEnumValue(target.GetElems(), uintDatum.GetUint64()) } else { @@ -1668,7 +1656,7 @@ func (d *Datum) convertToMysqlEnum(sc *stmtctx.StatementContext, target *FieldTy return ret, err } -func (d *Datum) convertToMysqlSet(sc *stmtctx.StatementContext, target *FieldType) (Datum, error) { +func (d *Datum) convertToMysqlSet(ctx Context, target *FieldType) (Datum, error) { var ( ret Datum s Set @@ -1683,7 +1671,7 @@ func (d *Datum) convertToMysqlSet(sc *stmtctx.StatementContext, target *FieldTyp s, err = ParseSet(target.GetElems(), d.GetMysqlSet().Name, target.GetCollate()) default: var uintDatum Datum - uintDatum, err = d.convertToUint(sc, target) + uintDatum, err = d.convertToUint(ctx, target) if err == nil { s, err = ParseSetValue(target.GetElems(), uintDatum.GetUint64()) } @@ -1695,7 +1683,7 @@ func (d *Datum) convertToMysqlSet(sc *stmtctx.StatementContext, target *FieldTyp return ret, err } -func (d *Datum) convertToMysqlJSON(_ *stmtctx.StatementContext, _ *FieldType) (ret Datum, err error) { +func (d *Datum) convertToMysqlJSON(_ *FieldType) (ret Datum, err error) { switch d.k { case KindString, KindBytes: var j BinaryJSON @@ -1845,15 +1833,15 @@ func (d *Datum) ToDecimal(ctx Context) (*MyDecimal, error) { } // ToInt64 converts to a int64. -func (d *Datum) ToInt64(sc *stmtctx.StatementContext) (int64, error) { +func (d *Datum) ToInt64(ctx Context) (int64, error) { if d.Kind() == KindMysqlBit { - uintVal, err := d.GetBinaryLiteral().ToInt(sc.TypeCtxOrDefault()) + uintVal, err := d.GetBinaryLiteral().ToInt(ctx) return int64(uintVal), err } - return d.toSignedInteger(sc, mysql.TypeLonglong) + return d.toSignedInteger(ctx, mysql.TypeLonglong) } -func (d *Datum) toSignedInteger(sc *stmtctx.StatementContext, tp byte) (int64, error) { +func (d *Datum) toSignedInteger(ctx Context, tp byte) (int64, error) { lowerBound := IntergerSignedLowerBound(tp) upperBound := IntergerSignedUpperBound(tp) switch d.Kind() { @@ -1866,7 +1854,7 @@ func (d *Datum) toSignedInteger(sc *stmtctx.StatementContext, tp byte) (int64, e case KindFloat64: return ConvertFloatToInt(d.GetFloat64(), lowerBound, upperBound, tp) case KindString, KindBytes: - iVal, err := StrToInt(sc.TypeCtxOrDefault(), d.GetString(), false) + iVal, err := StrToInt(ctx, d.GetString(), false) iVal, err2 := ConvertIntToInt(iVal, lowerBound, upperBound, tp) if err == nil { err = err2 @@ -1875,7 +1863,7 @@ func (d *Datum) toSignedInteger(sc *stmtctx.StatementContext, tp byte) (int64, e case KindMysqlTime: // 2011-11-10 11:11:11.999999 -> 20111110111112 // 2011-11-10 11:59:59.999999 -> 20111110120000 - t, err := d.GetMysqlTime().RoundFrac(sc.TypeCtxOrDefault(), DefaultFsp) + t, err := d.GetMysqlTime().RoundFrac(ctx, DefaultFsp) if err != nil { return 0, errors.Trace(err) } @@ -1888,7 +1876,7 @@ func (d *Datum) toSignedInteger(sc *stmtctx.StatementContext, tp byte) (int64, e case KindMysqlDuration: // 11:11:11.999999 -> 111112 // 11:59:59.999999 -> 120000 - dur, err := d.GetMysqlDuration().RoundFrac(DefaultFsp, sc.TimeZone()) + dur, err := d.GetMysqlDuration().RoundFrac(DefaultFsp, ctx.Location()) if err != nil { return 0, errors.Trace(err) } @@ -1917,9 +1905,9 @@ func (d *Datum) toSignedInteger(sc *stmtctx.StatementContext, tp byte) (int64, e fval := d.GetMysqlSet().ToNumber() return ConvertFloatToInt(fval, lowerBound, upperBound, tp) case KindMysqlJSON: - return ConvertJSONToInt(sc.TypeCtxOrDefault(), d.GetMysqlJSON(), false, tp) + return ConvertJSONToInt(ctx, d.GetMysqlJSON(), false, tp) case KindBinaryLiteral, KindMysqlBit: - val, err := d.GetBinaryLiteral().ToInt(sc.TypeCtxOrDefault()) + val, err := d.GetBinaryLiteral().ToInt(ctx) if err != nil { return 0, errors.Trace(err) } @@ -2263,15 +2251,15 @@ func MaxValueDatum() Datum { } // SortDatums sorts a slice of datum. -func SortDatums(sc *stmtctx.StatementContext, datums []Datum) error { - sorter := datumsSorter{datums: datums, sc: sc} +func SortDatums(ctx Context, datums []Datum) error { + sorter := datumsSorter{datums: datums, ctx: ctx} sort.Sort(&sorter) return sorter.err } type datumsSorter struct { datums []Datum - sc *stmtctx.StatementContext + ctx Context err error } @@ -2280,7 +2268,7 @@ func (ds *datumsSorter) Len() int { } func (ds *datumsSorter) Less(i, j int) bool { - cmp, err := ds.datums[i].Compare(ds.sc, &ds.datums[j], collate.GetCollator(ds.datums[i].Collation())) + cmp, err := ds.datums[i].Compare(ds.ctx, &ds.datums[j], collate.GetCollator(ds.datums[i].Collation())) if err != nil { ds.err = errors.Trace(err) return true @@ -2461,11 +2449,11 @@ func getDatumBound(retType *FieldType, rType RoundingType) Datum { // case, we should judge whether the rounding type are ceiling. If it is, then we should plus one for // 1.0 and get the reverse result 2.0. func ChangeReverseResultByUpperLowerBound( - sc *stmtctx.StatementContext, + ctx Context, retType *FieldType, res Datum, rType RoundingType) (Datum, error) { - d, err := res.ConvertTo(sc, retType) + d, err := res.ConvertTo(ctx, retType) if terror.ErrorEqual(err, ErrOverflow) { return d, nil } @@ -2489,7 +2477,7 @@ func ChangeReverseResultByUpperLowerBound( resRetType.SetDecimalUnderLimit(int(res.GetMysqlDecimal().GetDigitsInt())) } bound := getDatumBound(&resRetType, rType) - cmp, err := d.Compare(sc, &bound, collate.GetCollator(resRetType.GetCollate())) + cmp, err := d.Compare(ctx, &bound, collate.GetCollator(resRetType.GetCollate())) if err != nil { return d, err } diff --git a/pkg/types/datum_test.go b/pkg/types/datum_test.go index dc931779fd5a9..b1f4dc053e0eb 100644 --- a/pkg/types/datum_test.go +++ b/pkg/types/datum_test.go @@ -26,7 +26,6 @@ import ( "github.com/pingcap/tidb/pkg/parser/charset" "github.com/pingcap/tidb/pkg/parser/mysql" - "github.com/pingcap/tidb/pkg/sessionctx/stmtctx" "github.com/pingcap/tidb/pkg/util/collate" "github.com/pingcap/tidb/pkg/util/hack" "github.com/stretchr/testify/assert" @@ -115,10 +114,9 @@ func TestToBool(t *testing.T) { func testDatumToInt64(t *testing.T, val interface{}, expect int64) { d := NewDatum(val) - sc := stmtctx.NewStmtCtx() - sc.SetTypeFlags(sc.TypeFlags().WithIgnoreTruncateErr(true)) + ctx := DefaultStmtNoWarningContext.WithFlags(DefaultStmtFlags.WithIgnoreTruncateErr(true)) - b, err := d.ToInt64(sc) + b, err := d.ToInt64(ctx) require.NoError(t, err) require.Equal(t, expect, b) } @@ -152,12 +150,11 @@ func TestToInt64(t *testing.T) { func testDatumToUInt32(t *testing.T, val interface{}, expect uint32, hasError bool) { d := NewDatum(val) - sc := stmtctx.NewStmtCtx() - sc.SetTypeFlags(sc.TypeFlags().WithIgnoreTruncateErr(true)) + ctx := DefaultStmtNoWarningContext.WithFlags(DefaultStmtFlags.WithIgnoreTruncateErr(true)) ft := NewFieldType(mysql.TypeLong) ft.AddFlag(mysql.UnsignedFlag) - converted, err := d.ConvertTo(sc, ft) + converted, err := d.ConvertTo(ctx, ft) if hasError { require.Error(t, err) @@ -204,10 +201,9 @@ func TestConvertToFloat(t *testing.T) { {NewDatum("281.37"), mysql.TypeFloat, "", 281.37, 281.37}, } - sc := stmtctx.NewStmtCtx() - sc.SetTypeFlags(sc.TypeFlags().WithIgnoreTruncateErr(true)) + ctx := DefaultStmtNoWarningContext.WithFlags(DefaultStmtFlags.WithIgnoreTruncateErr(true)) for _, testCase := range testCases { - converted, err := testCase.d.ConvertTo(sc, NewFieldType(testCase.tp)) + converted, err := testCase.d.ConvertTo(ctx, NewFieldType(testCase.tp)) if testCase.errMsg == "" { require.NoError(t, err) } else { @@ -241,7 +237,6 @@ func mustParseTimeIntoDatum(s string, tp byte, fsp int) (d Datum) { func TestToJSON(t *testing.T) { ft := NewFieldType(mysql.TypeJSON) - sc := stmtctx.NewStmtCtx() tests := []struct { datum Datum expected interface{} @@ -260,14 +255,14 @@ func TestToJSON(t *testing.T) { {NewStringDatum("hello, 世界"), "", false}, } for _, tt := range tests { - obtain, err := tt.datum.ConvertTo(sc, ft) + obtain, err := tt.datum.ConvertTo(DefaultStmtNoWarningContext, ft) if tt.success { require.NoError(t, err) expected := NewJSONDatum(CreateBinaryJSON(tt.expected)) var cmp int - cmp, err = obtain.Compare(sc, &expected, collate.GetBinaryCollator()) + cmp, err = obtain.Compare(DefaultStmtNoWarningContext, &expected, collate.GetBinaryCollator()) require.NoError(t, err) require.Equal(t, 0, cmp) } else { @@ -309,8 +304,6 @@ func TestToBytes(t *testing.T) { {NewStringDatum("abc"), []byte("abc")}, {Datum{}, []byte{}}, } - sc := stmtctx.NewStmtCtx() - sc.SetTypeFlags(sc.TypeFlags().WithIgnoreTruncateErr(true)) for _, tt := range tests { bin, err := tt.a.ToBytes() require.NoError(t, err) @@ -319,7 +312,6 @@ func TestToBytes(t *testing.T) { } func TestComputePlusAndMinus(t *testing.T) { - sc := stmtctx.NewStmtCtxWithTimeZone(time.UTC) tests := []struct { a Datum b Datum @@ -340,7 +332,7 @@ func TestComputePlusAndMinus(t *testing.T) { for ith, tt := range tests { got, err := ComputePlus(tt.a, tt.b) require.Equal(t, tt.hasErr, err != nil) - v, err := got.Compare(sc, &tt.plus, collate.GetBinaryCollator()) + v, err := got.Compare(DefaultStmtNoWarningContext, &tt.plus, collate.GetBinaryCollator()) require.NoError(t, err) require.Equalf(t, 0, v, "%dth got:%#v, %#v, expect:%#v, %#v", ith, got, got.x, tt.plus, tt.plus.x) } @@ -358,11 +350,11 @@ func TestCloneDatum(t *testing.T) { raw, } - sc := stmtctx.NewStmtCtx() - sc.SetTypeFlags(sc.TypeFlags().WithIgnoreTruncateErr(true)) + ctx := DefaultStmtNoWarningContext.WithFlags(DefaultStmtFlags.WithIgnoreTruncateErr(true)) + for _, tt := range tests { tt1 := *tt.Clone() - res, err := tt.Compare(sc, &tt1, collate.GetBinaryCollator()) + res, err := tt.Compare(ctx, &tt1, collate.GetBinaryCollator()) require.NoError(t, err) require.Equal(t, 0, res) if tt.b != nil { @@ -412,9 +404,7 @@ func TestEstimatedMemUsage(t *testing.T) { } func TestChangeReverseResultByUpperLowerBound(t *testing.T) { - sc := stmtctx.NewStmtCtx() - sc.SetTypeFlags(sc.TypeFlags().WithIgnoreTruncateErr(true)) - sc.OverflowAsWarning = true + ctx := DefaultStmtNoWarningContext.WithFlags(DefaultStmtFlags.WithIgnoreTruncateErr(true)) // TODO: add more reserve convert tests for each pair of convert type. testData := []struct { a Datum @@ -499,10 +489,10 @@ func TestChangeReverseResultByUpperLowerBound(t *testing.T) { }, } for ith, test := range testData { - reverseRes, err := ChangeReverseResultByUpperLowerBound(sc, test.retType, test.a, test.roundType) + reverseRes, err := ChangeReverseResultByUpperLowerBound(ctx, test.retType, test.a, test.roundType) require.NoError(t, err) var cmp int - cmp, err = reverseRes.Compare(sc, &test.res, collate.GetBinaryCollator()) + cmp, err = reverseRes.Compare(ctx, &test.res, collate.GetBinaryCollator()) require.NoError(t, err) require.Equalf(t, 0, cmp, "%dth got:%#v, expect:%#v", ith, reverseRes, test.res) } @@ -537,12 +527,10 @@ func TestStringToMysqlBit(t *testing.T) { {NewStringDatum("b'1'"), []byte{1}}, {NewStringDatum("b'0'"), []byte{0}}, } - sc := stmtctx.NewStmtCtx() - sc.SetTypeFlags(sc.TypeFlags().WithIgnoreTruncateErr(true)) tp := NewFieldType(mysql.TypeBit) tp.SetFlen(1) for _, tt := range tests { - bin, err := tt.a.convertToMysqlBit(nil, tp) + bin, err := tt.a.convertToMysqlBit(DefaultStmtNoWarningContext, tp) require.NoError(t, err) require.Equal(t, tt.out, bin.b) } @@ -603,11 +591,10 @@ func TestMarshalDatum(t *testing.T) { func BenchmarkCompareDatum(b *testing.B) { vals, vals1 := prepareCompareDatums() - sc := stmtctx.NewStmtCtx() b.ResetTimer() for i := 0; i < b.N; i++ { for j, v := range vals { - _, err := v.Compare(sc, &vals1[j], collate.GetBinaryCollator()) + _, err := v.Compare(DefaultStmtNoWarningContext, &vals1[j], collate.GetBinaryCollator()) if err != nil { b.Fatal(err) } @@ -650,11 +637,12 @@ func TestProduceDecWithSpecifiedTp(t *testing.T) { {"99.9999", 6, 3, "100.000", false, true}, {"-99.9999", 6, 3, "-100.000", false, true}, } - sc := stmtctx.NewStmtCtx() + warnings := &warnStore{} + ctx := NewContext(DefaultStmtFlags, time.UTC, warnings.AppendWarning) for _, tt := range tests { tp := NewFieldTypeBuilder().SetType(mysql.TypeNewDecimal).SetFlen(tt.flen).SetDecimal(tt.frac).BuildP() dec := NewDecFromStringForTest(tt.dec) - newDec, err := ProduceDecWithSpecifiedTp(sc.TypeCtx(), dec, tp) + newDec, err := ProduceDecWithSpecifiedTp(ctx, dec, tp) if tt.isOverflow { if !ErrOverflow.Equal(err) { assert.FailNow(t, "Error is not overflow", "err: %v before: %v after: %v", err, tt.dec, dec) @@ -663,9 +651,10 @@ func TestProduceDecWithSpecifiedTp(t *testing.T) { require.NoError(t, err, tt) } require.Equal(t, tt.newDec, newDec.String()) - warn := sc.TruncateWarnings(0) + warn := warnings.GetWarnings() + warnings.Reset() if tt.isTruncated { - if len(warn) != 1 || !ErrTruncatedWrongVal.Equal(warn[0].Err) { + if len(warn) != 1 || !ErrTruncatedWrongVal.Equal(warn[0]) { assert.FailNow(t, "Warn is not truncated", "warn: %v before: %v after: %v", warn, tt.dec, dec) } } else { @@ -696,9 +685,8 @@ func TestNULLNotEqualWithOthers(t *testing.T) { MaxValueDatum(), } nullDatum := NewDatum(nil) - sc := stmtctx.NewStmtCtx() for _, d := range datums { - result, err := d.Compare(sc, &nullDatum, collate.GetBinaryCollator()) + result, err := d.Compare(DefaultStmtNoWarningContext, &nullDatum, collate.GetBinaryCollator()) require.NoError(t, err) require.NotEqual(t, 0, result) } diff --git a/pkg/types/time.go b/pkg/types/time.go index 3cc91576a8e61..cadb6968cfa07 100644 --- a/pkg/types/time.go +++ b/pkg/types/time.go @@ -30,7 +30,6 @@ import ( "github.com/pingcap/tidb/pkg/errno" "github.com/pingcap/tidb/pkg/parser/mysql" "github.com/pingcap/tidb/pkg/parser/terror" - "github.com/pingcap/tidb/pkg/sessionctx/stmtctx" "github.com/pingcap/tidb/pkg/util/dbterror" "github.com/pingcap/tidb/pkg/util/logutil" "github.com/pingcap/tidb/pkg/util/mathutil" @@ -2023,7 +2022,7 @@ func ParseDate(ctx Context, str string) (Time, error) { // ParseTimeFromYear parse a `YYYY` formed year to corresponded Datetime type. // Note: the invoker must promise the `year` is in the range [MinYear, MaxYear]. -func ParseTimeFromYear(_ *stmtctx.StatementContext, year int64) (Time, error) { +func ParseTimeFromYear(year int64) (Time, error) { if year == 0 { return NewTime(ZeroCoreTime, mysql.TypeDate, DefaultFsp), nil } diff --git a/pkg/types/context/truncate.go b/pkg/types/truncate.go similarity index 99% rename from pkg/types/context/truncate.go rename to pkg/types/truncate.go index 271c8ed4b1d16..67b531925b480 100644 --- a/pkg/types/context/truncate.go +++ b/pkg/types/truncate.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package context +package types import ( "github.com/pingcap/errors" diff --git a/pkg/util/chunk/chunk_test.go b/pkg/util/chunk/chunk_test.go index 4bbb3bc723ed0..7ff2292db7f89 100644 --- a/pkg/util/chunk/chunk_test.go +++ b/pkg/util/chunk/chunk_test.go @@ -545,7 +545,7 @@ func TestGetDecimalDatum(t *testing.T) { decType.SetFlen(4) decType.SetDecimal(2) sc := stmtctx.NewStmtCtx() - decDatum, err := datum.ConvertTo(sc, decType) + decDatum, err := datum.ConvertTo(sc.TypeCtx(), decType) require.NoError(t, err) chk := NewChunkWithCapacity([]*types.FieldType{decType}, 32) diff --git a/pkg/util/chunk/mutrow_test.go b/pkg/util/chunk/mutrow_test.go index b09663f0084d8..3e3a2c902326c 100644 --- a/pkg/util/chunk/mutrow_test.go +++ b/pkg/util/chunk/mutrow_test.go @@ -34,7 +34,7 @@ func TestMutRow(t *testing.T) { val := zeroValForType(allTypes[i]) d := row.GetDatum(i, allTypes[i]) d2 := types.NewDatum(val) - cmp, err := d.Compare(sc, &d2, collate.GetCollator(allTypes[i].GetCollate())) + cmp, err := d.Compare(sc.TypeCtx(), &d2, collate.GetCollator(allTypes[i].GetCollate())) require.NoError(t, err) require.Equal(t, 0, cmp) } diff --git a/pkg/util/codec/codec_test.go b/pkg/util/codec/codec_test.go index 4b636c3087e56..b223938662c23 100644 --- a/pkg/util/codec/codec_test.go +++ b/pkg/util/codec/codec_test.go @@ -963,7 +963,7 @@ func TestDecodeOneToChunk(t *testing.T) { require.True(t, expect.IsNull()) } else { if got.Kind() != types.KindMysqlDecimal { - cmp, err := got.Compare(sc, &expect, collate.GetCollator(tp.GetCollate())) + cmp, err := got.Compare(sc.TypeCtx(), &expect, collate.GetCollator(tp.GetCollate())) require.NoError(t, err) require.Equalf(t, 0, cmp, "expect: %v, got %v", expect, got) } else { @@ -1090,7 +1090,7 @@ func TestDecodeRange(t *testing.T) { datums1, _, err := DecodeRange(rowData, len(datums), nil, nil) require.NoError(t, err) for i, datum := range datums1 { - cmp, err := datum.Compare(nil, &datums[i], collate.GetBinaryCollator()) + cmp, err := datum.Compare(types.DefaultStmtNoWarningContext, &datums[i], collate.GetBinaryCollator()) require.NoError(t, err) require.Equal(t, 0, cmp) } diff --git a/pkg/util/profile/flamegraph_test.go b/pkg/util/profile/flamegraph_test.go index 02cb46f8859a5..71ad775b75fda 100644 --- a/pkg/util/profile/flamegraph_test.go +++ b/pkg/util/profile/flamegraph_test.go @@ -92,7 +92,7 @@ func TestProfileToDatum(t *testing.T) { comment = fmt.Sprintf("row %2d, actual (%s), expected (%s)", i, rowStr, expectStr) equal := true for j, r := range row { - v, err := r.Compare(nil, &datums[i][j], collate.GetBinaryCollator()) + v, err := r.Compare(types.DefaultStmtNoWarningContext, &datums[i][j], collate.GetBinaryCollator()) if v != 0 || err != nil { equal = false break diff --git a/pkg/util/ranger/detacher.go b/pkg/util/ranger/detacher.go index 071f78a8925c3..b9fe83d0d5d11 100644 --- a/pkg/util/ranger/detacher.go +++ b/pkg/util/ranger/detacher.go @@ -544,7 +544,7 @@ func allSinglePoints(sc *stmtctx.StatementContext, points []*point) []*point { return nil } // Since the point's collations are equal to the column's collation, we can use any of them. - cmp, err := left.value.Compare(sc, &right.value, collate.GetCollator(left.value.Collation())) + cmp, err := left.value.Compare(sc.TypeCtx(), &right.value, collate.GetCollator(left.value.Collation())) if err != nil || cmp != 0 { return nil } @@ -831,7 +831,7 @@ func isSameValue(sc *stmtctx.StatementContext, lhs, rhs *valueInfo) (bool, error return false, nil } // binary collator may not the best choice, but it can make sure the result is correct. - cmp, err := lhs.value.Compare(sc, rhs.value, collate.GetBinaryCollator()) + cmp, err := lhs.value.Compare(sc.TypeCtx(), rhs.value, collate.GetBinaryCollator()) if err != nil { return false, err } diff --git a/pkg/util/ranger/points.go b/pkg/util/ranger/points.go index 6830b953dbf0c..487c885c749b4 100644 --- a/pkg/util/ranger/points.go +++ b/pkg/util/ranger/points.go @@ -108,7 +108,7 @@ func rangePointLess(sc *stmtctx.StatementContext, a, b *point, collator collate. if a.value.Kind() == types.KindMysqlEnum && b.value.Kind() == types.KindMysqlEnum { return rangePointEnumLess(sc, a, b) } - cmp, err := a.value.Compare(sc, &b.value, collator) + cmp, err := a.value.Compare(sc.TypeCtx(), &b.value, collator) if cmp != 0 { return cmp < 0, nil } @@ -254,11 +254,11 @@ func (r *builder) buildFromBinOp(expr *expression.ScalarFunction) []*point { // If the original value is adjusted, we need to change the condition. // For example, col < 2156. Since the max year is 2155, 2156 is changed to 2155. // col < 2155 is wrong. It should be col <= 2155. - preValue, err1 := value.ToInt64(r.sc) + preValue, err1 := value.ToInt64(r.sc.TypeCtx()) if err1 != nil { return err1 } - *value, err = value.ConvertToMysqlYear(r.sc, col.RetType) + *value, err = value.ConvertToMysqlYear(r.sc.TypeCtx(), col.RetType) if errors.ErrorEqual(err, types.ErrWarnDataOutOfRange) { // Keep err for EQ and NE. switch *op { @@ -473,7 +473,7 @@ func handleEnumFromBinOp(sc *stmtctx.StatementContext, ft *types.FieldType, val } d := types.NewCollateMysqlEnumDatum(tmpEnum, ft.GetCollate()) - if v, err := d.Compare(sc, &val, collate.GetCollator(ft.GetCollate())); err == nil { + if v, err := d.Compare(sc.TypeCtx(), &val, collate.GetCollator(ft.GetCollate())); err == nil { switch op { case ast.LT: if v < 0 { @@ -585,7 +585,7 @@ func (r *builder) buildFromIn(expr *expression.ScalarFunction) ([]*point, bool) err = parseErr } default: - dt, err = dt.ConvertTo(r.sc, expr.GetArgs()[0].GetType()) + dt, err = dt.ConvertTo(r.sc.TypeCtx(), expr.GetArgs()[0].GetType()) } if err != nil { @@ -594,7 +594,7 @@ func (r *builder) buildFromIn(expr *expression.ScalarFunction) ([]*point, bool) } } if expr.GetArgs()[0].GetType().GetType() == mysql.TypeYear { - dt, err = dt.ConvertToMysqlYear(r.sc, expr.GetArgs()[0].GetType()) + dt, err = dt.ConvertToMysqlYear(r.sc.TypeCtx(), expr.GetArgs()[0].GetType()) if err != nil { // in (..., an impossible value (not valid year), ...), the range is empty, so skip it. continue diff --git a/pkg/util/ranger/ranger.go b/pkg/util/ranger/ranger.go index 10fa1cc70f6ae..efe801ef553cb 100644 --- a/pkg/util/ranger/ranger.go +++ b/pkg/util/ranger/ranger.go @@ -158,7 +158,7 @@ func convertPoint(sctx sessionctx.Context, point *point, tp *types.FieldType) (* case types.KindMaxValue, types.KindMinNotNull: return point, nil } - casted, err := point.value.ConvertTo(sc, tp) + casted, err := point.value.ConvertTo(sc.TypeCtx(), tp) if err != nil { if sctx.GetSessionVars().StmtCtx.InPreparedPlanBuilding { // skip plan cache in this case for safety. @@ -196,7 +196,7 @@ func convertPoint(sctx sessionctx.Context, point *point, tp *types.FieldType) (* } //revive:enable:empty-block } - valCmpCasted, err := point.value.Compare(sc, &casted, collate.GetCollator(tp.GetCollate())) + valCmpCasted, err := point.value.Compare(sc.TypeCtx(), &casted, collate.GetCollator(tp.GetCollate())) if err != nil { return point, errors.Trace(err) } @@ -772,7 +772,7 @@ func RangesToString(sc *stmtctx.StatementContext, rans Ranges, colNames []string // sanity check: only last column of the `Range` can be an interval if j < len(ran.LowVal)-1 { - cmp, err := ran.LowVal[j].Compare(sc, &ran.HighVal[j], ran.Collators[j]) + cmp, err := ran.LowVal[j].Compare(sc.TypeCtx(), &ran.HighVal[j], ran.Collators[j]) if err != nil { return "", errors.New("comparing values error: " + err.Error()) } @@ -829,7 +829,7 @@ func RangeSingleColToString(sc *stmtctx.StatementContext, lowVal, highVal types. restoreCtx := format.NewRestoreCtx(format.DefaultRestoreFlags, &buf) // case 2: low value and high value are the same, and low value and high value are both inclusive. - cmp, err := lowVal.Compare(sc, &highVal, collator) + cmp, err := lowVal.Compare(sc.TypeCtx(), &highVal, collator) if err != nil { return "false", errors.Trace(err) } diff --git a/pkg/util/ranger/types.go b/pkg/util/ranger/types.go index c3c8cbaa9ad02..f2b4c22787c3e 100644 --- a/pkg/util/ranger/types.go +++ b/pkg/util/ranger/types.go @@ -109,7 +109,7 @@ func (ran *Range) isPoint(stmtCtx *stmtctx.StatementContext, regardNullAsPoint b if a.Kind() == types.KindMinNotNull || b.Kind() == types.KindMaxValue { return false } - cmp, err := a.Compare(stmtCtx, &b, ran.Collators[i]) + cmp, err := a.Compare(stmtCtx.TypeCtx(), &b, ran.Collators[i]) if err != nil { return false } @@ -217,7 +217,7 @@ func (ran *Range) Encode(sc *stmtctx.StatementContext, lowBuffer, highBuffer []b func (ran *Range) PrefixEqualLen(sc *stmtctx.StatementContext) (int, error) { // Here, len(ran.LowVal) always equal to len(ran.HighVal) for i := 0; i < len(ran.LowVal); i++ { - cmp, err := ran.LowVal[i].Compare(sc, &ran.HighVal[i], ran.Collators[i]) + cmp, err := ran.LowVal[i].Compare(sc.TypeCtx(), &ran.HighVal[i], ran.Collators[i]) if err != nil { return 0, errors.Trace(err) } diff --git a/pkg/util/rowDecoder/decoder_test.go b/pkg/util/rowDecoder/decoder_test.go index 039d929f7359c..78eb540bedbf5 100644 --- a/pkg/util/rowDecoder/decoder_test.go +++ b/pkg/util/rowDecoder/decoder_test.go @@ -125,7 +125,7 @@ func TestRowDecoder(t *testing.T) { for i, col := range cols[:len(cols)-1] { v, ok := r[col.ID] if ok { - equal, err1 := v.Compare(sc, &row.output[i], collate.GetBinaryCollator()) + equal, err1 := v.Compare(sc.TypeCtx(), &row.output[i], collate.GetBinaryCollator()) require.Nil(t, err1) require.Equal(t, 0, equal) } else { @@ -139,7 +139,7 @@ func TestRowDecoder(t *testing.T) { for k, v := range r2 { v1, ok := r[k] require.True(t, ok) - equal, err1 := v.Compare(sc, &v1, collate.GetBinaryCollator()) + equal, err1 := v.Compare(sc.TypeCtx(), &v1, collate.GetBinaryCollator()) require.Nil(t, err1) require.Equal(t, 0, equal) } @@ -197,7 +197,7 @@ func TestClusterIndexRowDecoder(t *testing.T) { for i, col := range cols { v, ok := r[col.ID] require.True(t, ok) - equal, err1 := v.Compare(sc, &row.output[i], collate.GetBinaryCollator()) + equal, err1 := v.Compare(sc.TypeCtx(), &row.output[i], collate.GetBinaryCollator()) require.Nil(t, err1) require.Equal(t, 0, equal) }