diff --git a/br/pkg/lightning/backend/kv/session.go b/br/pkg/lightning/backend/kv/session.go index 6d52a7a4788e3..d6d25625de052 100644 --- a/br/pkg/lightning/backend/kv/session.go +++ b/br/pkg/lightning/backend/kv/session.go @@ -272,6 +272,11 @@ type planCtxImpl struct { *planctximpl.PlanCtxExtendedImpl } +type exprCtxImpl struct { + *Session + *exprctximpl.ExprCtxExtendedImpl +} + // Session is a trimmed down Session type which only wraps our own trimmed-down // transaction type and provides the session variables to the TiDB library // optimized for Lightning. @@ -280,6 +285,7 @@ type Session struct { planctx.EmptyPlanContextExtended txn transaction Vars *variable.SessionVars + exprCtx *exprCtxImpl planctx *planCtxImpl tblctx *tbctximpl.TableContextImpl // currently, we only set `CommonAddRecordCtx` @@ -342,11 +348,15 @@ func NewSession(options *encode.SessionOptions, logger log.Logger) *Session { } vars.TxnCtx = nil s.Vars = vars + s.exprCtx = &exprCtxImpl{ + Session: s, + ExprCtxExtendedImpl: exprctximpl.NewExprExtendedImpl(s), + } s.planctx = &planCtxImpl{ Session: s, - PlanCtxExtendedImpl: planctximpl.NewPlanCtxExtendedImpl(s, exprctximpl.NewExprExtendedImpl(s)), + PlanCtxExtendedImpl: planctximpl.NewPlanCtxExtendedImpl(s), } - s.tblctx = tbctximpl.NewTableContextImpl(s, s.planctx) + s.tblctx = tbctximpl.NewTableContextImpl(s, s.exprCtx) s.txn.kvPairs = &Pairs{} return s @@ -381,7 +391,7 @@ func (se *Session) GetPlanCtx() planctx.PlanContext { // GetExprCtx returns the expression context of the session. func (se *Session) GetExprCtx() exprctx.BuildContext { - return se.planctx + return se.exprCtx } // GetTableCtx returns the table.MutateContext diff --git a/pkg/executor/aggfuncs/func_group_concat_test.go b/pkg/executor/aggfuncs/func_group_concat_test.go index f2d472d82508e..af945cd22dfed 100644 --- a/pkg/executor/aggfuncs/func_group_concat_test.go +++ b/pkg/executor/aggfuncs/func_group_concat_test.go @@ -129,7 +129,7 @@ func groupConcatOrderMultiArgsUpdateMemDeltaGens(ctx sessionctx.Context, srcChk } memDelta := int64(buffer.Cap() - oldMemSize) for _, byItem := range byItems { - fdt, _ := byItem.Expr.Eval(ctx.GetPlanCtx(), row) + fdt, _ := byItem.Expr.Eval(ctx.GetExprCtx(), row) datumMem := aggfuncs.GetDatumMemSize(&fdt) memDelta += datumMem } diff --git a/pkg/executor/merge_join.go b/pkg/executor/merge_join.go index fc7e2a3f0cc52..d874eb02d5420 100644 --- a/pkg/executor/merge_join.go +++ b/pkg/executor/merge_join.go @@ -401,7 +401,7 @@ func (e *MergeJoinExec) compare(outerRow, innerRow chunk.Row) (int, error) { outerJoinKeys := e.outerTable.joinKeys innerJoinKeys := e.innerTable.joinKeys for i := range outerJoinKeys { - cmp, _, err := e.compareFuncs[i](e.Ctx().GetPlanCtx(), outerJoinKeys[i], innerJoinKeys[i], outerRow, innerRow) + cmp, _, err := e.compareFuncs[i](e.Ctx().GetExprCtx(), outerJoinKeys[i], innerJoinKeys[i], outerRow, innerRow) if err != nil { return 0, err } diff --git a/pkg/planner/cardinality/selectivity.go b/pkg/planner/cardinality/selectivity.go index 1bb595b233daa..a811d97191d45 100644 --- a/pkg/planner/cardinality/selectivity.go +++ b/pkg/planner/cardinality/selectivity.go @@ -81,7 +81,7 @@ func Selectivity( ret = pseudoSelectivity(coll, exprs) if sc.EnableOptimizerCETrace { ceTraceExpr(ctx, tableID, "Table Stats-Pseudo-Expression", - expression.ComposeCNFCondition(ctx, exprs...), ret*float64(coll.RealtimeCount)) + expression.ComposeCNFCondition(ctx.GetExprCtx(), exprs...), ret*float64(coll.RealtimeCount)) } return ret, nil, nil } @@ -229,7 +229,7 @@ func Selectivity( curExpr = append(curExpr, remainedExprs[i]) } } - expr := expression.ComposeCNFCondition(ctx, curExpr...) + expr := expression.ComposeCNFCondition(ctx.GetExprCtx(), curExpr...) ceTraceExpr(ctx, tableID, "Table Stats-Expression-CNF", expr, ret*float64(coll.RealtimeCount)) } else if sc.EnableOptimizerDebugTrace { var strs []string @@ -286,7 +286,7 @@ func Selectivity( // Try to cover remaining Constants for i, c := range notCoveredConstants { - if expression.MaybeOverOptimized4PlanCache(ctx, []expression.Expression{c}) { + if expression.MaybeOverOptimized4PlanCache(ctx.GetExprCtx(), []expression.Expression{c}) { continue } if c.Value.IsNull() { @@ -366,7 +366,7 @@ OUTER: if sc.EnableOptimizerCETrace { // Tracing for the expression estimation results after applying the DNF estimation result. curExpr = append(curExpr, remainedExprs[i]) - expr := expression.ComposeCNFCondition(ctx, curExpr...) + expr := expression.ComposeCNFCondition(ctx.GetExprCtx(), curExpr...) ceTraceExpr(ctx, tableID, "Table Stats-Expression-CNF", expr, ret*float64(coll.RealtimeCount)) } } @@ -428,7 +428,7 @@ OUTER: if sc.EnableOptimizerCETrace { // Tracing for the expression estimation results after applying the default selectivity. - totalExpr := expression.ComposeCNFCondition(ctx, remainedExprs...) + totalExpr := expression.ComposeCNFCondition(ctx.GetExprCtx(), remainedExprs...) ceTraceExpr(ctx, tableID, "Table Stats-Expression-CNF", totalExpr, ret*float64(coll.RealtimeCount)) } return ret, nodes, nil @@ -668,7 +668,7 @@ func findPrefixOfIndexByCol(ctx context.PlanContext, cols []*expression.Column, idLoop: for _, idCol := range idxCols { for _, col := range cols { - if col.EqualByExprAndID(ctx, idCol) { + if col.EqualByExprAndID(ctx.GetExprCtx(), idCol) { retCols = append(retCols, col) continue idLoop } @@ -715,7 +715,7 @@ func getMaskAndRanges(ctx context.PlanContext, exprs []expression.Expression, ra } for i := range exprs { for j := range accessConds { - if exprs[i].Equal(ctx, accessConds[j]) { + if exprs[i].Equal(ctx.GetExprCtx(), accessConds[j]) { mask |= 1 << uint64(i) break } @@ -745,7 +745,7 @@ func getMaskAndSelectivityForMVIndex( var mask int64 for i := range exprs { for _, accessCond := range accessConds { - if exprs[i].Equal(ctx, accessCond) { + if exprs[i].Equal(ctx.GetExprCtx(), accessCond) { mask |= 1 << uint64(i) break } @@ -841,7 +841,7 @@ func GetSelectivityByFilter(sctx context.PlanContext, coll *statistics.HistColl, } c.AppendDatum(0, &val) } - selected, err = expression.VectorizedFilter(sctx, filters, chunk.NewIterator4Chunk(c), selected) + selected, err = expression.VectorizedFilter(sctx.GetExprCtx(), filters, chunk.NewIterator4Chunk(c), selected) if err != nil { return false, 0, err } @@ -858,7 +858,7 @@ func GetSelectivityByFilter(sctx context.PlanContext, coll *statistics.HistColl, // The buckets lower bounds are used as random samples and are regarded equally. if hist != nil && histTotalCnt > 0 { selected = selected[:0] - selected, err = expression.VectorizedFilter(sctx, filters, chunk.NewIterator4Chunk(hist.Bounds), selected) + selected, err = expression.VectorizedFilter(sctx.GetExprCtx(), filters, chunk.NewIterator4Chunk(hist.Bounds), selected) if err != nil { return false, 0, err } @@ -892,7 +892,7 @@ func GetSelectivityByFilter(sctx context.PlanContext, coll *statistics.HistColl, c.Reset() c.AppendNull(0) selected = selected[:0] - selected, err = expression.VectorizedFilter(sctx, filters, chunk.NewIterator4Chunk(c), selected) + selected, err = expression.VectorizedFilter(sctx.GetExprCtx(), filters, chunk.NewIterator4Chunk(c), selected) if err != nil || len(selected) != 1 || !selected[0] { nullSel = 0 } else { diff --git a/pkg/planner/cardinality/trace.go b/pkg/planner/cardinality/trace.go index a067adbf1e34b..07811f3c5acf3 100644 --- a/pkg/planner/cardinality/trace.go +++ b/pkg/planner/cardinality/trace.go @@ -37,7 +37,7 @@ import ( // ceTraceExpr appends an expression and related information into CE trace func ceTraceExpr(sctx context.PlanContext, tableID int64, tp string, expr expression.Expression, rowCount float64) { - exprStr, err := exprToString(sctx, expr) + exprStr, err := exprToString(sctx.GetExprCtx(), expr) if err != nil { logutil.BgLogger().Debug("Failed to trace CE of an expression", zap.String("category", "OptimizerTrace"), zap.Any("expression", expr)) @@ -64,7 +64,7 @@ func ceTraceExpr(sctx context.PlanContext, tableID int64, tp string, expr expres // It may be more appropriate to put this in expression package. But currently we only use it for CE trace, // // and it may not be general enough to handle all possible expressions. So we put it here for now. -func exprToString(ctx context.PlanContext, e expression.Expression) (string, error) { +func exprToString(ctx expression.EvalContext, e expression.Expression) (string, error) { switch expr := e.(type) { case *expression.ScalarFunction: var buffer bytes.Buffer diff --git a/pkg/planner/cascades/transformation_rules.go b/pkg/planner/cascades/transformation_rules.go index 9977527f3f675..d0fbbad5ff7a5 100644 --- a/pkg/planner/cascades/transformation_rules.go +++ b/pkg/planner/cascades/transformation_rules.go @@ -268,7 +268,7 @@ func (*PushSelDownIndexScan) OnTransform(old *memo.ExprIter) (newExprs []*memo.G // or the pushed down conditions are the same with before. sameConds := true for i := range res.AccessConds { - if !res.AccessConds[i].Equal(is.SCtx(), is.AccessConds[i]) { + if !res.AccessConds[i].Equal(is.SCtx().GetExprCtx(), is.AccessConds[i]) { sameConds = false break } @@ -331,7 +331,7 @@ func (*PushSelDownTiKVSingleGather) OnTransform(old *memo.ExprIter) (newExprs [] childGroup := old.Children[0].Children[0].Group var pushed, remained []expression.Expression sctx := sg.SCtx() - pushed, remained = expression.PushDownExprs(sctx, sel.Conditions, sctx.GetClient(), kv.TiKV) + pushed, remained = expression.PushDownExprs(sctx.GetExprCtx(), sel.Conditions, sctx.GetClient(), kv.TiKV) if len(pushed) == 0 { return nil, false, false, nil } @@ -551,7 +551,7 @@ func (*PushSelDownProjection) OnTransform(old *memo.ExprIter) (newExprs []*memo. canNotBePushed := make([]expression.Expression, 0, len(sel.Conditions)) ctx := sel.SCtx() for _, cond := range sel.Conditions { - substituted, hasFailed, newFilter := expression.ColumnSubstituteImpl(ctx, cond, projSchema, proj.Exprs, true) + substituted, hasFailed, newFilter := expression.ColumnSubstituteImpl(ctx.GetExprCtx(), cond, projSchema, proj.Exprs, true) if substituted && !hasFailed && !expression.HasGetSetVarFunc(newFilter) { canBePushed = append(canBePushed, newFilter) } else { @@ -870,8 +870,8 @@ func (*pushDownJoin) predicatePushDown( tempCond = append(tempCond, expression.ScalarFuncs2Exprs(join.EqualConditions)...) tempCond = append(tempCond, join.OtherConditions...) tempCond = append(tempCond, predicates...) - tempCond = expression.ExtractFiltersFromDNFs(sctx, tempCond) - tempCond = expression.PropagateConstant(sctx, tempCond) + tempCond = expression.ExtractFiltersFromDNFs(sctx.GetExprCtx(), tempCond) + tempCond = expression.PropagateConstant(sctx.GetExprCtx(), tempCond) // Return table dual when filter is constant false or null. dual := plannercore.Conds2TableDual(join, tempCond) if dual != nil { @@ -902,9 +902,9 @@ func (*pushDownJoin) predicatePushDown( copy(remainCond, predicates) nullSensitive := join.JoinType == plannercore.AntiLeftOuterSemiJoin || join.JoinType == plannercore.LeftOuterSemiJoin if join.JoinType == plannercore.RightOuterJoin { - joinConds, remainCond = expression.PropConstOverOuterJoin(join.SCtx(), joinConds, remainCond, rightSchema, leftSchema, nullSensitive) + joinConds, remainCond = expression.PropConstOverOuterJoin(join.SCtx().GetExprCtx(), joinConds, remainCond, rightSchema, leftSchema, nullSensitive) } else { - joinConds, remainCond = expression.PropConstOverOuterJoin(join.SCtx(), joinConds, remainCond, leftSchema, rightSchema, nullSensitive) + joinConds, remainCond = expression.PropConstOverOuterJoin(join.SCtx().GetExprCtx(), joinConds, remainCond, leftSchema, rightSchema, nullSensitive) } eq, left, right, other := join.ExtractOnCondition(joinConds, leftSchema, rightSchema, false, false) join.AppendJoinConds(eq, left, right, other) @@ -914,7 +914,7 @@ func (*pushDownJoin) predicatePushDown( return leftCond, rightCond, remainCond, dual } if join.JoinType == plannercore.RightOuterJoin { - remainCond = expression.ExtractFiltersFromDNFs(join.SCtx(), remainCond) + remainCond = expression.ExtractFiltersFromDNFs(join.SCtx().GetExprCtx(), remainCond) // Only derive right where condition, because left where condition cannot be pushed down equalCond, leftPushCond, rightPushCond, otherCond = join.ExtractOnCondition(remainCond, leftSchema, rightSchema, false, true) rightCond = rightPushCond @@ -925,7 +925,7 @@ func (*pushDownJoin) predicatePushDown( remainCond = append(expression.ScalarFuncs2Exprs(equalCond), otherCond...) remainCond = append(remainCond, leftPushCond...) // nozero } else { - remainCond = expression.ExtractFiltersFromDNFs(join.SCtx(), remainCond) + remainCond = expression.ExtractFiltersFromDNFs(join.SCtx().GetExprCtx(), remainCond) // Only derive left where condition, because right where condition cannot be pushed down equalCond, leftPushCond, rightPushCond, otherCond = join.ExtractOnCondition(remainCond, leftSchema, rightSchema, true, false) leftCond = leftPushCond @@ -1313,7 +1313,7 @@ func (*PushTopNDownProjection) OnTransform(old *memo.ExprIter) (newExprs []*memo newTopN.ByItems = make([]*util.ByItems, 0, len(topN.ByItems)) for _, by := range topN.ByItems { newTopN.ByItems = append(newTopN.ByItems, &util.ByItems{ - Expr: expression.ColumnSubstitute(ctx, by.Expr, old.Children[0].Group.Prop.Schema, proj.Exprs), + Expr: expression.ColumnSubstitute(ctx.GetExprCtx(), by.Expr, old.Children[0].Group.Prop.Schema, proj.Exprs), Desc: by.Desc, }) } @@ -1459,7 +1459,7 @@ func (*MergeAdjacentTopN) Match(expr *memo.ExprIter) bool { return false } for i := 0; i < len(topN.ByItems); i++ { - if !topN.ByItems[i].Equal(topN.SCtx(), child.ByItems[i]) { + if !topN.ByItems[i].Equal(topN.SCtx().GetExprCtx(), child.ByItems[i]) { return false } } @@ -1527,7 +1527,7 @@ func (*MergeAggregationProjection) OnTransform(old *memo.ExprIter) (newExprs []* ctx := oldAgg.SCtx() groupByItems := make([]expression.Expression, len(oldAgg.GroupByItems)) for i, item := range oldAgg.GroupByItems { - groupByItems[i] = expression.ColumnSubstitute(ctx, item, projSchema, proj.Exprs) + groupByItems[i] = expression.ColumnSubstitute(ctx.GetExprCtx(), item, projSchema, proj.Exprs) } aggFuncs := make([]*aggregation.AggFuncDesc, len(oldAgg.AggFuncs)) @@ -1535,7 +1535,7 @@ func (*MergeAggregationProjection) OnTransform(old *memo.ExprIter) (newExprs []* aggFuncs[i] = aggFunc.Clone() newArgs := make([]expression.Expression, len(aggFunc.Args)) for j, arg := range aggFunc.Args { - newArgs[j] = expression.ColumnSubstitute(ctx, arg, projSchema, proj.Exprs) + newArgs[j] = expression.ColumnSubstitute(ctx.GetExprCtx(), arg, projSchema, proj.Exprs) } aggFuncs[i].Args = newArgs } @@ -1609,8 +1609,8 @@ func (r *EliminateSingleMaxMin) OnTransform(old *memo.ExprIter) (newExprs []*mem // If it can be NULL, we need to filter NULL out first. if !mysql.HasNotNullFlag(f.Args[0].GetType().GetFlag()) { sel := plannercore.LogicalSelection{}.Init(ctx, agg.QueryBlockOffset()) - isNullFunc := expression.NewFunctionInternal(ctx, ast.IsNull, types.NewFieldType(mysql.TypeTiny), f.Args[0]) - notNullFunc := expression.NewFunctionInternal(ctx, ast.UnaryNot, types.NewFieldType(mysql.TypeTiny), isNullFunc) + isNullFunc := expression.NewFunctionInternal(ctx.GetExprCtx(), ast.IsNull, types.NewFieldType(mysql.TypeTiny), f.Args[0]) + notNullFunc := expression.NewFunctionInternal(ctx.GetExprCtx(), ast.UnaryNot, types.NewFieldType(mysql.TypeTiny), isNullFunc) sel.Conditions = []expression.Expression{notNullFunc} selExpr := memo.NewGroupExpr(sel) selExpr.SetChildren(childGroup) @@ -2088,14 +2088,14 @@ func (r *TransformAggregateCaseToSelection) transform(agg *plannercore.LogicalAg caseArgsNum := len(caseArgs) // `case when a>0 then null else a end` should be converted to `case when !(a>0) then a else null end`. - var nullFlip = caseArgsNum == 3 && caseArgs[1].Equal(ctx, expression.NewNull()) && !caseArgs[2].Equal(ctx, expression.NewNull()) + var nullFlip = caseArgsNum == 3 && caseArgs[1].Equal(ctx.GetExprCtx(), expression.NewNull()) && !caseArgs[2].Equal(ctx.GetExprCtx(), expression.NewNull()) // `case when a>0 then 0 else a end` should be converted to `case when !(a>0) then a else 0 end`. - var zeroFlip = !nullFlip && caseArgsNum == 3 && caseArgs[1].Equal(ctx, expression.NewZero()) + var zeroFlip = !nullFlip && caseArgsNum == 3 && caseArgs[1].Equal(ctx.GetExprCtx(), expression.NewZero()) var outputIdx int if nullFlip || zeroFlip { outputIdx = 2 - newConditions = []expression.Expression{expression.NewFunctionInternal(ctx, ast.UnaryNot, types.NewFieldType(mysql.TypeTiny), conditionFromCase)} + newConditions = []expression.Expression{expression.NewFunctionInternal(ctx.GetExprCtx(), ast.UnaryNot, types.NewFieldType(mysql.TypeTiny), conditionFromCase)} } else { outputIdx = 1 newConditions = expression.SplitCNFItems(conditionFromCase) @@ -2107,7 +2107,7 @@ func (r *TransformAggregateCaseToSelection) transform(agg *plannercore.LogicalAg // => // newAggFuncDesc: COUNT(DISTINCT y), newCondition: x = 'foo' - if aggFuncName == ast.AggFuncCount && r.isOnlyOneNotNull(ctx, caseArgs, caseArgsNum, outputIdx) { + if aggFuncName == ast.AggFuncCount && r.isOnlyOneNotNull(ctx.GetExprCtx(), caseArgs, caseArgsNum, outputIdx) { newAggFuncDesc := aggFuncDesc.Clone() newAggFuncDesc.Args = []expression.Expression{caseArgs[outputIdx]} return true, newConditions, []*aggregation.AggFuncDesc{newAggFuncDesc} @@ -2123,8 +2123,8 @@ func (r *TransformAggregateCaseToSelection) transform(agg *plannercore.LogicalAg // => newAggFuncDesc: SUM(cnt), newCondition: x = 'foo' switch { - case r.allowsSelection(aggFuncName) && (caseArgsNum == 2 || caseArgs[3-outputIdx].Equal(ctx, expression.NewNull())), // Case A1 - aggFuncName == ast.AggFuncSum && caseArgsNum == 3 && caseArgs[3-outputIdx].Equal(ctx, expression.NewZero()): // Case A2 + case r.allowsSelection(aggFuncName) && (caseArgsNum == 2 || caseArgs[3-outputIdx].Equal(ctx.GetExprCtx(), expression.NewNull())), // Case A1 + aggFuncName == ast.AggFuncSum && caseArgsNum == 3 && caseArgs[3-outputIdx].Equal(ctx.GetExprCtx(), expression.NewZero()): // Case A2 newAggFuncDesc := aggFuncDesc.Clone() newAggFuncDesc.Args = []expression.Expression{caseArgs[outputIdx]} return true, newConditions, []*aggregation.AggFuncDesc{newAggFuncDesc} @@ -2340,7 +2340,7 @@ func (*InjectProjectionBelowAgg) OnTransform(old *memo.ExprIter) (newExprs []*me for _, aggFunc := range agg.AggFuncs { copyFunc := aggFunc.Clone() // WrapCastForAggArgs will modify AggFunc, so we should clone AggFunc. - copyFunc.WrapCastForAggArgs(agg.SCtx()) + copyFunc.WrapCastForAggArgs(agg.SCtx().GetExprCtx()) copyFuncs = append(copyFuncs, copyFunc) for _, arg := range copyFunc.Args { _, isScalarFunc := arg.(*expression.ScalarFunction) @@ -2542,8 +2542,8 @@ func (*MergeAdjacentWindow) Match(expr *memo.ExprIter) bool { // Whether Partition, OrderBy and Frame parts are the same. if !(curWinPlan.EqualPartitionBy(nextWinPlan) && - curWinPlan.EqualOrderBy(ctx, nextWinPlan) && - curWinPlan.EqualFrame(ctx, nextWinPlan)) { + curWinPlan.EqualOrderBy(ctx.GetExprCtx(), nextWinPlan) && + curWinPlan.EqualFrame(ctx.GetExprCtx(), nextWinPlan)) { return false } diff --git a/pkg/planner/context/context.go b/pkg/planner/context/context.go index 5d771cab39acf..4717ce553ac41 100644 --- a/pkg/planner/context/context.go +++ b/pkg/planner/context/context.go @@ -27,11 +27,18 @@ import ( // PlanContext is the context for building plan. type PlanContext interface { - exprctx.BuildContext contextutil.ValueStoreContext tablelock.TableLockReadContext + // GetExprCtx gets the expression context. + GetExprCtx() exprctx.BuildContext + // GetStore returns the store of session. + GetStore() kv.Storage // GetSessionVars gets the session variables. GetSessionVars() *variable.SessionVars + // GetDomainInfoSchema returns the latest information schema in domain + // Different with `domain.InfoSchema()`, the information schema returned by this method + // includes the temporary table definitions stored in session + GetDomainInfoSchema() infoschema.InfoSchemaMetaVersion // GetInfoSchema returns the current infoschema GetInfoSchema() infoschema.InfoSchemaMetaVersion // UpdateColStatsUsage updates the column stats usage. diff --git a/pkg/planner/contextimpl/BUILD.bazel b/pkg/planner/contextimpl/BUILD.bazel index 08778b9480db4..c1035df8e48f9 100644 --- a/pkg/planner/contextimpl/BUILD.bazel +++ b/pkg/planner/contextimpl/BUILD.bazel @@ -6,7 +6,7 @@ go_library( importpath = "github.com/pingcap/tidb/pkg/planner/contextimpl", visibility = ["//visibility:public"], deps = [ - "//pkg/expression/contextimpl", + "//pkg/expression/context", "//pkg/planner/context", "//pkg/sessionctx", "//pkg/sessiontxn", diff --git a/pkg/planner/contextimpl/impl.go b/pkg/planner/contextimpl/impl.go index 76c42128f2e60..04fc7015b5e7f 100644 --- a/pkg/planner/contextimpl/impl.go +++ b/pkg/planner/contextimpl/impl.go @@ -15,7 +15,7 @@ package contextimpl import ( - exprctximpl "github.com/pingcap/tidb/pkg/expression/contextimpl" + exprctx "github.com/pingcap/tidb/pkg/expression/context" "github.com/pingcap/tidb/pkg/planner/context" "github.com/pingcap/tidb/pkg/sessionctx" "github.com/pingcap/tidb/pkg/sessiontxn" @@ -28,13 +28,13 @@ var _ context.PlanContext = struct { // PlanCtxExtendedImpl provides extended method for session context to implement `PlanContext` type PlanCtxExtendedImpl struct { - sctx sessionctx.Context - *exprctximpl.ExprCtxExtendedImpl + sctx sessionctx.Context + exprCtx exprctx.BuildContext } // NewPlanCtxExtendedImpl creates a new PlanCtxExtendedImpl. -func NewPlanCtxExtendedImpl(sctx sessionctx.Context, exprCtx *exprctximpl.ExprCtxExtendedImpl) *PlanCtxExtendedImpl { - return &PlanCtxExtendedImpl{sctx: sctx, ExprCtxExtendedImpl: exprCtx} +func NewPlanCtxExtendedImpl(sctx sessionctx.Context) *PlanCtxExtendedImpl { + return &PlanCtxExtendedImpl{sctx: sctx} } // AdviseTxnWarmup advises the txn to warm up. diff --git a/pkg/planner/core/common_plans.go b/pkg/planner/core/common_plans.go index 4f4291e9699a6..65a1ba944c42f 100644 --- a/pkg/planner/core/common_plans.go +++ b/pkg/planner/core/common_plans.go @@ -198,7 +198,7 @@ type Execute struct { func isGetVarBinaryLiteral(sctx PlanContext, expr expression.Expression) (res bool) { scalarFunc, ok := expr.(*expression.ScalarFunction) if ok && scalarFunc.FuncName.L == ast.GetVar { - name, isNull, err := scalarFunc.GetArgs()[0].EvalString(sctx, chunk.Row{}) + name, isNull, err := scalarFunc.GetArgs()[0].EvalString(sctx.GetExprCtx(), chunk.Row{}) if err != nil || isNull { res = false } else if dt, ok2 := sctx.GetSessionVars().GetUserVarVal(name); ok2 { diff --git a/pkg/planner/core/exhaust_physical_plans.go b/pkg/planner/core/exhaust_physical_plans.go index 8e01d525a39d4..f6fd9c14c66a3 100644 --- a/pkg/planner/core/exhaust_physical_plans.go +++ b/pkg/planner/core/exhaust_physical_plans.go @@ -308,11 +308,11 @@ func (p *LogicalJoin) getEnforcedMergeJoin(prop *property.PhysicalProperty, sche isExist, hasLeftColInProp, hasRightColInProp := false, false, false for joinKeyPos := 0; joinKeyPos < len(leftJoinKeys); joinKeyPos++ { var key *expression.Column - if item.Col.Equal(p.SCtx(), leftJoinKeys[joinKeyPos]) { + if item.Col.Equal(p.SCtx().GetExprCtx(), leftJoinKeys[joinKeyPos]) { key = leftJoinKeys[joinKeyPos] hasLeftColInProp = true } - if item.Col.Equal(p.SCtx(), rightJoinKeys[joinKeyPos]) { + if item.Col.Equal(p.SCtx().GetExprCtx(), rightJoinKeys[joinKeyPos]) { key = rightJoinKeys[joinKeyPos] hasRightColInProp = true } @@ -378,7 +378,7 @@ func (p *LogicalJoin) getEnforcedMergeJoin(prop *property.PhysicalProperty, sche func (p *PhysicalMergeJoin) initCompareFuncs() { p.CompareFuncs = make([]expression.CompareFunc, 0, len(p.LeftJoinKeys)) for i := range p.LeftJoinKeys { - p.CompareFuncs = append(p.CompareFuncs, expression.GetCmpFunction(p.SCtx(), p.LeftJoinKeys[i], p.RightJoinKeys[i])) + p.CompareFuncs = append(p.CompareFuncs, expression.GetCmpFunction(p.SCtx().GetExprCtx(), p.LeftJoinKeys[i], p.RightJoinKeys[i])) } } @@ -672,8 +672,8 @@ func (p *LogicalJoin) constructIndexMergeJoin( if isOuterKeysPrefix && !prop.SortItems[i].Col.EqualColumn(join.OuterJoinKeys[keyOff2KeyOffOrderByIdx[i]]) { isOuterKeysPrefix = false } - compareFuncs = append(compareFuncs, expression.GetCmpFunction(p.SCtx(), join.OuterJoinKeys[i], join.InnerJoinKeys[i])) - outerCompareFuncs = append(outerCompareFuncs, expression.GetCmpFunction(p.SCtx(), join.OuterJoinKeys[i], join.OuterJoinKeys[i])) + compareFuncs = append(compareFuncs, expression.GetCmpFunction(p.SCtx().GetExprCtx(), join.OuterJoinKeys[i], join.InnerJoinKeys[i])) + outerCompareFuncs = append(outerCompareFuncs, expression.GetCmpFunction(p.SCtx().GetExprCtx(), join.OuterJoinKeys[i], join.OuterJoinKeys[i])) } // canKeepOuterOrder means whether the prop items are the prefix of the outer join keys. canKeepOuterOrder := len(prop.SortItems) <= len(join.OuterJoinKeys) @@ -1448,13 +1448,14 @@ func (cwc *ColWithCmpFuncManager) CompareRow(lhs, rhs chunk.Row) int { // BuildRangesByRow will build range of the given row. It will eval each function's arg then call BuildRange. func (cwc *ColWithCmpFuncManager) BuildRangesByRow(ctx PlanContext, row chunk.Row) ([]*ranger.Range, error) { exprs := make([]expression.Expression, len(cwc.OpType)) + exprCtx := ctx.GetExprCtx() for i, opType := range cwc.OpType { - constantArg, err := cwc.opArg[i].Eval(ctx, row) + constantArg, err := cwc.opArg[i].Eval(exprCtx, row) if err != nil { return nil, err } cwc.TmpConstant[i].Value = constantArg - newExpr, err := expression.NewFunction(ctx, opType, types.NewFieldType(mysql.TypeTiny), cwc.TargetCol, cwc.TmpConstant[i]) + newExpr, err := expression.NewFunction(exprCtx, opType, types.NewFieldType(mysql.TypeTiny), cwc.TargetCol, cwc.TmpConstant[i]) if err != nil { return nil, err } @@ -1529,7 +1530,7 @@ func (ijHelper *indexJoinBuildHelper) resetContextForIndex(innerKeys []*expressi if ijHelper.curIdxOff2KeyOff[i] >= 0 { // Don't use the join columns if their collations are unmatched and the new collation is enabled. if collate.NewCollationEnabled() && types.IsString(idxCol.RetType.GetType()) && types.IsString(outerKeys[ijHelper.curIdxOff2KeyOff[i]].RetType.GetType()) { - et, err := expression.CheckAndDeriveCollationFromExprs(ijHelper.innerPlan.SCtx(), "equal", types.ETInt, idxCol, outerKeys[ijHelper.curIdxOff2KeyOff[i]]) + et, err := expression.CheckAndDeriveCollationFromExprs(ijHelper.innerPlan.SCtx().GetExprCtx(), "equal", types.ETInt, idxCol, outerKeys[ijHelper.curIdxOff2KeyOff[i]]) if err != nil { logutil.BgLogger().Error("Unexpected error happened during constructing index join", zap.Stack("stack")) } @@ -1659,7 +1660,7 @@ func (mr *mutableIndexJoinRange) Rebuild() error { func (ijHelper *indexJoinBuildHelper) createMutableIndexJoinRange(relatedExprs []expression.Expression, ranges []*ranger.Range, path *util.AccessPath, innerKeys, outerKeys []*expression.Column) ranger.MutableRanges { // if the plan-cache is enabled and these ranges depend on some parameters, we have to rebuild these ranges after changing parameters - if expression.MaybeOverOptimized4PlanCache(ijHelper.join.SCtx(), relatedExprs) { + if expression.MaybeOverOptimized4PlanCache(ijHelper.join.SCtx().GetExprCtx(), relatedExprs) { // assume that path, innerKeys and outerKeys will not be modified in the follow-up process return &mutableIndexJoinRange{ ranges: ranges, @@ -2349,7 +2350,7 @@ func canExprsInJoinPushdown(p *LogicalJoin, storeType kv.StoreType) bool { } equalExprs = append(equalExprs, eqCondition) } - ctx := p.SCtx() + ctx := p.SCtx().GetExprCtx() if !expression.CanExprsPushDown(ctx, equalExprs, p.SCtx().GetClient(), storeType) { return false } @@ -2625,14 +2626,15 @@ func (p *LogicalProjection) exhaustPhysicalPlans(prop *property.PhysicalProperty newProps := []*property.PhysicalProperty{newProp} // generate a mpp task candidate if mpp mode is allowed ctx := p.SCtx() + exprCtx := ctx.GetExprCtx() if newProp.TaskTp != property.MppTaskType && ctx.GetSessionVars().IsMPPAllowed() && p.canPushToCop(kv.TiFlash) && - expression.CanExprsPushDown(ctx, p.Exprs, ctx.GetClient(), kv.TiFlash) { + expression.CanExprsPushDown(exprCtx, p.Exprs, ctx.GetClient(), kv.TiFlash) { mppProp := newProp.CloneEssentialFields() mppProp.TaskTp = property.MppTaskType newProps = append(newProps, mppProp) } if newProp.TaskTp != property.CopSingleReadTaskType && ctx.GetSessionVars().AllowProjectionPushDown && p.canPushToCop(kv.TiKV) && - expression.CanExprsPushDown(ctx, p.Exprs, ctx.GetClient(), kv.TiKV) && !expression.ContainVirtualColumn(p.Exprs) { + expression.CanExprsPushDown(exprCtx, p.Exprs, ctx.GetClient(), kv.TiKV) && !expression.ContainVirtualColumn(p.Exprs) { copProp := newProp.CloneEssentialFields() copProp.TaskTp = property.CopSingleReadTaskType newProps = append(newProps, copProp) @@ -2844,8 +2846,9 @@ func (lw *LogicalWindow) tryToGetMppWindows(prop *property.PhysicalProperty) []P { allSupported := true + exprCtx := lw.SCtx().GetExprCtx() for _, windowFunc := range lw.WindowFuncDescs { - if !windowFunc.CanPushDownToTiFlash(lw.SCtx(), lw.SCtx().GetClient()) { + if !windowFunc.CanPushDownToTiFlash(exprCtx, lw.SCtx().GetClient()) { lw.SCtx().GetSessionVars().RaiseWarningWhenMPPEnforced( "MPP mode may be blocked because window function `" + windowFunc.Name + "` or its arguments are not supported now.") allSupported = false @@ -2859,7 +2862,7 @@ func (lw *LogicalWindow) tryToGetMppWindows(prop *property.PhysicalProperty) []P } if lw.Frame != nil && lw.Frame.Type == ast.Ranges { - ctx := lw.SCtx() + ctx := lw.SCtx().GetExprCtx() if _, err := expression.ExpressionsToPBList(ctx, lw.Frame.Start.CalcFuncs, lw.SCtx().GetClient()); err != nil { lw.SCtx().GetSessionVars().RaiseWarningWhenMPPEnforced( "MPP mode may be blocked because window function frame can't be pushed down, because " + err.Error()) @@ -3237,7 +3240,7 @@ func (la *LogicalAggregation) tryToGetMppHashAggs(prop *property.PhysicalPropert for i, agg := range aggFuncs { if agg.Mode == aggregation.FinalMode && agg.Name == ast.AggFuncCount { oldFT := agg.RetTp - aggFuncs[i], _ = aggregation.NewAggFuncDesc(la.SCtx(), ast.AggFuncSum, agg.Args, false) + aggFuncs[i], _ = aggregation.NewAggFuncDesc(la.SCtx().GetExprCtx(), ast.AggFuncSum, agg.Args, false) aggFuncs[i].TypeInfer4FinalCount(oldFT) } } @@ -3474,7 +3477,7 @@ func (p *LogicalSelection) canPushDown(storeTp kv.StoreType) bool { return !expression.ContainVirtualColumn(p.Conditions) && p.canPushToCop(storeTp) && expression.CanExprsPushDown( - p.SCtx(), + p.SCtx().GetExprCtx(), p.Conditions, p.SCtx().GetClient(), storeTp) diff --git a/pkg/planner/core/exhaust_physical_plans_test.go b/pkg/planner/core/exhaust_physical_plans_test.go index e2e0e02265f4b..006da7e8dc7bb 100644 --- a/pkg/planner/core/exhaust_physical_plans_test.go +++ b/pkg/planner/core/exhaust_physical_plans_test.go @@ -186,10 +186,10 @@ func testAnalyzeLookUpFilters(t *testing.T, testCtx *indexJoinContext, testCase ctx.GetSessionVars().RangeMaxSize = testCase.rangeMaxSize dataSourceNode := testCtx.dataSourceNode joinNode := testCtx.joinNode - pushed, err := rewriteSimpleExpr(ctx, testCase.pushedDownConds, dataSourceNode.schema, testCtx.dsNames) + pushed, err := rewriteSimpleExpr(ctx.GetExprCtx(), testCase.pushedDownConds, dataSourceNode.schema, testCtx.dsNames) require.NoError(t, err) dataSourceNode.pushedDownConds = pushed - others, err := rewriteSimpleExpr(ctx, testCase.otherConds, joinNode.schema, testCtx.joinColNames) + others, err := rewriteSimpleExpr(ctx.GetExprCtx(), testCase.otherConds, joinNode.schema, testCtx.joinColNames) require.NoError(t, err) joinNode.OtherConditions = others helper := &indexJoinBuildHelper{join: joinNode, lastColManager: nil, innerPlan: dataSourceNode} diff --git a/pkg/planner/core/explain.go b/pkg/planner/core/explain.go index 96bf83680454c..0d0d0f3e5131e 100644 --- a/pkg/planner/core/explain.go +++ b/pkg/planner/core/explain.go @@ -221,7 +221,7 @@ func (p *PhysicalTableScan) OperatorInfo(normalized bool) string { if normalized { buffer.Write(expression.SortedExplainNormalizedExpressionList(p.LateMaterializationFilterCondition)) } else { - buffer.Write(expression.SortedExplainExpressionList(p.SCtx(), p.LateMaterializationFilterCondition)) + buffer.Write(expression.SortedExplainExpressionList(p.SCtx().GetExprCtx(), p.LateMaterializationFilterCondition)) } } else { buffer.WriteString("empty") @@ -352,12 +352,12 @@ func (p *PhysicalIndexMergeReader) ExplainInfo() string { // ExplainInfo implements Plan interface. func (p *PhysicalUnionScan) ExplainInfo() string { - return string(expression.SortedExplainExpressionList(p.SCtx(), p.Conditions)) + return string(expression.SortedExplainExpressionList(p.SCtx().GetExprCtx(), p.Conditions)) } // ExplainInfo implements Plan interface. func (p *PhysicalSelection) ExplainInfo() string { - exprStr := string(expression.SortedExplainExpressionList(p.SCtx(), p.Conditions)) + exprStr := string(expression.SortedExplainExpressionList(p.SCtx().GetExprCtx(), p.Conditions)) if p.TiFlashFineGrainedShuffleStreamCount > 0 { exprStr += fmt.Sprintf(", stream_count: %d", p.TiFlashFineGrainedShuffleStreamCount) } @@ -424,7 +424,7 @@ func (p *PhysicalTableDual) ExplainInfo() string { // ExplainInfo implements Plan interface. func (p *PhysicalSort) ExplainInfo() string { buffer := bytes.NewBufferString("") - buffer = explainByItems(p.SCtx(), buffer, p.ByItems) + buffer = explainByItems(p.SCtx().GetExprCtx(), buffer, p.ByItems) if p.TiFlashFineGrainedShuffleStreamCount > 0 { fmt.Fprintf(buffer, ", stream_count: %d", p.TiFlashFineGrainedShuffleStreamCount) } @@ -474,7 +474,7 @@ func (p *basePhysicalAgg) explainInfo(normalized bool) string { builder := &strings.Builder{} if len(p.GroupByItems) > 0 { builder.WriteString("group by:") - builder.Write(sortedExplainExpressionList(p.SCtx(), p.GroupByItems)) + builder.Write(sortedExplainExpressionList(p.SCtx().GetExprCtx(), p.GroupByItems)) builder.WriteString(", ") } for i := 0; i < len(p.AggFuncs); i++ { @@ -483,9 +483,9 @@ func (p *basePhysicalAgg) explainInfo(normalized bool) string { if normalized { colName = p.schema.Columns[i].ExplainNormalizedInfo() } else { - colName = p.schema.Columns[i].ExplainInfo(p.SCtx()) + colName = p.schema.Columns[i].ExplainInfo(p.SCtx().GetExprCtx()) } - builder.WriteString(aggregation.ExplainAggFunc(p.SCtx(), p.AggFuncs[i], normalized)) + builder.WriteString(aggregation.ExplainAggFunc(p.SCtx().GetExprCtx(), p.AggFuncs[i], normalized)) builder.WriteString("->") builder.WriteString(colName) if i+1 < len(p.AggFuncs) { @@ -530,36 +530,36 @@ func (p *PhysicalIndexJoin) explainInfo(normalized bool, isIndexMergeJoin bool) } if len(p.OuterJoinKeys) > 0 { buffer.WriteString(", outer key:") - buffer.Write(expression.ExplainColumnList(p.SCtx(), p.OuterJoinKeys)) + buffer.Write(expression.ExplainColumnList(p.SCtx().GetExprCtx(), p.OuterJoinKeys)) } if len(p.InnerJoinKeys) > 0 { buffer.WriteString(", inner key:") - buffer.Write(expression.ExplainColumnList(p.SCtx(), p.InnerJoinKeys)) + buffer.Write(expression.ExplainColumnList(p.SCtx().GetExprCtx(), p.InnerJoinKeys)) } if len(p.OuterHashKeys) > 0 && !isIndexMergeJoin { exprs := make([]expression.Expression, 0, len(p.OuterHashKeys)) for i := range p.OuterHashKeys { - expr, err := expression.NewFunctionBase(p.SCtx(), ast.EQ, types.NewFieldType(mysql.TypeLonglong), p.OuterHashKeys[i], p.InnerHashKeys[i]) + expr, err := expression.NewFunctionBase(p.SCtx().GetExprCtx(), ast.EQ, types.NewFieldType(mysql.TypeLonglong), p.OuterHashKeys[i], p.InnerHashKeys[i]) if err != nil { logutil.BgLogger().Warn("fail to NewFunctionBase", zap.Error(err)) } exprs = append(exprs, expr) } buffer.WriteString(", equal cond:") - buffer.Write(sortedExplainExpressionList(p.SCtx(), exprs)) + buffer.Write(sortedExplainExpressionList(p.SCtx().GetExprCtx(), exprs)) } if len(p.LeftConditions) > 0 { buffer.WriteString(", left cond:") - buffer.Write(sortedExplainExpressionList(p.SCtx(), p.LeftConditions)) + buffer.Write(sortedExplainExpressionList(p.SCtx().GetExprCtx(), p.LeftConditions)) } if len(p.RightConditions) > 0 { buffer.WriteString(", right cond:") - buffer.Write(sortedExplainExpressionList(p.SCtx(), p.RightConditions)) + buffer.Write(sortedExplainExpressionList(p.SCtx().GetExprCtx(), p.RightConditions)) } if len(p.OtherConditions) > 0 { buffer.WriteString(", other cond:") - buffer.Write(sortedExplainExpressionList(p.SCtx(), p.OtherConditions)) + buffer.Write(sortedExplainExpressionList(p.SCtx().GetExprCtx(), p.OtherConditions)) } return buffer.String() } @@ -651,11 +651,11 @@ func (p *PhysicalHashJoin) explainInfo(normalized bool) string { } if len(p.RightConditions) > 0 { buffer.WriteString(", right cond:") - buffer.Write(sortedExplainExpressionList(p.SCtx(), p.RightConditions)) + buffer.Write(sortedExplainExpressionList(p.SCtx().GetExprCtx(), p.RightConditions)) } if len(p.OtherConditions) > 0 { buffer.WriteString(", other cond:") - buffer.Write(sortedExplainExpressionList(p.SCtx(), p.OtherConditions)) + buffer.Write(sortedExplainExpressionList(p.SCtx().GetExprCtx(), p.OtherConditions)) } if p.TiFlashFineGrainedShuffleStreamCount > 0 { fmt.Fprintf(buffer, ", stream_count: %d", p.TiFlashFineGrainedShuffleStreamCount) @@ -690,11 +690,11 @@ func (p *PhysicalMergeJoin) explainInfo(normalized bool) string { buffer := bytes.NewBufferString(p.JoinType.String()) if len(p.LeftJoinKeys) > 0 { fmt.Fprintf(buffer, ", left key:%s", - expression.ExplainColumnList(p.SCtx(), p.LeftJoinKeys)) + expression.ExplainColumnList(p.SCtx().GetExprCtx(), p.LeftJoinKeys)) } if len(p.RightJoinKeys) > 0 { fmt.Fprintf(buffer, ", right key:%s", - expression.ExplainColumnList(p.SCtx(), p.RightJoinKeys)) + expression.ExplainColumnList(p.SCtx().GetExprCtx(), p.RightJoinKeys)) } if len(p.LeftConditions) > 0 { if normalized { @@ -705,11 +705,11 @@ func (p *PhysicalMergeJoin) explainInfo(normalized bool) string { } if len(p.RightConditions) > 0 { fmt.Fprintf(buffer, ", right cond:%s", - sortedExplainExpressionList(p.SCtx(), p.RightConditions)) + sortedExplainExpressionList(p.SCtx().GetExprCtx(), p.RightConditions)) } if len(p.OtherConditions) > 0 { fmt.Fprintf(buffer, ", other cond:%s", - sortedExplainExpressionList(p.SCtx(), p.OtherConditions)) + sortedExplainExpressionList(p.SCtx().GetExprCtx(), p.OtherConditions)) } return buffer.String() } @@ -746,7 +746,7 @@ func (p *PhysicalTopN) ExplainInfo() string { if len(p.GetPartitionBy()) > 0 { buffer.WriteString("order by ") } - buffer = explainByItems(p.SCtx(), buffer, p.ByItems) + buffer = explainByItems(p.SCtx().GetExprCtx(), buffer, p.ByItems) } fmt.Fprintf(buffer, ", offset:%v, count:%v", p.Offset, p.Count) return buffer.String() @@ -778,15 +778,15 @@ func (p *PhysicalWindow) formatFrameBound(buffer *bytes.Buffer, bound *FrameBoun if bound.UnBounded { buffer.WriteString("unbounded") } else if len(bound.CalcFuncs) > 0 { - ctx := p.SCtx() + exprCtx := p.SCtx().GetExprCtx() sf := bound.CalcFuncs[0].(*expression.ScalarFunction) switch sf.FuncName.L { case ast.DateAdd, ast.DateSub: // For `interval '2:30' minute_second`. - fmt.Fprintf(buffer, "interval %s %s", sf.GetArgs()[1].ExplainInfo(ctx), sf.GetArgs()[2].ExplainInfo(ctx)) + fmt.Fprintf(buffer, "interval %s %s", sf.GetArgs()[1].ExplainInfo(exprCtx), sf.GetArgs()[2].ExplainInfo(exprCtx)) case ast.Plus, ast.Minus: // For `1 preceding` of range frame. - fmt.Fprintf(buffer, "%s", sf.GetArgs()[1].ExplainInfo(ctx)) + fmt.Fprintf(buffer, "%s", sf.GetArgs()[1].ExplainInfo(exprCtx)) } } else { fmt.Fprintf(buffer, "%d", bound.Num) @@ -813,12 +813,12 @@ func (p *PhysicalWindow) ExplainInfo() string { buffer.WriteString(" ") } buffer.WriteString("order by ") - ctx := p.SCtx() + exprCtx := p.SCtx().GetExprCtx() for i, item := range p.OrderBy { if item.Desc { - fmt.Fprintf(buffer, "%s desc", item.Col.ExplainInfo(ctx)) + fmt.Fprintf(buffer, "%s desc", item.Col.ExplainInfo(exprCtx)) } else { - fmt.Fprintf(buffer, "%s", item.Col.ExplainInfo(ctx)) + fmt.Fprintf(buffer, "%s", item.Col.ExplainInfo(exprCtx)) } if i+1 < len(p.OrderBy) { @@ -879,15 +879,15 @@ func (p *LogicalJoin) ExplainInfo() string { } if len(p.LeftConditions) > 0 { fmt.Fprintf(buffer, ", left cond:%s", - expression.SortedExplainExpressionList(p.SCtx(), p.LeftConditions)) + expression.SortedExplainExpressionList(p.SCtx().GetExprCtx(), p.LeftConditions)) } if len(p.RightConditions) > 0 { fmt.Fprintf(buffer, ", right cond:%s", - expression.SortedExplainExpressionList(p.SCtx(), p.RightConditions)) + expression.SortedExplainExpressionList(p.SCtx().GetExprCtx(), p.RightConditions)) } if len(p.OtherConditions) > 0 { fmt.Fprintf(buffer, ", other cond:%s", - expression.SortedExplainExpressionList(p.SCtx(), p.OtherConditions)) + expression.SortedExplainExpressionList(p.SCtx().GetExprCtx(), p.OtherConditions)) } return buffer.String() } @@ -897,12 +897,12 @@ func (p *LogicalAggregation) ExplainInfo() string { buffer := bytes.NewBufferString("") if len(p.GroupByItems) > 0 { fmt.Fprintf(buffer, "group by:%s, ", - expression.SortedExplainExpressionList(p.SCtx(), p.GroupByItems)) + expression.SortedExplainExpressionList(p.SCtx().GetExprCtx(), p.GroupByItems)) } if len(p.AggFuncs) > 0 { buffer.WriteString("funcs:") for i, agg := range p.AggFuncs { - buffer.WriteString(aggregation.ExplainAggFunc(p.SCtx(), agg, false)) + buffer.WriteString(aggregation.ExplainAggFunc(p.SCtx().GetExprCtx(), agg, false)) if i+1 < len(p.AggFuncs) { buffer.WriteString(", ") } @@ -918,7 +918,7 @@ func (p *LogicalProjection) ExplainInfo() string { // ExplainInfo implements Plan interface. func (p *LogicalSelection) ExplainInfo() string { - return string(expression.SortedExplainExpressionList(p.SCtx(), p.Conditions)) + return string(expression.SortedExplainExpressionList(p.SCtx().GetExprCtx(), p.Conditions)) } // ExplainInfo implements Plan interface. @@ -966,7 +966,7 @@ func (p *PhysicalExchangeSender) ExplainInfo() string { fmt.Fprintf(buffer, ", Compression: %s", p.CompressionMode.Name()) } if p.ExchangeType == tipb.ExchangeType_Hash { - fmt.Fprintf(buffer, ", Hash Cols: %s", property.ExplainColumnList(p.SCtx(), p.HashCols)) + fmt.Fprintf(buffer, ", Hash Cols: %s", property.ExplainColumnList(p.SCtx().GetExprCtx(), p.HashCols)) } if len(p.Tasks) > 0 { fmt.Fprintf(buffer, ", tasks: [") @@ -996,7 +996,7 @@ func (p *PhysicalExchangeReceiver) ExplainInfo() (res string) { func (p *LogicalUnionScan) ExplainInfo() string { buffer := bytes.NewBufferString("") fmt.Fprintf(buffer, "conds:%s", - expression.SortedExplainExpressionList(p.SCtx(), p.conditions)) + expression.SortedExplainExpressionList(p.SCtx().GetExprCtx(), p.conditions)) fmt.Fprintf(buffer, ", handle:%s", p.handleCols) return buffer.String() } @@ -1034,7 +1034,7 @@ func explainNormalizedByItems(buffer *bytes.Buffer, byItems []*util.ByItems) *by // ExplainInfo implements Plan interface. func (p *LogicalSort) ExplainInfo() string { buffer := bytes.NewBufferString("") - return explainByItems(p.SCtx(), buffer, p.ByItems).String() + return explainByItems(p.SCtx().GetExprCtx(), buffer, p.ByItems).String() } // ExplainInfo implements Plan interface. @@ -1044,7 +1044,7 @@ func (lt *LogicalTopN) ExplainInfo() string { if len(lt.GetPartitionBy()) > 0 && len(lt.ByItems) > 0 { buffer.WriteString("order by ") } - buffer = explainByItems(lt.SCtx(), buffer, lt.ByItems) + buffer = explainByItems(lt.SCtx().GetExprCtx(), buffer, lt.ByItems) fmt.Fprintf(buffer, ", offset:%v, count:%v", lt.Offset, lt.Count) return buffer.String() } diff --git a/pkg/planner/core/expression_rewriter.go b/pkg/planner/core/expression_rewriter.go index dffdf08d13784..4ccf0168f817a 100644 --- a/pkg/planner/core/expression_rewriter.go +++ b/pkg/planner/core/expression_rewriter.go @@ -61,7 +61,7 @@ func evalAstExprWithPlanCtx(sctx PlanContext, expr ast.ExprNode) (types.Datum, e if err != nil { return types.Datum{}, err } - return newExpr.Eval(sctx, chunk.Row{}) + return newExpr.Eval(sctx.GetExprCtx(), chunk.Row{}) } // evalAstExpr evaluates ast expression directly. @@ -249,7 +249,7 @@ func (b *PlanBuilder) getExpressionRewriter(ctx context.Context, p LogicalPlan) if len(b.rewriterPool) < b.rewriterCounter { rewriter = &expressionRewriter{ - sctx: b.ctx, ctx: ctx, + sctx: b.ctx.GetExprCtx(), ctx: ctx, planCtx: &exprRewriterPlanCtx{plan: p, builder: b, rollExpand: b.currentBlockExpand}, } rewriter.sctx.SetValue(expression.TiDBDecodeKeyFunctionKey, decodeKeyFromString) @@ -797,7 +797,7 @@ func (er *expressionRewriter) handleOtherComparableSubq(planCtx *exprRewriterPla if useMin { funcName = ast.AggFuncMin } - funcMaxOrMin, err := aggregation.NewAggFuncDesc(planCtx.builder.ctx, funcName, []expression.Expression{rexpr}, false) + funcMaxOrMin, err := aggregation.NewAggFuncDesc(planCtx.builder.ctx.GetExprCtx(), funcName, []expression.Expression{rexpr}, false) if err != nil { er.err = err return @@ -825,7 +825,7 @@ func (er *expressionRewriter) buildQuantifierPlan(planCtx *exprRewriterPlanCtx, innerIsNull := expression.NewFunctionInternal(er.sctx, ast.IsNull, types.NewFieldType(mysql.TypeTiny), rexpr) outerIsNull := expression.NewFunctionInternal(er.sctx, ast.IsNull, types.NewFieldType(mysql.TypeTiny), lexpr) - funcSum, err := aggregation.NewAggFuncDesc(planCtx.builder.ctx, ast.AggFuncSum, []expression.Expression{innerIsNull}, false) + funcSum, err := aggregation.NewAggFuncDesc(planCtx.builder.ctx.GetExprCtx(), ast.AggFuncSum, []expression.Expression{innerIsNull}, false) if err != nil { er.err = err return @@ -839,7 +839,7 @@ func (er *expressionRewriter) buildQuantifierPlan(planCtx *exprRewriterPlanCtx, innerHasNull := expression.NewFunctionInternal(er.sctx, ast.NE, types.NewFieldType(mysql.TypeTiny), colSum, expression.NewZero()) // Build `count(1)` aggregation to check if subquery is empty. - funcCount, err := aggregation.NewAggFuncDesc(planCtx.builder.ctx, ast.AggFuncCount, []expression.Expression{expression.NewOne()}, false) + funcCount, err := aggregation.NewAggFuncDesc(planCtx.builder.ctx.GetExprCtx(), ast.AggFuncCount, []expression.Expression{expression.NewOne()}, false) if err != nil { er.err = err return @@ -906,14 +906,15 @@ func (er *expressionRewriter) buildQuantifierPlan(planCtx *exprRewriterPlanCtx, func (er *expressionRewriter) handleNEAny(planCtx *exprRewriterPlanCtx, lexpr, rexpr expression.Expression, np LogicalPlan, markNoDecorrelate bool) { intest.AssertNotNil(planCtx) sctx := planCtx.builder.ctx + exprCtx := sctx.GetExprCtx() // If there is NULL in s.id column, s.id should be the value that isn't null in condition t.id != s.id. // So use function max to filter NULL. - maxFunc, err := aggregation.NewAggFuncDesc(sctx, ast.AggFuncMax, []expression.Expression{rexpr}, false) + maxFunc, err := aggregation.NewAggFuncDesc(exprCtx, ast.AggFuncMax, []expression.Expression{rexpr}, false) if err != nil { er.err = err return } - countFunc, err := aggregation.NewAggFuncDesc(sctx, ast.AggFuncCount, []expression.Expression{rexpr}, true) + countFunc, err := aggregation.NewAggFuncDesc(exprCtx, ast.AggFuncCount, []expression.Expression{rexpr}, true) if err != nil { er.err = err return @@ -948,12 +949,13 @@ func (er *expressionRewriter) handleNEAny(planCtx *exprRewriterPlanCtx, lexpr, r func (er *expressionRewriter) handleEQAll(planCtx *exprRewriterPlanCtx, lexpr, rexpr expression.Expression, np LogicalPlan, markNoDecorrelate bool) { intest.AssertNotNil(planCtx) sctx := planCtx.builder.ctx - firstRowFunc, err := aggregation.NewAggFuncDesc(sctx, ast.AggFuncFirstRow, []expression.Expression{rexpr}, false) + exprCtx := sctx.GetExprCtx() + firstRowFunc, err := aggregation.NewAggFuncDesc(exprCtx, ast.AggFuncFirstRow, []expression.Expression{rexpr}, false) if err != nil { er.err = err return } - countFunc, err := aggregation.NewAggFuncDesc(sctx, ast.AggFuncCount, []expression.Expression{rexpr}, true) + countFunc, err := aggregation.NewAggFuncDesc(exprCtx, ast.AggFuncCount, []expression.Expression{rexpr}, true) if err != nil { er.err = err return @@ -1058,7 +1060,7 @@ func (er *expressionRewriter) handleExistSubquery(ctx context.Context, planCtx * scalarSubQ.SetCoercibility(np.Schema().Columns[0].Coercibility()) b.ctx.GetSessionVars().RegisterScalarSubQ(subqueryCtx) if v.Not { - notWrapped, err := expression.NewFunction(b.ctx, ast.UnaryNot, types.NewFieldType(mysql.TypeTiny), scalarSubQ) + notWrapped, err := expression.NewFunction(b.ctx.GetExprCtx(), ast.UnaryNot, types.NewFieldType(mysql.TypeTiny), scalarSubQ) if err != nil { er.err = err return v, true diff --git a/pkg/planner/core/find_best_task.go b/pkg/planner/core/find_best_task.go index 16b6aafc50476..ed1d2610038a0 100644 --- a/pkg/planner/core/find_best_task.go +++ b/pkg/planner/core/find_best_task.go @@ -700,7 +700,7 @@ func (p *LogicalMemTable) findBestTask(prop *property.PhysicalProperty, planCoun func (ds *DataSource) tryToGetDualTask() (task, error) { for _, cond := range ds.pushedDownConds { if con, ok := cond.(*expression.Constant); ok && con.DeferredExpr == nil && con.ParamMarker == nil { - result, _, err := expression.EvalBool(ds.SCtx(), []expression.Expression{cond}, chunk.Row{}) + result, _, err := expression.EvalBool(ds.SCtx().GetExprCtx(), []expression.Expression{cond}, chunk.Row{}) if err != nil { return nil, err } @@ -859,15 +859,15 @@ func (ds *DataSource) isMatchPropForIndexMerge(path *util.AccessPath, prop *prop func (ds *DataSource) getTableCandidate(path *util.AccessPath, prop *property.PhysicalProperty) *candidatePath { candidate := &candidatePath{path: path} candidate.isMatchProp = ds.isMatchProp(path, prop) - candidate.accessCondsColMap = util.ExtractCol2Len(ds.SCtx(), path.AccessConds, nil, nil) + candidate.accessCondsColMap = util.ExtractCol2Len(ds.SCtx().GetExprCtx(), path.AccessConds, nil, nil) return candidate } func (ds *DataSource) getIndexCandidate(path *util.AccessPath, prop *property.PhysicalProperty) *candidatePath { candidate := &candidatePath{path: path} candidate.isMatchProp = ds.isMatchProp(path, prop) - candidate.accessCondsColMap = util.ExtractCol2Len(ds.SCtx(), path.AccessConds, path.IdxCols, path.IdxColLens) - candidate.indexCondsColMap = util.ExtractCol2Len(ds.SCtx(), append(path.AccessConds, path.IndexFilters...), path.FullIdxCols, path.FullIdxColLens) + candidate.accessCondsColMap = util.ExtractCol2Len(ds.SCtx().GetExprCtx(), path.AccessConds, path.IdxCols, path.IdxColLens) + candidate.indexCondsColMap = util.ExtractCol2Len(ds.SCtx().GetExprCtx(), append(path.AccessConds, path.IndexFilters...), path.FullIdxCols, path.FullIdxColLens) return candidate } @@ -1161,7 +1161,7 @@ func (ds *DataSource) findBestTask(prop *property.PhysicalProperty, planCounter // if we already know the range of the scan is empty, just return a TableDual if len(path.Ranges) == 0 { // We should uncache the tableDual plan. - if expression.MaybeOverOptimized4PlanCache(ds.SCtx(), path.AccessConds) { + if expression.MaybeOverOptimized4PlanCache(ds.SCtx().GetExprCtx(), path.AccessConds) { ds.SCtx().GetSessionVars().StmtCtx.SetSkipPlanCache(errors.NewNoStackError("get a TableDual plan")) } dual := PhysicalTableDual{}.Init(ds.SCtx(), ds.StatsInfo(), ds.QueryBlockOffset()) @@ -1240,7 +1240,7 @@ func (ds *DataSource) findBestTask(prop *property.PhysicalProperty, planCounter // Batch/PointGet plans may be over-optimized, like `a>=1(?) and a<=1(?)` --> `a=1` --> PointGet(a=1). // For safety, prevent these plans from the plan cache here. - if !pointGetTask.invalid() && expression.MaybeOverOptimized4PlanCache(ds.SCtx(), candidate.path.AccessConds) && !isSafePointGetPath4PlanCache(ds.SCtx(), candidate.path) { + if !pointGetTask.invalid() && expression.MaybeOverOptimized4PlanCache(ds.SCtx().GetExprCtx(), candidate.path.AccessConds) && !isSafePointGetPath4PlanCache(ds.SCtx(), candidate.path) { ds.SCtx().GetSessionVars().StmtCtx.SetSkipPlanCache(errors.NewNoStackError("Batch/PointGet plans may be over-optimized")) } @@ -1352,7 +1352,7 @@ func (ds *DataSource) convertToIndexMergeScan(prop *property.PhysicalProperty, c tblColHists: ds.TblColHists, } cop.physPlanPartInfo = PhysPlanPartInfo{ - PruningConds: pushDownNot(ds.SCtx(), ds.allConds), + PruningConds: pushDownNot(ds.SCtx().GetExprCtx(), ds.allConds), PartitionNames: ds.partitionNames, Columns: ds.TblCols, ColumnNames: ds.names, @@ -1554,7 +1554,7 @@ func (ds *DataSource) buildIndexMergeTableScan(tableFilters []expression.Express } var currentTopPlan PhysicalPlan = ts if len(tableFilters) > 0 { - pushedFilters, remainingFilters := extractFiltersForIndexMerge(ds.SCtx(), ds.SCtx().GetClient(), tableFilters) + pushedFilters, remainingFilters := extractFiltersForIndexMerge(ds.SCtx().GetExprCtx(), ds.SCtx().GetClient(), tableFilters) pushedFilters1, remainingFilters1 := SplitSelCondsWithVirtualColumn(pushedFilters) pushedFilters = pushedFilters1 remainingFilters = append(remainingFilters, remainingFilters1...) @@ -1652,8 +1652,8 @@ func (ds *DataSource) indexCoveringColumn(column *expression.Column, indexColumn if column.ID == model.ExtraHandleID { return true } - coveredByPlainIndex := isIndexColsCoveringCol(ds.SCtx(), column, indexColumns, idxColLens, ignoreLen) - coveredByClusteredIndex := isIndexColsCoveringCol(ds.SCtx(), column, ds.commonHandleCols, ds.commonHandleLens, ignoreLen) + coveredByPlainIndex := isIndexColsCoveringCol(ds.SCtx().GetExprCtx(), column, indexColumns, idxColLens, ignoreLen) + coveredByClusteredIndex := isIndexColsCoveringCol(ds.SCtx().GetExprCtx(), column, ds.commonHandleCols, ds.commonHandleLens, ignoreLen) if !coveredByPlainIndex && !coveredByClusteredIndex { return false } @@ -1763,7 +1763,7 @@ func (ds *DataSource) convertToIndexScan(prop *property.PhysicalProperty, expectCnt: uint64(prop.ExpectedCnt), } cop.physPlanPartInfo = PhysPlanPartInfo{ - PruningConds: pushDownNot(ds.SCtx(), ds.allConds), + PruningConds: pushDownNot(ds.SCtx().GetExprCtx(), ds.allConds), PartitionNames: ds.partitionNames, Columns: ds.TblCols, ColumnNames: ds.names, @@ -1948,10 +1948,10 @@ func (is *PhysicalIndexScan) addPushedDownSelection(copTask *copTask, p *DataSou tableConds, copTask.rootTaskConds = SplitSelCondsWithVirtualColumn(tableConds) var newRootConds []expression.Expression - indexConds, newRootConds = expression.PushDownExprs(is.SCtx(), indexConds, is.SCtx().GetClient(), kv.TiKV) + indexConds, newRootConds = expression.PushDownExprs(is.SCtx().GetExprCtx(), indexConds, is.SCtx().GetClient(), kv.TiKV) copTask.rootTaskConds = append(copTask.rootTaskConds, newRootConds...) - tableConds, newRootConds = expression.PushDownExprs(is.SCtx(), tableConds, is.SCtx().GetClient(), kv.TiKV) + tableConds, newRootConds = expression.PushDownExprs(is.SCtx().GetExprCtx(), tableConds, is.SCtx().GetClient(), kv.TiKV) copTask.rootTaskConds = append(copTask.rootTaskConds, newRootConds...) if indexConds != nil { @@ -2015,7 +2015,7 @@ func matchIndicesProp(sctx PlanContext, idxCols []*expression.Column, colLens [] return false } for i, item := range propItems { - if colLens[i] != types.UnspecifiedLength || !item.Col.EqualByExprAndID(sctx, idxCols[i]) { + if colLens[i] != types.UnspecifiedLength || !item.Col.EqualByExprAndID(sctx.GetExprCtx(), idxCols[i]) { return false } } @@ -2186,7 +2186,7 @@ func (ds *DataSource) convertToTableScan(prop *property.PhysicalProperty, candid tblColHists: ds.TblColHists, } ts.PlanPartInfo = PhysPlanPartInfo{ - PruningConds: pushDownNot(ds.SCtx(), ds.allConds), + PruningConds: pushDownNot(ds.SCtx().GetExprCtx(), ds.allConds), PartitionNames: ds.partitionNames, Columns: ds.TblCols, ColumnNames: ds.names, @@ -2220,7 +2220,7 @@ func (ds *DataSource) convertToTableScan(prop *property.PhysicalProperty, candid tblColHists: ds.TblColHists, } copTask.physPlanPartInfo = PhysPlanPartInfo{ - PruningConds: pushDownNot(ds.SCtx(), ds.allConds), + PruningConds: pushDownNot(ds.SCtx().GetExprCtx(), ds.allConds), PartitionNames: ds.partitionNames, Columns: ds.TblCols, ColumnNames: ds.names, @@ -2451,7 +2451,7 @@ func (ds *DataSource) convertToBatchPointGet(prop *property.PhysicalProperty, ca func (ts *PhysicalTableScan) addPushedDownSelectionToMppTask(mpp *mppTask, stats *property.StatsInfo) *mppTask { filterCondition, rootTaskConds := SplitSelCondsWithVirtualColumn(ts.filterCondition) var newRootConds []expression.Expression - filterCondition, newRootConds = expression.PushDownExprs(ts.SCtx(), filterCondition, ts.SCtx().GetClient(), ts.StoreType) + filterCondition, newRootConds = expression.PushDownExprs(ts.SCtx().GetExprCtx(), filterCondition, ts.SCtx().GetClient(), ts.StoreType) mpp.rootTaskConds = append(rootTaskConds, newRootConds...) ts.filterCondition = filterCondition @@ -2467,7 +2467,7 @@ func (ts *PhysicalTableScan) addPushedDownSelectionToMppTask(mpp *mppTask, stats func (ts *PhysicalTableScan) addPushedDownSelection(copTask *copTask, stats *property.StatsInfo) { ts.filterCondition, copTask.rootTaskConds = SplitSelCondsWithVirtualColumn(ts.filterCondition) var newRootConds []expression.Expression - ts.filterCondition, newRootConds = expression.PushDownExprs(ts.SCtx(), ts.filterCondition, ts.SCtx().GetClient(), ts.StoreType) + ts.filterCondition, newRootConds = expression.PushDownExprs(ts.SCtx().GetExprCtx(), ts.filterCondition, ts.SCtx().GetClient(), ts.StoreType) copTask.rootTaskConds = append(copTask.rootTaskConds, newRootConds...) // Add filter condition to table plan now. diff --git a/pkg/planner/core/indexmerge_path.go b/pkg/planner/core/indexmerge_path.go index b3b2184a72fe3..d0dbc594fc764 100644 --- a/pkg/planner/core/indexmerge_path.go +++ b/pkg/planner/core/indexmerge_path.go @@ -68,7 +68,7 @@ func (ds *DataSource) generateIndexMergePath() error { // We will create new Selection for exprs that cannot be pushed in convertToIndexMergeScan. indexMergeConds := make([]expression.Expression, 0, len(ds.allConds)) for _, expr := range ds.allConds { - indexMergeConds = append(indexMergeConds, expression.PushDownNot(ds.SCtx(), expr)) + indexMergeConds = append(indexMergeConds, expression.PushDownNot(ds.SCtx().GetExprCtx(), expr)) } sessionAndStmtPermission := (ds.SCtx().GetSessionVars().GetEnableIndexMerge() || len(ds.indexMergeHints) > 0) && !stmtCtx.NoIndexMergeHint @@ -142,7 +142,7 @@ func (ds *DataSource) generateNormalIndexPartialPaths4DNF(dnfItems []expression. cnfItems := expression.SplitCNFItems(item) pushedDownCNFItems := make([]expression.Expression, 0, len(cnfItems)) for _, cnfItem := range cnfItems { - if expression.CanExprsPushDown(ds.SCtx(), + if expression.CanExprsPushDown(ds.SCtx().GetExprCtx(), []expression.Expression{cnfItem}, ds.SCtx().GetClient(), kv.TiKV, @@ -181,13 +181,13 @@ func (ds *DataSource) generateNormalIndexPartialPaths4DNF(dnfItems []expression. partialPath.TableFilters = nil } // If any partial path's index filter cannot be pushed to TiKV, we should keep the whole DNF filter. - if len(partialPath.IndexFilters) != 0 && !expression.CanExprsPushDown(ds.SCtx(), partialPath.IndexFilters, ds.SCtx().GetClient(), kv.TiKV) { + if len(partialPath.IndexFilters) != 0 && !expression.CanExprsPushDown(ds.SCtx().GetExprCtx(), partialPath.IndexFilters, ds.SCtx().GetClient(), kv.TiKV) { needSelection = true // Clear IndexFilter, the whole filter will be put in indexMergePath.TableFilters. partialPath.IndexFilters = nil } // Keep this filter as a part of table filters for safety if it has any parameter. - if expression.MaybeOverOptimized4PlanCache(ds.SCtx(), cnfItems) { + if expression.MaybeOverOptimized4PlanCache(ds.SCtx().GetExprCtx(), cnfItems) { needSelection = true } usedMap[offset] = true @@ -214,7 +214,7 @@ func (ds *DataSource) generateIndexMergeOrPaths(filters []expression.Expression) pushedDownCNFItems := make([]expression.Expression, 0, len(cnfItems)) for _, cnfItem := range cnfItems { - if expression.CanExprsPushDown(ds.SCtx(), + if expression.CanExprsPushDown(ds.SCtx().GetExprCtx(), []expression.Expression{cnfItem}, ds.SCtx().GetClient(), kv.TiKV, @@ -259,10 +259,10 @@ func (ds *DataSource) generateIndexMergeOrPaths(filters []expression.Expression) indexCondsForP := p.AccessConds[:] indexCondsForP = append(indexCondsForP, p.IndexFilters...) if len(indexCondsForP) > 0 { - accessConds = append(accessConds, expression.ComposeCNFCondition(ds.SCtx(), indexCondsForP...)) + accessConds = append(accessConds, expression.ComposeCNFCondition(ds.SCtx().GetExprCtx(), indexCondsForP...)) } } - accessDNF := expression.ComposeDNFCondition(ds.SCtx(), accessConds...) + accessDNF := expression.ComposeDNFCondition(ds.SCtx().GetExprCtx(), accessConds...) sel, _, err := cardinality.Selectivity(ds.SCtx(), ds.tableStats.HistColl, []expression.Expression{accessDNF}, nil) if err != nil { logutil.BgLogger().Debug("something wrong happened, use the default selectivity", zap.Error(err)) @@ -442,19 +442,19 @@ func (ds *DataSource) buildIndexMergeOrPath( shouldKeepCurrentFilter = true } // If any partial path's index filter cannot be pushed to TiKV, we should keep the whole DNF filter. - if len(path.IndexFilters) != 0 && !expression.CanExprsPushDown(ds.SCtx(), path.IndexFilters, ds.SCtx().GetClient(), kv.TiKV) { + if len(path.IndexFilters) != 0 && !expression.CanExprsPushDown(ds.SCtx().GetExprCtx(), path.IndexFilters, ds.SCtx().GetClient(), kv.TiKV) { shouldKeepCurrentFilter = true // Clear IndexFilter, the whole filter will be put in indexMergePath.TableFilters. path.IndexFilters = nil } - if len(path.TableFilters) != 0 && !expression.CanExprsPushDown(ds.SCtx(), path.TableFilters, ds.SCtx().GetClient(), kv.TiKV) { + if len(path.TableFilters) != 0 && !expression.CanExprsPushDown(ds.SCtx().GetExprCtx(), path.TableFilters, ds.SCtx().GetClient(), kv.TiKV) { shouldKeepCurrentFilter = true path.TableFilters = nil } } // Keep this filter as a part of table filters for safety if it has any parameter. - if expression.MaybeOverOptimized4PlanCache(ds.SCtx(), filters[current:current+1]) { + if expression.MaybeOverOptimized4PlanCache(ds.SCtx().GetExprCtx(), filters[current:current+1]) { shouldKeepCurrentFilter = true } if shouldKeepCurrentFilter { @@ -583,7 +583,7 @@ func (ds *DataSource) generateIndexMergeAndPaths(normalPathCnt int, usedAccessMa coveredConds = append(coveredConds, path.AccessConds...) for i, cond := range path.IndexFilters { // IndexFilters can be covered by partial path if it can be pushed down to TiKV. - if !expression.CanExprsPushDown(ds.SCtx(), []expression.Expression{cond}, ds.SCtx().GetClient(), kv.TiKV) { + if !expression.CanExprsPushDown(ds.SCtx().GetExprCtx(), []expression.Expression{cond}, ds.SCtx().GetClient(), kv.TiKV) { path.IndexFilters = append(path.IndexFilters[:i], path.IndexFilters[i+1:]...) notCoveredConds = append(notCoveredConds, cond) } else { @@ -624,7 +624,7 @@ func (ds *DataSource) generateIndexMergeAndPaths(normalPathCnt int, usedAccessMa } // Keep these partial filters as a part of table filters for safety if there is any parameter. - if expression.MaybeOverOptimized4PlanCache(ds.SCtx(), partialFilters) { + if expression.MaybeOverOptimized4PlanCache(ds.SCtx().GetExprCtx(), partialFilters) { dedupedFinalFilters = append(dedupedFinalFilters, partialFilters...) } @@ -872,7 +872,7 @@ func (ds *DataSource) generateIndexMerge4NormalIndex(regularPathCount int, index // PushDownExprs() will append extra warnings, which is annoying. So we reset warnings here. warnings := stmtCtx.GetWarnings() extraWarnings := stmtCtx.GetExtraWarnings() - _, remaining := expression.PushDownExprs(ds.SCtx(), indexMergeConds, ds.SCtx().GetClient(), kv.UnSpecified) + _, remaining := expression.PushDownExprs(ds.SCtx().GetExprCtx(), indexMergeConds, ds.SCtx().GetClient(), kv.UnSpecified) stmtCtx.SetWarnings(warnings) stmtCtx.SetExtraWarnings(extraWarnings) if len(remaining) > 0 { @@ -1291,7 +1291,7 @@ func buildPartialPaths4MVIndex( virColVals = append(virColVals, v) case ast.JSONContains: // (json_contains(a->'$.zip', '[1, 2, 3]') isIntersection = true - virColVals, ok = jsonArrayExpr2Exprs(sctx, sf.GetArgs()[1], jsonType) + virColVals, ok = jsonArrayExpr2Exprs(sctx.GetExprCtx(), sf.GetArgs()[1], jsonType) if !ok || len(virColVals) == 0 { // json_contains(JSON, '[]') is TRUE. If the row has an empty array, it'll not exist on multi-valued index, // but the `json_contains(array, '[]')` is still true, so also don't try to scan on the index. @@ -1299,15 +1299,15 @@ func buildPartialPaths4MVIndex( } case ast.JSONOverlaps: // (json_overlaps(a->'$.zip', '[1, 2, 3]') var jsonPathIdx int - if sf.GetArgs()[0].Equal(sctx, targetJSONPath) { + if sf.GetArgs()[0].Equal(sctx.GetExprCtx(), targetJSONPath) { jsonPathIdx = 0 // (json_overlaps(a->'$.zip', '[1, 2, 3]') - } else if sf.GetArgs()[1].Equal(sctx, targetJSONPath) { + } else if sf.GetArgs()[1].Equal(sctx.GetExprCtx(), targetJSONPath) { jsonPathIdx = 1 // (json_overlaps('[1, 2, 3]', a->'$.zip') } else { return nil, false, false, nil } var ok bool - virColVals, ok = jsonArrayExpr2Exprs(sctx, sf.GetArgs()[1-jsonPathIdx], jsonType) + virColVals, ok = jsonArrayExpr2Exprs(sctx.GetExprCtx(), sf.GetArgs()[1-jsonPathIdx], jsonType) if !ok || len(virColVals) == 0 { // forbid empty array for safety return nil, false, false, nil } @@ -1323,7 +1323,7 @@ func buildPartialPaths4MVIndex( for _, v := range virColVals { // rewrite json functions to EQ to calculate range, `(1 member of j)` -> `j=1`. - eq, err := expression.NewFunction(sctx, ast.EQ, types.NewFieldType(mysql.TypeTiny), virCol, v) + eq, err := expression.NewFunction(sctx.GetExprCtx(), ast.EQ, types.NewFieldType(mysql.TypeTiny), virCol, v) if err != nil { return nil, false, false, err } @@ -1587,12 +1587,12 @@ func checkFilter4MVIndexColumn(sctx PlanContext, filter expression.Expression, i } switch sf.FuncName.L { case ast.JSONMemberOf: // (1 member of a) - return targetJSONPath.Equal(sctx, sf.GetArgs()[1]) + return targetJSONPath.Equal(sctx.GetExprCtx(), sf.GetArgs()[1]) case ast.JSONContains: // json_contains(a, '1') - return targetJSONPath.Equal(sctx, sf.GetArgs()[0]) + return targetJSONPath.Equal(sctx.GetExprCtx(), sf.GetArgs()[0]) case ast.JSONOverlaps: // json_overlaps(a, '1') or json_overlaps('1', a) - return targetJSONPath.Equal(sctx, sf.GetArgs()[0]) || - targetJSONPath.Equal(sctx, sf.GetArgs()[1]) + return targetJSONPath.Equal(sctx.GetExprCtx(), sf.GetArgs()[0]) || + targetJSONPath.Equal(sctx.GetExprCtx(), sf.GetArgs()[1]) default: return false } @@ -1615,7 +1615,7 @@ func checkFilter4MVIndexColumn(sctx PlanContext, filter expression.Expression, i if argCol == nil || argConst == nil { return false } - if argCol.Equal(sctx, idxCol) { + if argCol.Equal(sctx.GetExprCtx(), idxCol) { return true } } diff --git a/pkg/planner/core/logical_plan_builder.go b/pkg/planner/core/logical_plan_builder.go index 42ef9ac26f97a..7531022f3cff9 100644 --- a/pkg/planner/core/logical_plan_builder.go +++ b/pkg/planner/core/logical_plan_builder.go @@ -102,7 +102,7 @@ func (a *aggOrderByResolver) Enter(inNode ast.Node) (ast.Node, bool) { func (a *aggOrderByResolver) Leave(inNode ast.Node) (ast.Node, bool) { if v, ok := inNode.(*ast.PositionExpr); ok { - pos, isNull, err := expression.PosFromPositionExpr(a.ctx, v) + pos, isNull, err := expression.PosFromPositionExpr(a.ctx.GetExprCtx(), v) if err != nil { a.err = err } @@ -281,7 +281,7 @@ func (b *PlanBuilder) buildAggregation(ctx context.Context, p LogicalPlan, aggFu p = np newArgList = append(newArgList, newArg) } - newFunc, err := aggregation.NewAggFuncDesc(b.ctx, aggFunc.F, newArgList, aggFunc.Distinct) + newFunc, err := aggregation.NewAggFuncDesc(b.ctx.GetExprCtx(), aggFunc.F, newArgList, aggFunc.Distinct) if err != nil { return nil, nil, err } @@ -313,7 +313,7 @@ func (b *PlanBuilder) buildAggregation(ctx context.Context, p LogicalPlan, aggFu combined := false for j := 0; j < i; j++ { oldFunc := plan4Agg.AggFuncs[aggIndexMap[j]] - if oldFunc.Equal(b.ctx, newFunc) { + if oldFunc.Equal(b.ctx.GetExprCtx(), newFunc) { aggIndexMap[i] = aggIndexMap[j] combined = true if _, ok := correlatedAggMap[aggFunc]; ok { @@ -348,7 +348,7 @@ func (b *PlanBuilder) buildAggregation(ctx context.Context, p LogicalPlan, aggFu } } for i, col := range p.Schema().Columns { - newFunc, err := aggregation.NewAggFuncDesc(b.ctx, ast.AggFuncFirstRow, []expression.Expression{col}, false) + newFunc, err := aggregation.NewAggFuncDesc(b.ctx.GetExprCtx(), ast.AggFuncFirstRow, []expression.Expression{col}, false) if err != nil { return nil, nil, err } @@ -373,7 +373,7 @@ func (b *PlanBuilder) buildAggregation(ctx context.Context, p LogicalPlan, aggFu if p.Schema().Contains(col) { continue } - newFunc, err := aggregation.NewAggFuncDesc(b.ctx, ast.AggFuncFirstRow, []expression.Expression{col}, false) + newFunc, err := aggregation.NewAggFuncDesc(b.ctx.GetExprCtx(), ast.AggFuncFirstRow, []expression.Expression{col}, false) if err != nil { return nil, nil, err } @@ -561,18 +561,18 @@ func (p *LogicalJoin) ExtractOnCondition( if leftCol != nil && rightCol != nil { if deriveLeft { if isNullRejected(ctx, leftSchema, expr) && !mysql.HasNotNullFlag(leftCol.RetType.GetFlag()) { - notNullExpr := expression.BuildNotNullExpr(ctx, leftCol) + notNullExpr := expression.BuildNotNullExpr(ctx.GetExprCtx(), leftCol) leftCond = append(leftCond, notNullExpr) } } if deriveRight { if isNullRejected(ctx, rightSchema, expr) && !mysql.HasNotNullFlag(rightCol.RetType.GetFlag()) { - notNullExpr := expression.BuildNotNullExpr(ctx, rightCol) + notNullExpr := expression.BuildNotNullExpr(ctx.GetExprCtx(), rightCol) rightCond = append(rightCond, notNullExpr) } } if binop.FuncName.L == ast.EQ { - cond := expression.NewFunctionInternal(ctx, ast.EQ, types.NewFieldType(mysql.TypeTiny), arg0, arg1) + cond := expression.NewFunctionInternal(ctx.GetExprCtx(), ast.EQ, types.NewFieldType(mysql.TypeTiny), arg0, arg1) eqCond = append(eqCond, cond.(*expression.ScalarFunction)) continue } @@ -604,13 +604,13 @@ func (p *LogicalJoin) ExtractOnCondition( // `expr AND leftRelaxedCond AND rightRelaxedCond`. Motivation is to push filters down to // children as much as possible. if deriveLeft { - leftRelaxedCond := expression.DeriveRelaxedFiltersFromDNF(ctx, expr, leftSchema) + leftRelaxedCond := expression.DeriveRelaxedFiltersFromDNF(ctx.GetExprCtx(), expr, leftSchema) if leftRelaxedCond != nil { leftCond = append(leftCond, leftRelaxedCond) } } if deriveRight { - rightRelaxedCond := expression.DeriveRelaxedFiltersFromDNF(ctx, expr, rightSchema) + rightRelaxedCond := expression.DeriveRelaxedFiltersFromDNF(ctx.GetExprCtx(), expr, rightSchema) if rightRelaxedCond != nil { rightCond = append(rightCond, rightRelaxedCond) } @@ -1257,7 +1257,7 @@ func (b *PlanBuilder) coalesceCommonColumns(p *LogicalJoin, leftPlan, rightPlan conds := make([]expression.Expression, 0, commonLen) for i := 0; i < commonLen; i++ { lc, rc := lsc.Columns[i], rsc.Columns[i] - cond, err := expression.NewFunction(b.ctx, ast.EQ, types.NewFieldType(mysql.TypeTiny), lc, rc) + cond, err := expression.NewFunction(b.ctx.GetExprCtx(), ast.EQ, types.NewFieldType(mysql.TypeTiny), lc, rc) if err != nil { return err } @@ -1315,7 +1315,7 @@ func (b *PlanBuilder) buildSelection(ctx context.Context, p LogicalPlan, where a cnfItems := expression.SplitCNFItems(expr) for _, item := range cnfItems { if con, ok := item.(*expression.Constant); ok && expression.ConstExprConsiderPlanCache(con, useCache) { - ret, _, err := expression.EvalBool(b.ctx, expression.CNFExprs{con}, chunk.Row{}) + ret, _, err := expression.EvalBool(b.ctx.GetExprCtx(), expression.CNFExprs{con}, chunk.Row{}) if err != nil { return nil, errors.Trace(err) } @@ -1343,7 +1343,7 @@ func (b *PlanBuilder) buildSelection(ctx context.Context, p LogicalPlan, where a tp.SetFlen(mysql.MaxRealWidth) tp.SetDecimal(types.UnspecifiedLength) types.SetBinChsClnFlag(tp) - cnfExpres[i] = expression.TryPushCastIntoControlFunctionForHybridType(b.ctx, expr, tp) + cnfExpres[i] = expression.TryPushCastIntoControlFunctionForHybridType(b.ctx.GetExprCtx(), expr, tp) } } selection.Conditions = cnfExpres @@ -1379,7 +1379,7 @@ func (b *PlanBuilder) buildProjectionFieldNameFromExpressions(_ context.Context, // When used to produce a result set column, NAME_CONST() causes the column to have the given name. // See https://dev.mysql.com/doc/refman/5.7/en/miscellaneous-functions.html#function_name-const for details if isFuncCall && funcCall.FnName.L == ast.NameConst { - if v, err := evalAstExpr(b.ctx, funcCall.Args[0]); err == nil { + if v, err := evalAstExpr(b.ctx.GetExprCtx(), funcCall.Args[0]); err == nil { if s, err := v.ToString(); err == nil { return model.NewCIStr(s), nil } @@ -1874,7 +1874,7 @@ func (b *PlanBuilder) buildDistinct(child LogicalPlan, length int) (*LogicalAggr plan4Agg.PreferAggToCop = hintinfo.PreferAggToCop } for _, col := range child.Schema().Columns { - aggDesc, err := aggregation.NewAggFuncDesc(b.ctx, ast.AggFuncFirstRow, []expression.Expression{col}, false) + aggDesc, err := aggregation.NewAggFuncDesc(b.ctx.GetExprCtx(), ast.AggFuncFirstRow, []expression.Expression{col}, false) if err != nil { return nil, err } @@ -1958,7 +1958,7 @@ func (b *PlanBuilder) buildProjection4Union(_ context.Context, u *LogicalUnionAl childTp := u.children[j].Schema().Columns[i].RetType resultTp = unionJoinFieldType(resultTp, childTp) } - collation, err := expression.CheckAndDeriveCollationFromExprs(b.ctx, "UNION", resultTp.EvalType(), tmpExprs...) + collation, err := expression.CheckAndDeriveCollationFromExprs(b.ctx.GetExprCtx(), "UNION", resultTp.EvalType(), tmpExprs...) if err != nil || collation.Coer == expression.CoercibilityNone { return collate.ErrIllegalMixCollation.GenWithStackByArgs("UNION") } @@ -1981,7 +1981,7 @@ func (b *PlanBuilder) buildProjection4Union(_ context.Context, u *LogicalUnionAl dstType := unionCols[i].RetType srcType := srcCol.RetType if !srcType.Equal(dstType) { - exprs[i] = expression.BuildCastFunction4Union(b.ctx, srcCol, dstType) + exprs[i] = expression.BuildCastFunction4Union(b.ctx.GetExprCtx(), srcCol, dstType) } else { exprs[i] = srcCol } @@ -2105,7 +2105,7 @@ func (b *PlanBuilder) buildSemiJoinForSetOperator( copy(joinPlan.names, leftPlan.OutputNames()) for j := 0; j < len(rightPlan.Schema().Columns); j++ { leftCol, rightCol := leftPlan.Schema().Columns[j], rightPlan.Schema().Columns[j] - eqCond, err := expression.NewFunction(b.ctx, ast.NullEQ, types.NewFieldType(mysql.TypeTiny), leftCol, rightCol) + eqCond, err := expression.NewFunction(b.ctx.GetExprCtx(), ast.NullEQ, types.NewFieldType(mysql.TypeTiny), leftCol, rightCol) if err != nil { return nil, err } @@ -2338,9 +2338,10 @@ func (b *PlanBuilder) checkOrderByInDistinct(byItem *ast.ByItem, idx int, expr e // select distinct count(a) from t group by b order by count(a); ✔ // select distinct a+1 from t order by a+1; ✔ // select distinct a+1 from t order by a+2; ✗ + exprCtx := b.ctx.GetExprCtx() for j := 0; j < length; j++ { // both check original expression & as name - if expr.Equal(b.ctx, originalExprs[j]) || expr.Equal(b.ctx, p.Schema().Columns[j]) { + if expr.Equal(exprCtx, originalExprs[j]) || expr.Equal(exprCtx, p.Schema().Columns[j]) { return nil } } @@ -2355,7 +2356,7 @@ func (b *PlanBuilder) checkOrderByInDistinct(byItem *ast.ByItem, idx int, expr e CheckReferenced: for _, col := range cols { for j := 0; j < length; j++ { - if col.Equal(b.ctx, originalExprs[j]) || col.Equal(b.ctx, p.Schema().Columns[j]) { + if col.Equal(exprCtx, originalExprs[j]) || col.Equal(exprCtx, p.Schema().Columns[j]) { continue CheckReferenced } } @@ -2394,7 +2395,7 @@ func getUintFromNode(ctx PlanContext, n ast.Node, mustInt64orUint64 bool) (uVal if err != nil { return 0, false, false } - str, isNull, err := expression.GetStringFromConstant(ctx, param) + str, isNull, err := expression.GetStringFromConstant(ctx.GetExprCtx(), param) if err != nil { return 0, false, false } @@ -3277,7 +3278,7 @@ func (g *gbyResolver) Leave(inNode ast.Node) (ast.Node, bool) { return inNode, false } case *ast.PositionExpr: - pos, isNull, err := expression.PosFromPositionExpr(g.ctx, v) + pos, isNull, err := expression.PosFromPositionExpr(g.ctx.GetExprCtx(), v) if err != nil { g.err = plannererrors.ErrUnknown.GenWithStackByArgs() } @@ -6554,7 +6555,7 @@ func (b *PlanBuilder) buildWindowFunctionFrameBound(_ context.Context, spec *ast for i, item := range orderByItems { col := item.Col bound.CalcFuncs[i] = col - bound.CmpFuncs[i] = expression.GetCmpFunction(b.ctx, col, col) + bound.CmpFuncs[i] = expression.GetCmpFunction(b.ctx.GetExprCtx(), col, col) } return bound, nil } @@ -6577,7 +6578,7 @@ func (b *PlanBuilder) buildWindowFunctionFrameBound(_ context.Context, spec *ast oldTypeFlags := sc.TypeFlags() newTypeFlags := oldTypeFlags.WithIgnoreTruncateErr(true) sc.SetTypeFlags(newTypeFlags) - uVal, isNull, err := expr.EvalInt(b.ctx, chunk.Row{}) + uVal, isNull, err := expr.EvalInt(b.ctx.GetExprCtx(), chunk.Row{}) sc.SetTypeFlags(oldTypeFlags) if uVal < 0 || isNull || err != nil { return nil, plannererrors.ErrWindowFrameIllegal.GenWithStackByArgs(getWindowName(spec.Name.O)) @@ -6603,7 +6604,7 @@ func (b *PlanBuilder) buildWindowFunctionFrameBound(_ context.Context, spec *ast funcName = ast.DateSub } - bound.CalcFuncs[0], err = expression.NewFunctionBase(b.ctx, funcName, col.RetType, col, &expr, &unit) + bound.CalcFuncs[0], err = expression.NewFunctionBase(b.ctx.GetExprCtx(), funcName, col.RetType, col, &expr, &unit) if err != nil { return nil, err } @@ -6616,7 +6617,7 @@ func (b *PlanBuilder) buildWindowFunctionFrameBound(_ context.Context, spec *ast funcName = ast.Minus } - bound.CalcFuncs[0], err = expression.NewFunctionBase(b.ctx, funcName, col.RetType, col, &expr) + bound.CalcFuncs[0], err = expression.NewFunctionBase(b.ctx.GetExprCtx(), funcName, col.RetType, col, &expr) if err != nil { return nil, err } @@ -6658,7 +6659,7 @@ func (b *PlanBuilder) checkWindowFuncArgs(ctx context.Context, p LogicalPlan, wi for _, expr := range windowFuncExpr.Args { expr.Accept(checker) } - desc, err := aggregation.NewWindowFuncDesc(b.ctx, windowFuncExpr.Name, args, checker.InPrepareStmt) + desc, err := aggregation.NewWindowFuncDesc(b.ctx.GetExprCtx(), windowFuncExpr.Name, args, checker.InPrepareStmt) if err != nil { return err } @@ -6777,7 +6778,7 @@ func (b *PlanBuilder) buildWindowFunctions(ctx context.Context, p LogicalPlan, g for _, expr := range windowFunc.Args { expr.Accept(checker) } - desc, err := aggregation.NewWindowFuncDesc(b.ctx, windowFunc.Name, args[preArgs:preArgs+len(windowFunc.Args)], checker.InPrepareStmt) + desc, err := aggregation.NewWindowFuncDesc(b.ctx.GetExprCtx(), windowFunc.Name, args[preArgs:preArgs+len(windowFunc.Args)], checker.InPrepareStmt) if err != nil { return nil, nil, err } @@ -6785,7 +6786,7 @@ func (b *PlanBuilder) buildWindowFunctions(ctx context.Context, p LogicalPlan, g return nil, nil, plannererrors.ErrWrongArguments.GenWithStackByArgs(strings.ToLower(windowFunc.Name)) } preArgs += len(windowFunc.Args) - desc.WrapCastForAggArgs(b.ctx) + desc.WrapCastForAggArgs(b.ctx.GetExprCtx()) descs = append(descs, desc) windowMap[windowFunc] = schema.Len() schema.Append(&expression.Column{ @@ -7656,7 +7657,7 @@ func (b *PlanBuilder) buildProjection4CTEUnion(_ context.Context, seed LogicalPl resSchema := getResultCTESchema(seed.Schema(), b.ctx.GetSessionVars()) for i, col := range recur.Schema().Columns { if !resSchema.Columns[i].RetType.Equal(col.RetType) { - exprs[i] = expression.BuildCastFunction4Union(b.ctx, col, resSchema.Columns[i].RetType) + exprs[i] = expression.BuildCastFunction4Union(b.ctx.GetExprCtx(), col, resSchema.Columns[i].RetType) } else { exprs[i] = col } diff --git a/pkg/planner/core/logical_plans.go b/pkg/planner/core/logical_plans.go index 24a88593be030..112556420b669 100644 --- a/pkg/planner/core/logical_plans.go +++ b/pkg/planner/core/logical_plans.go @@ -412,29 +412,29 @@ func (p *LogicalJoin) columnSubstituteAll(schema *expression.Schema, exprs []exp copy(cpOtherConditions, p.OtherConditions) copy(cpEqualConditions, p.EqualConditions) - ctx := p.SCtx() + exprCtx := p.SCtx().GetExprCtx() // try to substitute columns in these condition. for i, cond := range cpLeftConditions { - if hasFail, cpLeftConditions[i] = expression.ColumnSubstituteAll(ctx, cond, schema, exprs); hasFail { + if hasFail, cpLeftConditions[i] = expression.ColumnSubstituteAll(exprCtx, cond, schema, exprs); hasFail { return } } for i, cond := range cpRightConditions { - if hasFail, cpRightConditions[i] = expression.ColumnSubstituteAll(ctx, cond, schema, exprs); hasFail { + if hasFail, cpRightConditions[i] = expression.ColumnSubstituteAll(exprCtx, cond, schema, exprs); hasFail { return } } for i, cond := range cpOtherConditions { - if hasFail, cpOtherConditions[i] = expression.ColumnSubstituteAll(ctx, cond, schema, exprs); hasFail { + if hasFail, cpOtherConditions[i] = expression.ColumnSubstituteAll(exprCtx, cond, schema, exprs); hasFail { return } } for i, cond := range cpEqualConditions { var tmp expression.Expression - if hasFail, tmp = expression.ColumnSubstituteAll(ctx, cond, schema, exprs); hasFail { + if hasFail, tmp = expression.ColumnSubstituteAll(exprCtx, cond, schema, exprs); hasFail { return } cpEqualConditions[i] = tmp.(*expression.ScalarFunction) @@ -1198,7 +1198,7 @@ func extractConstantCols(conditions []expression.Expression, sctx PlanContext, f constObjs []expression.Expression constUniqueIDs = intset.NewFastIntSet() ) - constObjs = expression.ExtractConstantEqColumnsOrScalar(sctx, constObjs, conditions) + constObjs = expression.ExtractConstantEqColumnsOrScalar(sctx.GetExprCtx(), constObjs, conditions) for _, constObj := range constObjs { switch x := constObj.(type) { case *expression.Column: @@ -1348,7 +1348,7 @@ func (la *LogicalApply) ExtractFD() *fd.FDSet { for _, col := range innerPlan.Schema().Columns { if cc.UniqueID == col.CorrelatedColUniqueID { ccc := &cc.Column - cond := expression.NewFunctionInternal(la.SCtx(), ast.EQ, types.NewFieldType(mysql.TypeTiny), ccc, col) + cond := expression.NewFunctionInternal(la.SCtx().GetExprCtx(), ast.EQ, types.NewFieldType(mysql.TypeTiny), ccc, col) eqCond = append(eqCond, cond.(*expression.ScalarFunction)) } } @@ -1543,8 +1543,9 @@ func (p *LogicalIndexScan) MatchIndexProp(prop *property.PhysicalProperty) (matc return false } sctx := p.SCtx() + exprCtx := sctx.GetExprCtx() for i, col := range p.IdxCols { - if col.Equal(sctx, prop.SortItems[0].Col) { + if col.Equal(exprCtx, prop.SortItems[0].Col) { return matchIndicesProp(sctx, p.IdxCols[i:], p.IdxColLens[i:], prop.SortItems) } else if i >= p.EqCondCount { break @@ -1730,7 +1731,7 @@ func (ds *DataSource) deriveTablePathStats(path *util.AccessPath, conds []expres continue } lCol, lOk := eqFunc.GetArgs()[0].(*expression.Column) - if lOk && lCol.Equal(ds.SCtx(), pkCol) { + if lOk && lCol.Equal(ds.SCtx().GetExprCtx(), pkCol) { _, rOk := eqFunc.GetArgs()[1].(*expression.CorrelatedColumn) if rOk { path.AccessConds = append(path.AccessConds, filter) @@ -1740,7 +1741,7 @@ func (ds *DataSource) deriveTablePathStats(path *util.AccessPath, conds []expres } } rCol, rOk := eqFunc.GetArgs()[1].(*expression.Column) - if rOk && rCol.Equal(ds.SCtx(), pkCol) { + if rOk && rCol.Equal(ds.SCtx().GetExprCtx(), pkCol) { _, lOk := eqFunc.GetArgs()[0].(*expression.CorrelatedColumn) if lOk { path.AccessConds = append(path.AccessConds, filter) diff --git a/pkg/planner/core/optimizer.go b/pkg/planner/core/optimizer.go index c3a8335a62aa1..fada1322343f6 100644 --- a/pkg/planner/core/optimizer.go +++ b/pkg/planner/core/optimizer.go @@ -529,12 +529,13 @@ func (p *PhysicalHashJoin) extractUsedCols(parentUsedCols []*expression.Column) func prunePhysicalColumnForHashJoinChild(sctx PlanContext, hashJoin *PhysicalHashJoin, joinUsedCols []*expression.Column, sender *PhysicalExchangeSender) error { var err error - joinUsed := expression.GetUsedList(sctx, joinUsedCols, sender.Schema()) + exprCtx := sctx.GetExprCtx() + joinUsed := expression.GetUsedList(exprCtx, joinUsedCols, sender.Schema()) hashCols := make([]*expression.Column, len(sender.HashCols)) for i, mppCol := range sender.HashCols { hashCols[i] = mppCol.Col } - hashUsed := expression.GetUsedList(sctx, hashCols, sender.Schema()) + hashUsed := expression.GetUsedList(exprCtx, hashCols, sender.Schema()) needPrune := false usedExprs := make([]expression.Expression, len(sender.Schema().Columns)) diff --git a/pkg/planner/core/pb_to_plan.go b/pkg/planner/core/pb_to_plan.go index 04dcb78a24e7d..af455809ee23a 100644 --- a/pkg/planner/core/pb_to_plan.go +++ b/pkg/planner/core/pb_to_plan.go @@ -152,7 +152,7 @@ func (b *PBPlanBuilder) buildTableScanSchema(tblInfo *model.TableInfo, columns [ } func (b *PBPlanBuilder) pbToSelection(e *tipb.Executor) (PhysicalPlan, error) { - conds, err := expression.PBToExprs(b.sctx, e.Selection.Conditions, b.tps) + conds, err := expression.PBToExprs(b.sctx.GetExprCtx(), e.Selection.Conditions, b.tps) if err != nil { return nil, err } @@ -165,8 +165,9 @@ func (b *PBPlanBuilder) pbToSelection(e *tipb.Executor) (PhysicalPlan, error) { func (b *PBPlanBuilder) pbToTopN(e *tipb.Executor) (PhysicalPlan, error) { topN := e.TopN byItems := make([]*util.ByItems, 0, len(topN.OrderBy)) + exprCtx := b.sctx.GetExprCtx() for _, item := range topN.OrderBy { - expr, err := expression.PBToExpr(b.sctx, item.Expr, b.tps) + expr, err := expression.PBToExpr(exprCtx, item.Expr, b.tps) if err != nil { return nil, errors.Trace(err) } @@ -221,14 +222,15 @@ func (b *PBPlanBuilder) buildAggSchema(aggFuncs []*aggregation.AggFuncDesc, grou func (b *PBPlanBuilder) getAggInfo(executor *tipb.Executor) ([]*aggregation.AggFuncDesc, []expression.Expression, error) { var err error aggFuncs := make([]*aggregation.AggFuncDesc, 0, len(executor.Aggregation.AggFunc)) + exprCtx := b.sctx.GetExprCtx() for _, expr := range executor.Aggregation.AggFunc { - aggFunc, err := aggregation.PBExprToAggFuncDesc(b.sctx, expr, b.tps) + aggFunc, err := aggregation.PBExprToAggFuncDesc(exprCtx, expr, b.tps) if err != nil { return nil, nil, errors.Trace(err) } aggFuncs = append(aggFuncs, aggFunc) } - groupBys, err := expression.PBToExprs(b.sctx, executor.Aggregation.GetGroupBy(), b.tps) + groupBys, err := expression.PBToExprs(exprCtx, executor.Aggregation.GetGroupBy(), b.tps) if err != nil { return nil, nil, errors.Trace(err) } diff --git a/pkg/planner/core/physical_plans.go b/pkg/planner/core/physical_plans.go index c156a9b386e79..3cd695b907d56 100644 --- a/pkg/planner/core/physical_plans.go +++ b/pkg/planner/core/physical_plans.go @@ -961,7 +961,7 @@ func (ts *PhysicalTableScan) ResolveCorrelatedColumns() ([]*ranger.Range, error) pkIdx := tables.FindPrimaryIndex(ts.Table) idxCols, idxColLens := expression.IndexInfo2PrefixCols(ts.Columns, ts.Schema().Columns, pkIdx) for _, cond := range access { - newCond, err := expression.SubstituteCorCol2Constant(ctx, cond) + newCond, err := expression.SubstituteCorCol2Constant(ctx.GetExprCtx(), cond) if err != nil { return nil, err } diff --git a/pkg/planner/core/plan_cache.go b/pkg/planner/core/plan_cache.go index 5d87f2c92c798..4814c0d95616b 100644 --- a/pkg/planner/core/plan_cache.go +++ b/pkg/planner/core/plan_cache.go @@ -58,7 +58,7 @@ func SetParameterValuesIntoSCtx(sctx PlanContext, isNonPrep bool, markers []ast. vars := sctx.GetSessionVars() vars.PlanCacheParams.Reset() for i, usingParam := range params { - val, err := usingParam.Eval(sctx, chunk.Row{}) + val, err := usingParam.Eval(sctx.GetExprCtx(), chunk.Row{}) if err != nil { return err } @@ -619,7 +619,7 @@ func rebuildRange(p Plan) error { } func convertConstant2Datum(ctx PlanContext, con *expression.Constant, target *types.FieldType) (*types.Datum, error) { - val, err := con.Eval(ctx, chunk.Row{}) + val, err := con.Eval(ctx.GetExprCtx(), chunk.Row{}) if err != nil { return nil, err } diff --git a/pkg/planner/core/plan_to_pb.go b/pkg/planner/core/plan_to_pb.go index 912eddb25b635..69ca809a4c504 100644 --- a/pkg/planner/core/plan_to_pb.go +++ b/pkg/planner/core/plan_to_pb.go @@ -37,7 +37,7 @@ func (p *PhysicalExpand) ToPB(ctx PlanContext, storeType kv.StoreType) (*tipb.Ex return p.toPBV2(ctx, storeType) } client := ctx.GetClient() - groupingSetsPB, err := p.GroupingSets.ToPB(ctx, client) + groupingSetsPB, err := p.GroupingSets.ToPB(ctx.GetExprCtx(), client) if err != nil { return nil, err } @@ -59,8 +59,9 @@ func (p *PhysicalExpand) ToPB(ctx PlanContext, storeType kv.StoreType) (*tipb.Ex func (p *PhysicalExpand) toPBV2(ctx PlanContext, storeType kv.StoreType) (*tipb.Executor, error) { client := ctx.GetClient() projExprsPB := make([]*tipb.ExprSlice, 0, len(p.LevelExprs)) + exprCtx := ctx.GetExprCtx() for _, exprs := range p.LevelExprs { - expressionsPB, err := expression.ExpressionsToPBList(ctx, exprs, client) + expressionsPB, err := expression.ExpressionsToPBList(exprCtx, exprs, client) if err != nil { return nil, err } @@ -85,7 +86,7 @@ func (p *PhysicalExpand) toPBV2(ctx PlanContext, storeType kv.StoreType) (*tipb. // ToPB implements PhysicalPlan ToPB interface. func (p *PhysicalHashAgg) ToPB(ctx PlanContext, storeType kv.StoreType) (*tipb.Executor, error) { client := ctx.GetClient() - groupByExprs, err := expression.ExpressionsToPBList(ctx, p.GroupByItems, client) + groupByExprs, err := expression.ExpressionsToPBList(ctx.GetExprCtx(), p.GroupByItems, client) if err != nil { return nil, err } @@ -93,7 +94,7 @@ func (p *PhysicalHashAgg) ToPB(ctx PlanContext, storeType kv.StoreType) (*tipb.E GroupBy: groupByExprs, } for _, aggFunc := range p.AggFuncs { - agg, err := aggregation.AggFuncToPBExpr(ctx, client, aggFunc, storeType) + agg, err := aggregation.AggFuncToPBExpr(ctx.GetExprCtx(), client, aggFunc, storeType) if err != nil { return nil, errors.Trace(err) } @@ -120,7 +121,8 @@ func (p *PhysicalHashAgg) ToPB(ctx PlanContext, storeType kv.StoreType) (*tipb.E // ToPB implements PhysicalPlan ToPB interface. func (p *PhysicalStreamAgg) ToPB(ctx PlanContext, storeType kv.StoreType) (*tipb.Executor, error) { client := ctx.GetClient() - groupByExprs, err := expression.ExpressionsToPBList(ctx, p.GroupByItems, client) + exprCtx := ctx.GetExprCtx() + groupByExprs, err := expression.ExpressionsToPBList(exprCtx, p.GroupByItems, client) if err != nil { return nil, err } @@ -128,7 +130,7 @@ func (p *PhysicalStreamAgg) ToPB(ctx PlanContext, storeType kv.StoreType) (*tipb GroupBy: groupByExprs, } for _, aggFunc := range p.AggFuncs { - agg, err := aggregation.AggFuncToPBExpr(ctx, client, aggFunc, storeType) + agg, err := aggregation.AggFuncToPBExpr(exprCtx, client, aggFunc, storeType) if err != nil { return nil, errors.Trace(err) } @@ -149,7 +151,7 @@ func (p *PhysicalStreamAgg) ToPB(ctx PlanContext, storeType kv.StoreType) (*tipb // ToPB implements PhysicalPlan ToPB interface. func (p *PhysicalSelection) ToPB(ctx PlanContext, storeType kv.StoreType) (*tipb.Executor, error) { client := ctx.GetClient() - conditions, err := expression.ExpressionsToPBList(ctx, p.Conditions, client) + conditions, err := expression.ExpressionsToPBList(ctx.GetExprCtx(), p.Conditions, client) if err != nil { return nil, err } @@ -171,7 +173,7 @@ func (p *PhysicalSelection) ToPB(ctx PlanContext, storeType kv.StoreType) (*tipb // ToPB implements PhysicalPlan ToPB interface. func (p *PhysicalProjection) ToPB(ctx PlanContext, storeType kv.StoreType) (*tipb.Executor, error) { client := ctx.GetClient() - exprs, err := expression.ExpressionsToPBList(ctx, p.Exprs, client) + exprs, err := expression.ExpressionsToPBList(ctx.GetExprCtx(), p.Exprs, client) if err != nil { return nil, err } @@ -196,11 +198,12 @@ func (p *PhysicalTopN) ToPB(ctx PlanContext, storeType kv.StoreType) (*tipb.Exec topNExec := &tipb.TopN{ Limit: p.Count, } + exprCtx := ctx.GetExprCtx() for _, item := range p.ByItems { - topNExec.OrderBy = append(topNExec.OrderBy, expression.SortByItemToPB(ctx, client, item.Expr, item.Desc)) + topNExec.OrderBy = append(topNExec.OrderBy, expression.SortByItemToPB(exprCtx, client, item.Expr, item.Desc)) } for _, item := range p.PartitionBy { - topNExec.PartitionBy = append(topNExec.PartitionBy, expression.SortByItemToPB(ctx, client, item.Col.Clone(), item.Desc)) + topNExec.PartitionBy = append(topNExec.PartitionBy, expression.SortByItemToPB(exprCtx, client, item.Col.Clone(), item.Desc)) } executorID := "" if storeType == kv.TiFlash { @@ -222,7 +225,7 @@ func (p *PhysicalLimit) ToPB(ctx PlanContext, storeType kv.StoreType) (*tipb.Exe } executorID := "" for _, item := range p.PartitionBy { - limitExec.PartitionBy = append(limitExec.PartitionBy, expression.SortByItemToPB(ctx, client, item.Col.Clone(), item.Desc)) + limitExec.PartitionBy = append(limitExec.PartitionBy, expression.SortByItemToPB(ctx.GetExprCtx(), client, item.Col.Clone(), item.Desc)) } if storeType == kv.TiFlash { var err error @@ -248,7 +251,7 @@ func (p *PhysicalTableScan) ToPB(ctx PlanContext, storeType kv.StoreType) (*tipb if len(p.LateMaterializationFilterCondition) > 0 { client := ctx.GetClient() - conditions, err := expression.ExpressionsToPBList(ctx, p.LateMaterializationFilterCondition, client) + conditions, err := expression.ExpressionsToPBList(ctx.GetExprCtx(), p.LateMaterializationFilterCondition, client) if err != nil { return nil, err } @@ -269,7 +272,7 @@ func (p *PhysicalTableScan) ToPB(ctx PlanContext, storeType kv.StoreType) (*tipb if storeType == kv.TiFlash { executorID = p.ExplainID().String() } - err = tables.SetPBColumnsDefaultValue(ctx, tsExec.Columns, p.Columns) + err = tables.SetPBColumnsDefaultValue(ctx.GetExprCtx(), tsExec.Columns, p.Columns) return &tipb.Executor{Tp: tipb.ExecType_TypeTableScan, TblScan: tsExec, ExecutorId: &executorID}, err } @@ -278,7 +281,7 @@ func (p *PhysicalTableScan) partitionTableScanToPBForFlash(ctx PlanContext) (*ti if len(p.LateMaterializationFilterCondition) > 0 { client := ctx.GetClient() - conditions, err := expression.ExpressionsToPBList(ctx, p.LateMaterializationFilterCondition, client) + conditions, err := expression.ExpressionsToPBList(ctx.GetExprCtx(), p.LateMaterializationFilterCondition, client) if err != nil { return nil, err } @@ -295,7 +298,7 @@ func (p *PhysicalTableScan) partitionTableScanToPBForFlash(ctx PlanContext) (*ti ptsExec.Desc = p.Desc executorID := p.ExplainID().String() - err = tables.SetPBColumnsDefaultValue(ctx, ptsExec.Columns, p.Columns) + err = tables.SetPBColumnsDefaultValue(ctx.GetExprCtx(), ptsExec.Columns, p.Columns) return &tipb.Executor{Tp: tipb.ExecType_TypePartitionTableScan, PartitionTableScan: ptsExec, ExecutorId: &executorID}, err } @@ -387,7 +390,7 @@ func (e *PhysicalExchangeSender) ToPB(ctx PlanContext, storeType kv.StoreType) ( } allFieldTypes = append(allFieldTypes, pbType) } - hashColPb, err := expression.ExpressionsToPBList(ctx, hashCols, ctx.GetClient()) + hashColPb, err := expression.ExpressionsToPBList(ctx.GetExprCtx(), hashCols, ctx.GetClient()) if err != nil { return nil, errors.Trace(err) } @@ -520,20 +523,20 @@ func (p *PhysicalHashJoin) ToPB(ctx PlanContext, storeType kv.StoreType) (*tipb. return nil, errors.Trace(err) } - left, err := expression.ExpressionsToPBList(ctx, leftKeys, client) + left, err := expression.ExpressionsToPBList(ctx.GetExprCtx(), leftKeys, client) if err != nil { return nil, err } - right, err := expression.ExpressionsToPBList(ctx, rightKeys, client) + right, err := expression.ExpressionsToPBList(ctx.GetExprCtx(), rightKeys, client) if err != nil { return nil, err } - leftConditions, err := expression.ExpressionsToPBList(ctx, p.LeftConditions, client) + leftConditions, err := expression.ExpressionsToPBList(ctx.GetExprCtx(), p.LeftConditions, client) if err != nil { return nil, err } - rightConditions, err := expression.ExpressionsToPBList(ctx, p.RightConditions, client) + rightConditions, err := expression.ExpressionsToPBList(ctx.GetExprCtx(), p.RightConditions, client) if err != nil { return nil, err } @@ -553,11 +556,11 @@ func (p *PhysicalHashJoin) ToPB(ctx PlanContext, storeType kv.StoreType) (*tipb. } else { otherConditionsInJoin = p.OtherConditions } - otherConditions, err := expression.ExpressionsToPBList(ctx, otherConditionsInJoin, client) + otherConditions, err := expression.ExpressionsToPBList(ctx.GetExprCtx(), otherConditionsInJoin, client) if err != nil { return nil, err } - otherEqConditions, err := expression.ExpressionsToPBList(ctx, otherEqConditionsFromIn, client) + otherEqConditions, err := expression.ExpressionsToPBList(ctx.GetExprCtx(), otherEqConditionsFromIn, client) if err != nil { return nil, err } @@ -641,7 +644,7 @@ func (fb *FrameBound) ToPB(ctx PlanContext) (*tipb.WindowFrameBound, error) { pbBound.Offset = &offset if fb.IsExplicitRange { - rangeFrame, err := expression.ExpressionsToPBList(ctx, fb.CalcFuncs, ctx.GetClient()) + rangeFrame, err := expression.ExpressionsToPBList(ctx.GetExprCtx(), fb.CalcFuncs, ctx.GetClient()) if err != nil { return nil, err } @@ -661,13 +664,13 @@ func (p *PhysicalWindow) ToPB(ctx PlanContext, storeType kv.StoreType) (*tipb.Ex windowExec.FuncDesc = make([]*tipb.Expr, 0, len(p.WindowFuncDescs)) for _, desc := range p.WindowFuncDescs { - windowExec.FuncDesc = append(windowExec.FuncDesc, aggregation.WindowFuncToPBExpr(ctx, client, desc)) + windowExec.FuncDesc = append(windowExec.FuncDesc, aggregation.WindowFuncToPBExpr(ctx.GetExprCtx(), client, desc)) } for _, item := range p.PartitionBy { - windowExec.PartitionBy = append(windowExec.PartitionBy, expression.SortByItemToPB(ctx, client, item.Col.Clone(), item.Desc)) + windowExec.PartitionBy = append(windowExec.PartitionBy, expression.SortByItemToPB(ctx.GetExprCtx(), client, item.Col.Clone(), item.Desc)) } for _, item := range p.OrderBy { - windowExec.OrderBy = append(windowExec.OrderBy, expression.SortByItemToPB(ctx, client, item.Col.Clone(), item.Desc)) + windowExec.OrderBy = append(windowExec.OrderBy, expression.SortByItemToPB(ctx.GetExprCtx(), client, item.Col.Clone(), item.Desc)) } if p.Frame != nil { @@ -715,7 +718,7 @@ func (p *PhysicalSort) ToPB(ctx PlanContext, storeType kv.StoreType) (*tipb.Exec sortExec := &tipb.Sort{} for _, item := range p.ByItems { - sortExec.ByItems = append(sortExec.ByItems, expression.SortByItemToPB(ctx, client, item.Expr, item.Desc)) + sortExec.ByItems = append(sortExec.ByItems, expression.SortByItemToPB(ctx.GetExprCtx(), client, item.Expr, item.Desc)) } isPartialSort := p.IsPartialSort sortExec.IsPartialSort = &isPartialSort diff --git a/pkg/planner/core/planbuilder.go b/pkg/planner/core/planbuilder.go index 28de96c043bbe..a9ba17583f614 100644 --- a/pkg/planner/core/planbuilder.go +++ b/pkg/planner/core/planbuilder.go @@ -1481,7 +1481,7 @@ func (b *PlanBuilder) buildAdmin(ctx context.Context, as *ast.AdminStmt) (Plan, func (b *PlanBuilder) buildPhysicalIndexLookUpReader(_ context.Context, dbName model.CIStr, tbl table.Table, idx *model.IndexInfo) (Plan, error) { tblInfo := tbl.Meta() physicalID, isPartition := getPhysicalID(tbl) - fullExprCols, _, err := expression.TableInfo2SchemaAndNames(b.ctx, dbName, tblInfo) + fullExprCols, _, err := expression.TableInfo2SchemaAndNames(b.ctx.GetExprCtx(), dbName, tblInfo) if err != nil { return nil, err } @@ -1922,7 +1922,7 @@ func (b *PlanBuilder) getMustAnalyzedColumns(tbl *ast.TableName, cols *calcOnceM if len(tblInfo.Indices) > 0 { // Add indexed columns. // Some indexed columns are generated columns so we also need to add the columns that make up those generated columns. - columns, _, err := expression.ColumnInfos2ColumnsAndNames(b.ctx, tbl.Schema, tbl.Name, tblInfo.Columns, tblInfo) + columns, _, err := expression.ColumnInfos2ColumnsAndNames(b.ctx.GetExprCtx(), tbl.Schema, tbl.Name, tblInfo.Columns, tblInfo) if err != nil { return nil, err } @@ -3555,12 +3555,12 @@ func (b *PlanBuilder) getDefaultValueForInsert(col *table.Column) (*expression.C err error ) if col.DefaultIsExpr && col.DefaultExpr != nil { - value, err = table.EvalColDefaultExpr(b.ctx, col.ToInfo(), col.DefaultExpr) + value, err = table.EvalColDefaultExpr(b.ctx.GetExprCtx(), col.ToInfo(), col.DefaultExpr) } else { if err := table.CheckNoDefaultValueForInsert(b.ctx.GetSessionVars().StmtCtx, col.ToInfo()); err != nil { return nil, err } - value, err = table.GetColDefaultValue(b.ctx, col.ToInfo()) + value, err = table.GetColDefaultValue(b.ctx.GetExprCtx(), col.ToInfo()) } if err != nil { return nil, err @@ -3633,7 +3633,7 @@ func (b *PlanBuilder) buildInsert(ctx context.Context, insert *ast.InsertStmt) ( return nil, err } // Build Schema with DBName otherwise ColumnRef with DBName cannot match any Column in Schema. - schema, names, err := expression.TableInfo2SchemaAndNames(b.ctx, tn.Schema, tableInfo) + schema, names, err := expression.TableInfo2SchemaAndNames(b.ctx.GetExprCtx(), tn.Schema, tableInfo) if err != nil { return nil, err } @@ -4125,7 +4125,7 @@ func (b *PlanBuilder) buildLoadData(ctx context.Context, ld *ast.LoadDataStmt) ( db := b.ctx.GetSessionVars().CurrentDB return nil, infoschema.ErrTableNotExists.GenWithStackByArgs(db, tableInfo.Name.O) } - schema, names, err := expression.TableInfo2SchemaAndNames(b.ctx, model.NewCIStr(""), tableInfo) + schema, names, err := expression.TableInfo2SchemaAndNames(b.ctx.GetExprCtx(), model.NewCIStr(""), tableInfo) if err != nil { return nil, err } @@ -4225,7 +4225,7 @@ func (b *PlanBuilder) buildImportInto(ctx context.Context, ld *ast.ImportIntoStm db := b.ctx.GetSessionVars().CurrentDB return nil, infoschema.ErrTableNotExists.GenWithStackByArgs(db, tableInfo.Name.O) } - schema, names, err := expression.TableInfo2SchemaAndNames(b.ctx, model.NewCIStr(""), tableInfo) + schema, names, err := expression.TableInfo2SchemaAndNames(b.ctx.GetExprCtx(), model.NewCIStr(""), tableInfo) if err != nil { return nil, err } @@ -4336,7 +4336,7 @@ func (b *PlanBuilder) buildSplitIndexRegion(node *ast.SplitRegionStmt) (Plan, er return nil, plannererrors.ErrKeyDoesNotExist.GenWithStackByArgs(node.IndexName, tblInfo.Name) } mockTablePlan := LogicalTableDual{}.Init(b.ctx, b.getSelectOffset()) - schema, names, err := expression.TableInfo2SchemaAndNames(b.ctx, node.Table.Schema, tblInfo) + schema, names, err := expression.TableInfo2SchemaAndNames(b.ctx.GetExprCtx(), node.Table.Schema, tblInfo) if err != nil { return nil, err } @@ -4429,7 +4429,7 @@ func (b *PlanBuilder) convertValue(valueItem ast.ExprNode, mockTablePlan Logical if !ok { return d, errors.New("Expect constant values") } - value, err := constant.Eval(b.ctx, chunk.Row{}) + value, err := constant.Eval(b.ctx.GetExprCtx(), chunk.Row{}) if err != nil { return d, err } @@ -4451,7 +4451,7 @@ func (b *PlanBuilder) buildSplitTableRegion(node *ast.SplitRegionStmt) (Plan, er tblInfo := node.Table.TableInfo handleColInfos := buildHandleColumnInfos(tblInfo) mockTablePlan := LogicalTableDual{}.Init(b.ctx, b.getSelectOffset()) - schema, names, err := expression.TableInfo2SchemaAndNames(b.ctx, node.Table.Schema, tblInfo) + schema, names, err := expression.TableInfo2SchemaAndNames(b.ctx.GetExprCtx(), node.Table.Schema, tblInfo) if err != nil { return nil, err } diff --git a/pkg/planner/core/point_get_plan.go b/pkg/planner/core/point_get_plan.go index 5cbe2f75eabeb..d09672ae67c63 100644 --- a/pkg/planner/core/point_get_plan.go +++ b/pkg/planner/core/point_get_plan.go @@ -693,7 +693,7 @@ func newBatchPointGetPlan( if err != nil { return nil } - d, err = con.Eval(ctx, chunk.Row{}) + d, err = con.Eval(ctx.GetExprCtx(), chunk.Row{}) if err != nil { return nil } @@ -839,7 +839,7 @@ func newBatchPointGetPlan( if err != nil { return nil } - d, err := con.Eval(ctx, chunk.Row{}) + d, err := con.Eval(ctx.GetExprCtx(), chunk.Row{}) if err != nil { return nil } @@ -878,7 +878,7 @@ func newBatchPointGetPlan( if err != nil { return nil } - d, err := con.Eval(ctx, chunk.Row{}) + d, err := con.Eval(ctx.GetExprCtx(), chunk.Row{}) if err != nil { return nil } @@ -1453,7 +1453,7 @@ func getNameValuePairs(ctx PlanContext, tbl *model.TableInfo, tblName model.CISt if err != nil { return nil, false } - d, err = con.Eval(ctx, chunk.Row{}) + d, err = con.Eval(ctx.GetExprCtx(), chunk.Row{}) if err != nil { return nil, false } @@ -1467,7 +1467,7 @@ func getNameValuePairs(ctx PlanContext, tbl *model.TableInfo, tblName model.CISt if err != nil { return nil, false } - d, err = con.Eval(ctx, chunk.Row{}) + d, err = con.Eval(ctx.GetExprCtx(), chunk.Row{}) if err != nil { return nil, false } @@ -1774,7 +1774,7 @@ func buildOrderedList(ctx PlanContext, plan Plan, list []*ast.Assignment, if err != nil { return nil, true } - expr = expression.BuildCastFunction(ctx, expr, col.GetType()) + expr = expression.BuildCastFunction(ctx.GetExprCtx(), expr, col.GetType()) if allAssignmentsAreConstant { _, isConst := expr.(*expression.Constant) allAssignmentsAreConstant = isConst diff --git a/pkg/planner/core/resolve_indices.go b/pkg/planner/core/resolve_indices.go index 06a6a81b40024..e80fef57d727d 100644 --- a/pkg/planner/core/resolve_indices.go +++ b/pkg/planner/core/resolve_indices.go @@ -95,7 +95,7 @@ func (p *PhysicalHashJoin) ResolveIndicesItself() (err error) { return err } p.RightJoinKeys[i] = rArg.(*expression.Column) - p.EqualConditions[i] = expression.NewFunctionInternal(ctx, fun.FuncName.L, fun.GetType(), lArg, rArg).(*expression.ScalarFunction) + p.EqualConditions[i] = expression.NewFunctionInternal(ctx.GetExprCtx(), fun.FuncName.L, fun.GetType(), lArg, rArg).(*expression.ScalarFunction) } for i, fun := range p.NAEqualConditions { lArg, err := fun.GetArgs()[0].ResolveIndices(lSchema) @@ -108,7 +108,7 @@ func (p *PhysicalHashJoin) ResolveIndicesItself() (err error) { return err } p.RightNAJoinKeys[i] = rArg.(*expression.Column) - p.NAEqualConditions[i] = expression.NewFunctionInternal(ctx, fun.FuncName.L, fun.GetType(), lArg, rArg).(*expression.ScalarFunction) + p.NAEqualConditions[i] = expression.NewFunctionInternal(ctx.GetExprCtx(), fun.FuncName.L, fun.GetType(), lArg, rArg).(*expression.ScalarFunction) } for i, expr := range p.LeftConditions { p.LeftConditions[i], err = expr.ResolveIndices(lSchema) @@ -412,7 +412,7 @@ func (p *PhysicalIndexReader) ResolveIndices() (err error) { if err != nil { // Check if there is duplicate virtual expression column matched. sctx := p.SCtx() - newExprCol, isOK := col.ResolveIndicesByVirtualExpr(sctx, p.indexPlan.Schema()) + newExprCol, isOK := col.ResolveIndicesByVirtualExpr(sctx.GetExprCtx(), p.indexPlan.Schema()) if isOK { p.OutputColumns[i] = newExprCol.(*expression.Column) continue @@ -492,7 +492,7 @@ func (p *PhysicalSelection) ResolveIndices() (err error) { p.Conditions[i], err = expr.ResolveIndices(p.children[0].Schema()) if err != nil { // Check if there is duplicate virtual expression column matched. - newCond, isOk := expr.ResolveIndicesByVirtualExpr(p.SCtx(), p.children[0].Schema()) + newCond, isOk := expr.ResolveIndicesByVirtualExpr(p.SCtx().GetExprCtx(), p.children[0].Schema()) if isOk { p.Conditions[i] = newCond continue diff --git a/pkg/planner/core/rule_aggregation_elimination.go b/pkg/planner/core/rule_aggregation_elimination.go index ebf689ac9ceca..75aa71dedd012 100644 --- a/pkg/planner/core/rule_aggregation_elimination.go +++ b/pkg/planner/core/rule_aggregation_elimination.go @@ -184,7 +184,7 @@ func ConvertAggToProj(agg *LogicalAggregation, schema *expression.Schema) (bool, Exprs: make([]expression.Expression, 0, len(agg.AggFuncs)), }.Init(agg.SCtx(), agg.QueryBlockOffset()) for _, fun := range agg.AggFuncs { - ok, expr := rewriteExpr(agg.SCtx(), fun) + ok, expr := rewriteExpr(agg.SCtx().GetExprCtx(), fun) if !ok { return false, nil } diff --git a/pkg/planner/core/rule_aggregation_push_down.go b/pkg/planner/core/rule_aggregation_push_down.go index 1c439c0fc99d4..395583657be9d 100644 --- a/pkg/planner/core/rule_aggregation_push_down.go +++ b/pkg/planner/core/rule_aggregation_push_down.go @@ -196,7 +196,7 @@ func (*aggregationPushDownSolver) addGbyCol(ctx PlanContext, gbyCols []*expressi for _, c := range cols { duplicate := false for _, gbyCol := range gbyCols { - if c.Equal(ctx, gbyCol) { + if c.Equal(ctx.GetExprCtx(), gbyCol) { duplicate = true break } @@ -291,7 +291,7 @@ func (a *aggregationPushDownSolver) tryToPushDownAgg(oldAgg *LogicalAggregation, func (*aggregationPushDownSolver) getDefaultValues(agg *LogicalAggregation) ([]types.Datum, bool) { defaultValues := make([]types.Datum, 0, agg.Schema().Len()) for _, aggFunc := range agg.AggFuncs { - value, existsDefaultValue := aggFunc.EvalNullValueInOuterJoin(agg.SCtx(), agg.children[0].Schema()) + value, existsDefaultValue := aggFunc.EvalNullValueInOuterJoin(agg.SCtx().GetExprCtx(), agg.children[0].Schema()) if !existsDefaultValue { return nil, false } @@ -340,7 +340,7 @@ func (a *aggregationPushDownSolver) makeNewAgg(ctx PlanContext, aggFuncs []*aggr newAggFuncDescs = append(newAggFuncDescs, newFuncs...) } for _, gbyCol := range gbyCols { - firstRow, err := aggregation.NewAggFuncDesc(agg.SCtx(), ast.AggFuncFirstRow, []expression.Expression{gbyCol}, false) + firstRow, err := aggregation.NewAggFuncDesc(agg.SCtx().GetExprCtx(), ast.AggFuncFirstRow, []expression.Expression{gbyCol}, false) if err != nil { return nil, err } @@ -395,16 +395,16 @@ func (*aggregationPushDownSolver) pushAggCrossUnion(agg *LogicalAggregation, uni newAggFunc := aggFunc.Clone() newArgs := make([]expression.Expression, 0, len(newAggFunc.Args)) for _, arg := range newAggFunc.Args { - newArgs = append(newArgs, expression.ColumnSubstitute(ctx, arg, unionSchema, expression.Column2Exprs(unionChild.Schema().Columns))) + newArgs = append(newArgs, expression.ColumnSubstitute(ctx.GetExprCtx(), arg, unionSchema, expression.Column2Exprs(unionChild.Schema().Columns))) } newAggFunc.Args = newArgs newAgg.AggFuncs = append(newAgg.AggFuncs, newAggFunc) } for _, gbyExpr := range agg.GroupByItems { - newExpr := expression.ColumnSubstitute(ctx, gbyExpr, unionSchema, expression.Column2Exprs(unionChild.Schema().Columns)) + newExpr := expression.ColumnSubstitute(ctx.GetExprCtx(), gbyExpr, unionSchema, expression.Column2Exprs(unionChild.Schema().Columns)) newAgg.GroupByItems = append(newAgg.GroupByItems, newExpr) // TODO: if there is a duplicated first_row function, we can delete it. - firstRow, err := aggregation.NewAggFuncDesc(agg.SCtx(), ast.AggFuncFirstRow, []expression.Expression{gbyExpr}, false) + firstRow, err := aggregation.NewAggFuncDesc(agg.SCtx().GetExprCtx(), ast.AggFuncFirstRow, []expression.Expression{gbyExpr}, false) if err != nil { return nil, err } @@ -557,7 +557,7 @@ func (a *aggregationPushDownSolver) aggPushDown(p LogicalPlan, opt *logicalOptim noSideEffects := true newGbyItems := make([]expression.Expression, 0, len(agg.GroupByItems)) for _, gbyItem := range agg.GroupByItems { - _, failed, groupBy := expression.ColumnSubstituteImpl(ctx, gbyItem, proj.schema, proj.Exprs, true) + _, failed, groupBy := expression.ColumnSubstituteImpl(ctx.GetExprCtx(), gbyItem, proj.schema, proj.Exprs, true) if failed { noSideEffects = false break @@ -577,7 +577,7 @@ func (a *aggregationPushDownSolver) aggPushDown(p LogicalPlan, opt *logicalOptim oldAggFuncsArgs = append(oldAggFuncsArgs, aggFunc.Args) newArgs := make([]expression.Expression, 0, len(aggFunc.Args)) for _, arg := range aggFunc.Args { - _, failed, newArg := expression.ColumnSubstituteImpl(ctx, arg, proj.schema, proj.Exprs, true) + _, failed, newArg := expression.ColumnSubstituteImpl(ctx.GetExprCtx(), arg, proj.schema, proj.Exprs, true) if failed { noSideEffects = false break @@ -594,7 +594,7 @@ func (a *aggregationPushDownSolver) aggPushDown(p LogicalPlan, opt *logicalOptim oldAggOrderItems = append(oldAggOrderItems, aggFunc.OrderByItems) newOrderByItems := make([]expression.Expression, 0, len(aggFunc.OrderByItems)) for _, oby := range aggFunc.OrderByItems { - _, failed, byItem := expression.ColumnSubstituteImpl(ctx, oby.Expr, proj.schema, proj.Exprs, true) + _, failed, byItem := expression.ColumnSubstituteImpl(ctx.GetExprCtx(), oby.Expr, proj.schema, proj.Exprs, true) if failed { noSideEffects = false break diff --git a/pkg/planner/core/rule_aggregation_skew_rewrite.go b/pkg/planner/core/rule_aggregation_skew_rewrite.go index ec07a034ba702..3e9588e8b062f 100644 --- a/pkg/planner/core/rule_aggregation_skew_rewrite.go +++ b/pkg/planner/core/rule_aggregation_skew_rewrite.go @@ -119,7 +119,7 @@ func (a *skewDistinctAggRewriter) rewriteSkewDistinctAgg(agg *LogicalAggregation } for _, arg := range aggFunc.Args { - firstRow, err := aggregation.NewAggFuncDesc(agg.SCtx(), ast.AggFuncFirstRow, + firstRow, err := aggregation.NewAggFuncDesc(agg.SCtx().GetExprCtx(), ast.AggFuncFirstRow, []expression.Expression{arg}, false) if err != nil { return nil @@ -155,7 +155,7 @@ func (a *skewDistinctAggRewriter) rewriteSkewDistinctAgg(agg *LogicalAggregation if newAggFunc.Name == ast.AggFuncCount { cntIndexes = append(cntIndexes, i) - sumAggFunc, err := aggregation.NewAggFuncDesc(agg.SCtx(), ast.AggFuncSum, + sumAggFunc, err := aggregation.NewAggFuncDesc(agg.SCtx().GetExprCtx(), ast.AggFuncSum, []expression.Expression{aggCol}, false) if err != nil { return nil @@ -179,7 +179,7 @@ func (a *skewDistinctAggRewriter) rewriteSkewDistinctAgg(agg *LogicalAggregation // SELECT count(DISTINCT a) FROM t GROUP BY b; // column b is not in the output schema, we have to add it to the bottom agg schema if firstRowCols.Has(int(col.UniqueID)) { - firstRow, err := aggregation.NewAggFuncDesc(agg.SCtx(), ast.AggFuncFirstRow, + firstRow, err := aggregation.NewAggFuncDesc(agg.SCtx().GetExprCtx(), ast.AggFuncFirstRow, []expression.Expression{col}, false) if err != nil { return nil @@ -225,7 +225,7 @@ func (a *skewDistinctAggRewriter) rewriteSkewDistinctAgg(agg *LogicalAggregation exprType := proj.Exprs[index].GetType() targetType := agg.schema.Columns[index].GetType() if !exprType.Equal(targetType) { - proj.Exprs[index] = expression.BuildCastFunction(agg.SCtx(), proj.Exprs[index], targetType) + proj.Exprs[index] = expression.BuildCastFunction(agg.SCtx().GetExprCtx(), proj.Exprs[index], targetType) } } proj.SetSchema(agg.schema.Clone()) diff --git a/pkg/planner/core/rule_build_key_info.go b/pkg/planner/core/rule_build_key_info.go index 819e9101aa4dc..6afed3341d68d 100644 --- a/pkg/planner/core/rule_build_key_info.go +++ b/pkg/planner/core/rule_build_key_info.go @@ -198,17 +198,18 @@ func (p *LogicalJoin) BuildKeyInfo(selfSchema *expression.Schema, childSchema [] // If one sides (a, b) is a unique key, then the unique key information is remained. // But we don't consider this situation currently. // Only key made by one column is considered now. + exprCtx := p.SCtx().GetExprCtx() for _, expr := range p.EqualConditions { ln := expr.GetArgs()[0].(*expression.Column) rn := expr.GetArgs()[1].(*expression.Column) for _, key := range childSchema[0].Keys { - if len(key) == 1 && key[0].Equal(p.SCtx(), ln) { + if len(key) == 1 && key[0].Equal(exprCtx, ln) { lOk = true break } } for _, key := range childSchema[1].Keys { - if len(key) == 1 && key[0].Equal(p.SCtx(), rn) { + if len(key) == 1 && key[0].Equal(exprCtx, rn) { rOk = true break } diff --git a/pkg/planner/core/rule_column_pruning.go b/pkg/planner/core/rule_column_pruning.go index da1d1f5e2a72a..794c64a42514d 100644 --- a/pkg/planner/core/rule_column_pruning.go +++ b/pkg/planner/core/rule_column_pruning.go @@ -78,7 +78,7 @@ func (p *LogicalExpand) PruneColumns(parentUsedCols []*expression.Column, opt *l // Expand need those extra redundant distinct group by columns projected from underlying projection. // distinct GroupByCol must be used by aggregate above, to make sure this, append distinctGroupByCol again. parentUsedCols = append(parentUsedCols, p.distinctGroupByCol...) - used := expression.GetUsedList(p.SCtx(), parentUsedCols, p.Schema()) + used := expression.GetUsedList(p.SCtx().GetExprCtx(), parentUsedCols, p.Schema()) prunedColumns := make([]*expression.Column, 0) for i := len(used) - 1; i >= 0; i-- { if !used[i] { @@ -100,7 +100,7 @@ func (p *LogicalExpand) PruneColumns(parentUsedCols []*expression.Column, opt *l // PruneColumns implements LogicalPlan interface. // If any expression has SetVar function or Sleep function, we do not prune it. func (p *LogicalProjection) PruneColumns(parentUsedCols []*expression.Column, opt *logicalOptimizeOp) (LogicalPlan, error) { - used := expression.GetUsedList(p.SCtx(), parentUsedCols, p.schema) + used := expression.GetUsedList(p.SCtx().GetExprCtx(), parentUsedCols, p.schema) prunedColumns := make([]*expression.Column, 0) // for implicit projected cols, once the ancestor doesn't use it, the implicit expr will be automatically pruned here. @@ -137,7 +137,7 @@ func (p *LogicalSelection) PruneColumns(parentUsedCols []*expression.Column, opt // PruneColumns implements LogicalPlan interface. func (la *LogicalAggregation) PruneColumns(parentUsedCols []*expression.Column, opt *logicalOptimizeOp) (LogicalPlan, error) { child := la.children[0] - used := expression.GetUsedList(la.SCtx(), parentUsedCols, la.Schema()) + used := expression.GetUsedList(la.SCtx().GetExprCtx(), parentUsedCols, la.Schema()) prunedColumns := make([]*expression.Column, 0) prunedFunctions := make([]*aggregation.AggFuncDesc, 0) prunedGroupByItems := make([]expression.Expression, 0) @@ -176,9 +176,9 @@ func (la *LogicalAggregation) PruneColumns(parentUsedCols []*expression.Column, var err error var newAgg *aggregation.AggFuncDesc if allFirstRow { - newAgg, err = aggregation.NewAggFuncDesc(la.SCtx(), ast.AggFuncFirstRow, []expression.Expression{expression.NewOne()}, false) + newAgg, err = aggregation.NewAggFuncDesc(la.SCtx().GetExprCtx(), ast.AggFuncFirstRow, []expression.Expression{expression.NewOne()}, false) } else { - newAgg, err = aggregation.NewAggFuncDesc(la.SCtx(), ast.AggFuncCount, []expression.Expression{expression.NewOne()}, false) + newAgg, err = aggregation.NewAggFuncDesc(la.SCtx().GetExprCtx(), ast.AggFuncCount, []expression.Expression{expression.NewOne()}, false) } if err != nil { return nil, err @@ -292,7 +292,7 @@ func (lt *LogicalTopN) PruneColumns(parentUsedCols []*expression.Column, opt *lo // PruneColumns implements LogicalPlan interface. func (p *LogicalUnionAll) PruneColumns(parentUsedCols []*expression.Column, opt *logicalOptimizeOp) (LogicalPlan, error) { - used := expression.GetUsedList(p.SCtx(), parentUsedCols, p.schema) + used := expression.GetUsedList(p.SCtx().GetExprCtx(), parentUsedCols, p.schema) hasBeenUsed := false for i := range used { hasBeenUsed = hasBeenUsed || used[i] @@ -367,10 +367,10 @@ func (p *LogicalUnionScan) PruneColumns(parentUsedCols []*expression.Column, opt // PruneColumns implements LogicalPlan interface. func (ds *DataSource) PruneColumns(parentUsedCols []*expression.Column, opt *logicalOptimizeOp) (LogicalPlan, error) { - used := expression.GetUsedList(ds.SCtx(), parentUsedCols, ds.schema) + used := expression.GetUsedList(ds.SCtx().GetExprCtx(), parentUsedCols, ds.schema) exprCols := expression.ExtractColumnsFromExpressions(nil, ds.allConds, nil) - exprUsed := expression.GetUsedList(ds.SCtx(), exprCols, ds.schema) + exprUsed := expression.GetUsedList(ds.SCtx().GetExprCtx(), exprCols, ds.schema) prunedColumns := make([]*expression.Column, 0) originSchemaColumns := ds.schema.Columns @@ -436,7 +436,7 @@ func (p *LogicalMemTable) PruneColumns(parentUsedCols []*expression.Column, opt return p, nil } prunedColumns := make([]*expression.Column, 0) - used := expression.GetUsedList(p.SCtx(), parentUsedCols, p.schema) + used := expression.GetUsedList(p.SCtx().GetExprCtx(), parentUsedCols, p.schema) for i := len(used) - 1; i >= 0; i-- { if !used[i] && p.schema.Len() > 1 { prunedColumns = append(prunedColumns, p.schema.Columns[i]) @@ -451,7 +451,7 @@ func (p *LogicalMemTable) PruneColumns(parentUsedCols []*expression.Column, opt // PruneColumns implements LogicalPlan interface. func (p *LogicalTableDual) PruneColumns(parentUsedCols []*expression.Column, opt *logicalOptimizeOp) (LogicalPlan, error) { - used := expression.GetUsedList(p.SCtx(), parentUsedCols, p.Schema()) + used := expression.GetUsedList(p.SCtx().GetExprCtx(), parentUsedCols, p.Schema()) prunedColumns := make([]*expression.Column, 0) for i := len(used) - 1; i >= 0; i-- { if !used[i] { diff --git a/pkg/planner/core/rule_decorrelate.go b/pkg/planner/core/rule_decorrelate.go index 2deccf3befdca..d99e3669ccdff 100644 --- a/pkg/planner/core/rule_decorrelate.go +++ b/pkg/planner/core/rule_decorrelate.go @@ -46,7 +46,7 @@ func (la *LogicalAggregation) canPullUp() bool { } for _, f := range la.AggFuncs { for _, arg := range f.Args { - expr := expression.EvaluateExprWithNull(la.SCtx(), la.children[0].Schema(), arg) + expr := expression.EvaluateExprWithNull(la.SCtx().GetExprCtx(), la.children[0].Schema(), arg) if con, ok := expr.(*expression.Constant); !ok || !con.Value.IsNull() { return false } @@ -69,7 +69,7 @@ func (la *LogicalApply) deCorColFromEqExpr(expr expression.Expression) expressio return nil } // We should make sure that the equal condition's left side is the join's left join key, right is the right key. - return expression.NewFunctionInternal(la.SCtx(), ast.EQ, types.NewFieldType(mysql.TypeTiny), ret, col) + return expression.NewFunctionInternal(la.SCtx().GetExprCtx(), ast.EQ, types.NewFieldType(mysql.TypeTiny), ret, col) } } if corCol, lOk := sf.GetArgs()[0].(*expression.CorrelatedColumn); lOk { @@ -79,7 +79,7 @@ func (la *LogicalApply) deCorColFromEqExpr(expr expression.Expression) expressio return nil } // We should make sure that the equal condition's left side is the join's left join key, right is the right key. - return expression.NewFunctionInternal(la.SCtx(), ast.EQ, types.NewFieldType(mysql.TypeTiny), ret, col) + return expression.NewFunctionInternal(la.SCtx().GetExprCtx(), ast.EQ, types.NewFieldType(mysql.TypeTiny), ret, col) } } return nil @@ -312,7 +312,7 @@ func (s *decorrelateSolver) optimize(ctx context.Context, p LogicalPlan, opt *lo outerColsInSchema := make([]*expression.Column, 0, outerPlan.Schema().Len()) for i, col := range outerPlan.Schema().Columns { - first, err := aggregation.NewAggFuncDesc(agg.SCtx(), ast.AggFuncFirstRow, []expression.Expression{col}, false) + first, err := aggregation.NewAggFuncDesc(agg.SCtx().GetExprCtx(), ast.AggFuncFirstRow, []expression.Expression{col}, false) if err != nil { return nil, planChanged, err } @@ -342,7 +342,7 @@ func (s *decorrelateSolver) optimize(ctx context.Context, p LogicalPlan, opt *lo aggArgs = append(aggArgs, expr) } } - desc, err := aggregation.NewAggFuncDesc(agg.SCtx(), agg.AggFuncs[i].Name, aggArgs, agg.AggFuncs[i].HasDistinct) + desc, err := aggregation.NewAggFuncDesc(agg.SCtx().GetExprCtx(), agg.AggFuncs[i].Name, aggArgs, agg.AggFuncs[i].HasDistinct) if err != nil { return nil, planChanged, err } @@ -390,7 +390,7 @@ func (s *decorrelateSolver) optimize(ctx context.Context, p LogicalPlan, opt *lo clonedCol := eqCond.GetArgs()[1].(*expression.Column) // If the join key is not in the aggregation's schema, add first row function. if agg.schema.ColumnIndex(eqCond.GetArgs()[1].(*expression.Column)) == -1 { - newFunc, err := aggregation.NewAggFuncDesc(apply.SCtx(), ast.AggFuncFirstRow, []expression.Expression{clonedCol}, false) + newFunc, err := aggregation.NewAggFuncDesc(apply.SCtx().GetExprCtx(), ast.AggFuncFirstRow, []expression.Expression{clonedCol}, false) if err != nil { return nil, planChanged, err } @@ -419,7 +419,7 @@ func (s *decorrelateSolver) optimize(ctx context.Context, p LogicalPlan, opt *lo proj.Exprs = expression.Column2Exprs(apply.schema.Columns) for i, val := range defaultValueMap { pos := proj.schema.ColumnIndex(agg.schema.Columns[i]) - ifNullFunc := expression.NewFunctionInternal(agg.SCtx(), ast.Ifnull, types.NewFieldType(mysql.TypeLonglong), agg.schema.Columns[i], val) + ifNullFunc := expression.NewFunctionInternal(agg.SCtx().GetExprCtx(), ast.Ifnull, types.NewFieldType(mysql.TypeLonglong), agg.schema.Columns[i], val) proj.Exprs[pos] = ifNullFunc } proj.SetChildren(apply) diff --git a/pkg/planner/core/rule_derive_topn_from_window.go b/pkg/planner/core/rule_derive_topn_from_window.go index d962387be4d7a..2633fec00039b 100644 --- a/pkg/planner/core/rule_derive_topn_from_window.go +++ b/pkg/planner/core/rule_derive_topn_from_window.go @@ -91,7 +91,7 @@ func windowIsTopN(p *LogicalSelection) (bool, uint64) { // Check if filter on window function windowColumns := child.GetWindowResultColumns() - if len(windowColumns) != 1 || !(column.Equal(p.SCtx(), windowColumns[0])) { + if len(windowColumns) != 1 || !(column.Equal(p.SCtx().GetExprCtx(), windowColumns[0])) { return false, 0 } diff --git a/pkg/planner/core/rule_eliminate_projection.go b/pkg/planner/core/rule_eliminate_projection.go index 36d7dbc82fd8d..55b6c5bbfe9fe 100644 --- a/pkg/planner/core/rule_eliminate_projection.go +++ b/pkg/planner/core/rule_eliminate_projection.go @@ -212,7 +212,7 @@ func (pe *projectionEliminator) eliminate(p LogicalPlan, replace map[string]*exp ctx := p.SCtx() for i := range proj.Exprs { proj.Exprs[i] = ReplaceColumnOfExpr(proj.Exprs[i], child, child.Schema()) - foldedExpr := expression.FoldConstant(ctx, proj.Exprs[i]) + foldedExpr := expression.FoldConstant(ctx.GetExprCtx(), proj.Exprs[i]) // the folded expr should have the same null flag with the original expr, especially for the projection under union, so forcing it here. foldedExpr.GetType().SetFlag((foldedExpr.GetType().GetFlag() & ^mysql.NotNullFlag) | (proj.Exprs[i].GetType().GetFlag() & mysql.NotNullFlag)) proj.Exprs[i] = foldedExpr diff --git a/pkg/planner/core/rule_generate_column_substitute.go b/pkg/planner/core/rule_generate_column_substitute.go index c2d92af20628e..1f30c3e39afe5 100644 --- a/pkg/planner/core/rule_generate_column_substitute.go +++ b/pkg/planner/core/rule_generate_column_substitute.go @@ -86,7 +86,7 @@ func collectGenerateColumn(lp LogicalPlan, exprToColumn ExprColumnMap) { func tryToSubstituteExpr(expr *expression.Expression, lp LogicalPlan, candidateExpr expression.Expression, tp types.EvalType, schema *expression.Schema, col *expression.Column, opt *logicalOptimizeOp) bool { changed := false - if (*expr).Equal(lp.SCtx(), candidateExpr) && candidateExpr.GetType().EvalType() == tp && + if (*expr).Equal(lp.SCtx().GetExprCtx(), candidateExpr) && candidateExpr.GetType().EvalType() == tp && schema.ColumnIndex(col) != -1 { *expr = col appendSubstituteColumnStep(lp, candidateExpr, col, opt) @@ -198,7 +198,7 @@ func (gc *gcSubstituter) substitute(ctx context.Context, lp LogicalPlan, exprToC for i := 0; i < len(aggFunc.Args); i++ { tp = aggFunc.Args[i].GetType().EvalType() for candidateExpr, column := range exprToColumn { - if aggFunc.Args[i].Equal(lp.SCtx(), candidateExpr) && candidateExpr.GetType().EvalType() == tp && + if aggFunc.Args[i].Equal(lp.SCtx().GetExprCtx(), candidateExpr) && candidateExpr.GetType().EvalType() == tp && x.Schema().ColumnIndex(column) != -1 { aggFunc.Args[i] = column appendSubstituteColumnStep(lp, candidateExpr, column, opt) @@ -209,7 +209,7 @@ func (gc *gcSubstituter) substitute(ctx context.Context, lp LogicalPlan, exprToC for i := 0; i < len(x.GroupByItems); i++ { tp = x.GroupByItems[i].GetType().EvalType() for candidateExpr, column := range exprToColumn { - if x.GroupByItems[i].Equal(lp.SCtx(), candidateExpr) && candidateExpr.GetType().EvalType() == tp && + if x.GroupByItems[i].Equal(lp.SCtx().GetExprCtx(), candidateExpr) && candidateExpr.GetType().EvalType() == tp && x.Schema().ColumnIndex(column) != -1 { x.GroupByItems[i] = column appendSubstituteColumnStep(lp, candidateExpr, column, opt) diff --git a/pkg/planner/core/rule_inject_extra_projection.go b/pkg/planner/core/rule_inject_extra_projection.go index 5d1983df08be6..0119fa3684f37 100644 --- a/pkg/planner/core/rule_inject_extra_projection.go +++ b/pkg/planner/core/rule_inject_extra_projection.go @@ -90,7 +90,7 @@ func injectProjBelowUnion(un *PhysicalUnionAll) *PhysicalUnionAll { srcCol.Index = i srcType := srcCol.RetType if !srcType.Equal(dstType) || !(mysql.HasNotNullFlag(dstType.GetFlag()) == mysql.HasNotNullFlag(srcType.GetFlag())) { - exprs[i] = expression.BuildCastFunction4Union(un.SCtx(), srcCol, dstType) + exprs[i] = expression.BuildCastFunction4Union(un.SCtx().GetExprCtx(), srcCol, dstType) needChange = true } else { exprs[i] = srcCol @@ -115,7 +115,7 @@ func injectProjBelowUnion(un *PhysicalUnionAll) *PhysicalUnionAll { func InjectProjBelowAgg(aggPlan PhysicalPlan, aggFuncs []*aggregation.AggFuncDesc, groupByItems []expression.Expression) PhysicalPlan { hasScalarFunc := false - internal.WrapCastForAggFuncs(aggPlan.SCtx(), aggFuncs) + internal.WrapCastForAggFuncs(aggPlan.SCtx().GetExprCtx(), aggFuncs) for i := 0; !hasScalarFunc && i < len(aggFuncs); i++ { for _, arg := range aggFuncs[i].Args { _, isScalarFunc := arg.(*expression.ScalarFunction) diff --git a/pkg/planner/core/rule_join_reorder.go b/pkg/planner/core/rule_join_reorder.go index 0fd1d54940982..a7182d38a6b42 100644 --- a/pkg/planner/core/rule_join_reorder.go +++ b/pkg/planner/core/rule_join_reorder.go @@ -511,7 +511,7 @@ func (s *baseSingleGroupJoinOrderSolver) checkConnection(leftPlan, rightPlan Log rightNode, leftNode = leftPlan, rightPlan usedEdges = append(usedEdges, edge) } else { - newSf := expression.NewFunctionInternal(s.ctx, ast.EQ, edge.GetType(), rCol, lCol).(*expression.ScalarFunction) + newSf := expression.NewFunctionInternal(s.ctx.GetExprCtx(), ast.EQ, edge.GetType(), rCol, lCol).(*expression.ScalarFunction) usedEdges = append(usedEdges, newSf) } } diff --git a/pkg/planner/core/rule_join_reorder_dp.go b/pkg/planner/core/rule_join_reorder_dp.go index e49f79e02d968..916e4c4fbd5f1 100644 --- a/pkg/planner/core/rule_join_reorder_dp.go +++ b/pkg/planner/core/rule_join_reorder_dp.go @@ -248,7 +248,7 @@ func (s *joinReorderDPSolver) newJoinWithEdge(leftPlan, rightPlan LogicalPlan, e if leftPlan.Schema().Contains(lCol) { eqConds = append(eqConds, edge.edge) } else { - newSf := expression.NewFunctionInternal(s.ctx, ast.EQ, edge.edge.GetType(), rCol, lCol).(*expression.ScalarFunction) + newSf := expression.NewFunctionInternal(s.ctx.GetExprCtx(), ast.EQ, edge.edge.GetType(), rCol, lCol).(*expression.ScalarFunction) eqConds = append(eqConds, newSf) } } diff --git a/pkg/planner/core/rule_max_min_eliminate.go b/pkg/planner/core/rule_max_min_eliminate.go index 519aede720503..e4ec913d49995 100644 --- a/pkg/planner/core/rule_max_min_eliminate.go +++ b/pkg/planner/core/rule_max_min_eliminate.go @@ -182,8 +182,8 @@ func (*maxMinEliminator) eliminateSingleMaxMin(agg *LogicalAggregation, opt *log // If it can be NULL, we need to filter NULL out first. if !mysql.HasNotNullFlag(f.Args[0].GetType().GetFlag()) { sel = LogicalSelection{}.Init(ctx, agg.QueryBlockOffset()) - isNullFunc := expression.NewFunctionInternal(ctx, ast.IsNull, types.NewFieldType(mysql.TypeTiny), f.Args[0]) - notNullFunc := expression.NewFunctionInternal(ctx, ast.UnaryNot, types.NewFieldType(mysql.TypeTiny), isNullFunc) + isNullFunc := expression.NewFunctionInternal(ctx.GetExprCtx(), ast.IsNull, types.NewFieldType(mysql.TypeTiny), f.Args[0]) + notNullFunc := expression.NewFunctionInternal(ctx.GetExprCtx(), ast.UnaryNot, types.NewFieldType(mysql.TypeTiny), isNullFunc) sel.Conditions = []expression.Expression{notNullFunc} sel.SetChildren(agg.Children()[0]) child = sel diff --git a/pkg/planner/core/rule_partition_processor.go b/pkg/planner/core/rule_partition_processor.go index ef21fb84a09b0..f82b4964dfe66 100644 --- a/pkg/planner/core/rule_partition_processor.go +++ b/pkg/planner/core/rule_partition_processor.go @@ -125,7 +125,7 @@ func generateHashPartitionExpr(ctx PlanContext, pi *model.PartitionInfo, columns // we have to increase the `PlanID` here. But it is safe to remove this line without introducing any bug. // TODO: remove this line after fixing the test cases. ctx.GetSessionVars().PlanID.Add(1) - expr, err := expression.ParseSimpleExpr(ctx, pi.Expr, expression.WithInputSchemaAndNames(schema, names, nil)) + expr, err := expression.ParseSimpleExpr(ctx.GetExprCtx(), pi.Expr, expression.WithInputSchemaAndNames(schema, names, nil)) if err != nil { return nil, err } @@ -168,12 +168,12 @@ func (s *partitionProcessor) getUsedHashPartitions(ctx PlanContext, if col, ok := hashExpr.(*expression.Column); ok && col.RetType.EvalType() == types.ETInt { numPartitions := len(pi.Definitions) - posHigh, highIsNull, err := hashExpr.EvalInt(ctx, chunk.MutRowFromDatums(r.HighVal).ToRow()) + posHigh, highIsNull, err := hashExpr.EvalInt(ctx.GetExprCtx(), chunk.MutRowFromDatums(r.HighVal).ToRow()) if err != nil { return nil, nil, err } - posLow, lowIsNull, err := hashExpr.EvalInt(ctx, chunk.MutRowFromDatums(r.LowVal).ToRow()) + posLow, lowIsNull, err := hashExpr.EvalInt(ctx.GetExprCtx(), chunk.MutRowFromDatums(r.LowVal).ToRow()) if err != nil { return nil, nil, err } @@ -244,7 +244,7 @@ func (s *partitionProcessor) getUsedHashPartitions(ctx PlanContext, highLowVals := make([]types.Datum, 0, len(r.HighVal)+len(r.LowVal)) highLowVals = append(highLowVals, r.HighVal...) highLowVals = append(highLowVals, r.LowVal...) - pos, isNull, err := hashExpr.EvalInt(ctx, chunk.MutRowFromDatums(highLowVals).ToRow()) + pos, isNull, err := hashExpr.EvalInt(ctx.GetExprCtx(), chunk.MutRowFromDatums(highLowVals).ToRow()) if err != nil { // If we failed to get the point position, we can just skip and ignore it. continue @@ -280,12 +280,12 @@ func (s *partitionProcessor) getUsedKeyPartitions(ctx PlanContext, if !r.IsPointNullable(tc) { if len(partCols) == 1 && partCols[0].RetType.EvalType() == types.ETInt { col := partCols[0] - posHigh, highIsNull, err := col.EvalInt(ctx, chunk.MutRowFromDatums(r.HighVal).ToRow()) + posHigh, highIsNull, err := col.EvalInt(ctx.GetExprCtx(), chunk.MutRowFromDatums(r.HighVal).ToRow()) if err != nil { return nil, nil, err } - posLow, lowIsNull, err := col.EvalInt(ctx, chunk.MutRowFromDatums(r.LowVal).ToRow()) + posLow, lowIsNull, err := col.EvalInt(ctx.GetExprCtx(), chunk.MutRowFromDatums(r.LowVal).ToRow()) if err != nil { return nil, nil, err } @@ -764,7 +764,7 @@ func (l *listPartitionPruner) findUsedListPartitions(conds []expression.Expressi if len(r.HighVal) != len(exprCols) { return l.fullRange, nil } - value, isNull, err := pruneExpr.EvalInt(l.ctx, chunk.MutRowFromDatums(r.HighVal).ToRow()) + value, isNull, err := pruneExpr.EvalInt(l.ctx.GetExprCtx(), chunk.MutRowFromDatums(r.HighVal).ToRow()) if err != nil { return nil, err } @@ -826,7 +826,7 @@ func (s *partitionProcessor) prune(ds *DataSource, opt *logicalOptimizeOp) (Logi // like 'not (a != 1)' would not be handled so we need to convert it to 'a = 1', which can be handled when building range. // TODO: there may be a better way to push down Not once for all. for i, cond := range ds.allConds { - ds.allConds[i] = expression.PushDownNot(ds.SCtx(), cond) + ds.allConds[i] = expression.PushDownNot(ds.SCtx().GetExprCtx(), cond) } // Try to locate partition directly for hash partition. switch pi.Type { @@ -1059,7 +1059,7 @@ func makePartitionByFnCol(sctx PlanContext, columns []*expression.Column, names // we have to increase the `PlanID` here. But it is safe to remove this line without introducing any bug. // TODO: remove this line after fixing the test cases. sctx.GetSessionVars().PlanID.Add(1) - partExpr, err := expression.ParseSimpleExpr(sctx, partitionExpr, expression.WithInputSchemaAndNames(schema, names, nil)) + partExpr, err := expression.ParseSimpleExpr(sctx.GetExprCtx(), partitionExpr, expression.WithInputSchemaAndNames(schema, names, nil)) if err != nil { return nil, nil, monotonous, err } @@ -1377,7 +1377,7 @@ func partitionRangeColumnForInExpr(sctx PlanContext, args []expression.Expressio } // convert all elements to EQ-exprs and prune them one by one - sf, err := expression.NewFunction(sctx, ast.EQ, types.NewFieldType(types.KindInt64), []expression.Expression{col, args[i]}...) + sf, err := expression.NewFunction(sctx.GetExprCtx(), ast.EQ, types.NewFieldType(types.KindInt64), []expression.Expression{col, args[i]}...) if err != nil { return pruner.fullRange() } @@ -1415,10 +1415,10 @@ func partitionRangeForInExpr(sctx PlanContext, args []expression.Expression, if pruner.partFn != nil { // replace fn(col) to fn(const) partFnConst := replaceColumnWithConst(pruner.partFn, constExpr) - val, _, err = partFnConst.EvalInt(sctx, chunk.Row{}) + val, _, err = partFnConst.EvalInt(sctx.GetExprCtx(), chunk.Row{}) unsigned = mysql.HasUnsignedFlag(partFnConst.GetType().GetFlag()) } else { - val, _, err = constExpr.EvalInt(sctx, chunk.Row{}) + val, _, err = constExpr.EvalInt(sctx.GetExprCtx(), chunk.Row{}) unsigned = mysql.HasUnsignedFlag(constExpr.GetType().GetFlag()) } if err != nil { @@ -1534,7 +1534,7 @@ func (p *rangePruner) extractDataForPrune(sctx PlanContext, expr expression.Expr if !expression.ConstExprConsiderPlanCache(constExpr, sctx.GetSessionVars().StmtCtx.UseCache) { return ret, false } - c, isNull, err := constExpr.EvalInt(sctx, chunk.Row{}) + c, isNull, err := constExpr.EvalInt(sctx.GetExprCtx(), chunk.Row{}) if err == nil && !isNull { ret.c = c ret.unsigned = mysql.HasUnsignedFlag(constExpr.GetType().GetFlag()) @@ -1949,14 +1949,14 @@ func (p *rangeColumnsPruner) pruneUseBinarySearch(sctx PlanContext, op string, d if p.lessThan[ith][i] == nil { // MAXVALUE return true } - expr, err := expression.NewFunctionBase(sctx, op, types.NewFieldType(mysql.TypeLonglong), *p.lessThan[ith][i], v) + expr, err := expression.NewFunctionBase(sctx.GetExprCtx(), op, types.NewFieldType(mysql.TypeLonglong), *p.lessThan[ith][i], v) if err != nil { savedError = err return true } expr.SetCharsetAndCollation(charSet, collation) var val int64 - val, isNull, err = expr.EvalInt(sctx, chunk.Row{}) + val, isNull, err = expr.EvalInt(sctx.GetExprCtx(), chunk.Row{}) if err != nil { savedError = err return true diff --git a/pkg/planner/core/rule_predicate_push_down.go b/pkg/planner/core/rule_predicate_push_down.go index 105e2754aa035..9b3588b0d98bb 100644 --- a/pkg/planner/core/rule_predicate_push_down.go +++ b/pkg/planner/core/rule_predicate_push_down.go @@ -52,7 +52,7 @@ func addSelection(p LogicalPlan, child LogicalPlan, conditions []expression.Expr p.Children()[chIdx] = child return } - conditions = expression.PropagateConstant(p.SCtx(), conditions) + conditions = expression.PropagateConstant(p.SCtx().GetExprCtx(), conditions) // Return table dual when filter is constant false or null. dual := Conds2TableDual(child, conditions) if dual != nil { @@ -107,8 +107,9 @@ func (p *LogicalSelection) PredicatePushDown(predicates []expression.Expression, originConditions = canBePushDown retConditions, child = p.children[0].PredicatePushDown(append(canBePushDown, predicates...), opt) retConditions = append(retConditions, canNotBePushDown...) + exprCtx := p.SCtx().GetExprCtx() if len(retConditions) > 0 { - p.Conditions = expression.PropagateConstant(p.SCtx(), retConditions) + p.Conditions = expression.PropagateConstant(exprCtx, retConditions) // Return table dual when filter is constant false or null. dual := Conds2TableDual(p, p.Conditions) if dual != nil { @@ -132,13 +133,13 @@ func (p *LogicalUnionScan) PredicatePushDown(predicates []expression.Expression, // PredicatePushDown implements LogicalPlan PredicatePushDown interface. func (ds *DataSource) PredicatePushDown(predicates []expression.Expression, opt *logicalOptimizeOp) ([]expression.Expression, LogicalPlan) { - predicates = expression.PropagateConstant(ds.SCtx(), predicates) + predicates = expression.PropagateConstant(ds.SCtx().GetExprCtx(), predicates) predicates = DeleteTrueExprs(ds, predicates) // Add tidb_shard() prefix to the condtion for shard index in some scenarios // TODO: remove it to the place building logical plan predicates = ds.AddPrefix4ShardIndexes(ds.SCtx(), predicates) ds.allConds = predicates - ds.pushedDownConds, predicates = expression.PushDownExprs(ds.SCtx(), predicates, ds.SCtx().GetClient(), kv.UnSpecified) + ds.pushedDownConds, predicates = expression.PushDownExprs(ds.SCtx().GetExprCtx(), predicates, ds.SCtx().GetClient(), kv.UnSpecified) appendDataSourcePredicatePushDownTraceStep(ds, opt) return predicates, ds } @@ -162,7 +163,7 @@ func (p *LogicalJoin) PredicatePushDown(predicates []expression.Expression, opt return ret, dual } // Handle where conditions - predicates = expression.ExtractFiltersFromDNFs(p.SCtx(), predicates) + predicates = expression.ExtractFiltersFromDNFs(p.SCtx().GetExprCtx(), predicates) // Only derive left where condition, because right where condition cannot be pushed down equalCond, leftPushCond, rightPushCond, otherCond = p.extractOnCondition(predicates, true, false) leftCond = leftPushCond @@ -181,7 +182,7 @@ func (p *LogicalJoin) PredicatePushDown(predicates []expression.Expression, opt return ret, dual } // Handle where conditions - predicates = expression.ExtractFiltersFromDNFs(p.SCtx(), predicates) + predicates = expression.ExtractFiltersFromDNFs(p.SCtx().GetExprCtx(), predicates) // Only derive right where condition, because left where condition cannot be pushed down equalCond, leftPushCond, rightPushCond, otherCond = p.extractOnCondition(predicates, false, true) rightCond = rightPushCond @@ -199,8 +200,8 @@ func (p *LogicalJoin) PredicatePushDown(predicates []expression.Expression, opt tempCond = append(tempCond, expression.ScalarFuncs2Exprs(p.EqualConditions)...) tempCond = append(tempCond, p.OtherConditions...) tempCond = append(tempCond, predicates...) - tempCond = expression.ExtractFiltersFromDNFs(p.SCtx(), tempCond) - tempCond = expression.PropagateConstant(p.SCtx(), tempCond) + tempCond = expression.ExtractFiltersFromDNFs(p.SCtx().GetExprCtx(), tempCond) + tempCond = expression.PropagateConstant(p.SCtx().GetExprCtx(), tempCond) // Return table dual when filter is constant false or null. dual := Conds2TableDual(p, tempCond) if dual != nil { @@ -215,7 +216,7 @@ func (p *LogicalJoin) PredicatePushDown(predicates []expression.Expression, opt leftCond = leftPushCond rightCond = rightPushCond case AntiSemiJoin: - predicates = expression.PropagateConstant(p.SCtx(), predicates) + predicates = expression.PropagateConstant(p.SCtx().GetExprCtx(), predicates) // Return table dual when filter is constant false or null. dual := Conds2TableDual(p, predicates) if dual != nil { @@ -304,7 +305,7 @@ func (p *LogicalJoin) updateEQCond() { if rProj != nil { rKey = rProj.appendExpr(rKey) } - eqCond := expression.NewFunctionInternal(p.SCtx(), ast.EQ, types.NewFieldType(mysql.TypeTiny), lKey, rKey) + eqCond := expression.NewFunctionInternal(p.SCtx().GetExprCtx(), ast.EQ, types.NewFieldType(mysql.TypeTiny), lKey, rKey) if isNA { p.NAEQConditions = append(p.NAEQConditions, eqCond.(*expression.ScalarFunction)) } else { @@ -349,7 +350,7 @@ func (p *LogicalProjection) appendExpr(expr expression.Expression) *expression.C if col, ok := expr.(*expression.Column); ok { return col } - expr = expression.ColumnSubstitute(p.SCtx(), expr, p.schema, p.Exprs) + expr = expression.ColumnSubstitute(p.SCtx().GetExprCtx(), expr, p.schema, p.Exprs) p.Exprs = append(p.Exprs, expr) col := &expression.Column{ @@ -429,7 +430,8 @@ func simplifyOuterJoin(p *LogicalJoin, predicates []expression.Expression) { // If it is a conjunction containing a null-rejected condition as a conjunct. // If it is a disjunction of null-rejected conditions. func isNullRejected(ctx PlanContext, schema *expression.Schema, expr expression.Expression) bool { - expr = expression.PushDownNot(ctx, expr) + exprCtx := ctx.GetExprCtx() + expr = expression.PushDownNot(exprCtx, expr) if expression.ContainOuterNot(expr) { return false } @@ -443,7 +445,7 @@ func isNullRejected(ctx PlanContext, schema *expression.Schema, expr expression. return true } - result := expression.EvaluateExprWithNull(ctx, schema, cond) + result := expression.EvaluateExprWithNull(exprCtx, schema, cond) x, ok := result.(*expression.Constant) if !ok { continue @@ -525,9 +527,9 @@ func (p *LogicalProjection) PredicatePushDown(predicates []expression.Expression return predicates, p } } - ctx := p.SCtx() + exprCtx := p.SCtx().GetExprCtx() for _, cond := range predicates { - substituted, hasFailed, newFilter := expression.ColumnSubstituteImpl(ctx, cond, p.Schema(), p.Exprs, true) + substituted, hasFailed, newFilter := expression.ColumnSubstituteImpl(exprCtx, cond, p.Schema(), p.Exprs, true) if substituted && !hasFailed && !expression.HasGetSetVarFunc(newFilter) { canBePushed = append(canBePushed, newFilter) } else { @@ -570,7 +572,7 @@ func (la *LogicalAggregation) pushDownPredicatesForAggregation(cond expression.E } } if ok { - newFunc := expression.ColumnSubstitute(la.SCtx(), cond, la.Schema(), exprsOriginal) + newFunc := expression.ColumnSubstitute(la.SCtx().GetExprCtx(), cond, la.Schema(), exprsOriginal) condsToPush = append(condsToPush, newFunc) } else { ret = append(ret, cond) @@ -593,13 +595,14 @@ func (la *LogicalAggregation) pushDownCNFPredicatesForAggregation(cond expressio if len(subCNFItem) == 1 { return la.pushDownPredicatesForAggregation(subCNFItem[0], groupByColumns, exprsOriginal) } + exprCtx := la.SCtx().GetExprCtx() for _, item := range subCNFItem { condsToPushForItem, retForItem := la.pushDownDNFPredicatesForAggregation(item, groupByColumns, exprsOriginal) if len(condsToPushForItem) > 0 { - condsToPush = append(condsToPush, expression.ComposeDNFCondition(la.SCtx(), condsToPushForItem...)) + condsToPush = append(condsToPush, expression.ComposeDNFCondition(exprCtx, condsToPushForItem...)) } if len(retForItem) > 0 { - ret = append(ret, expression.ComposeDNFCondition(la.SCtx(), retForItem...)) + ret = append(ret, expression.ComposeDNFCondition(exprCtx, retForItem...)) } } return condsToPush, ret @@ -618,21 +621,22 @@ func (la *LogicalAggregation) pushDownDNFPredicatesForAggregation(cond expressio if len(subDNFItem) == 1 { return la.pushDownPredicatesForAggregation(subDNFItem[0], groupByColumns, exprsOriginal) } + exprCtx := la.SCtx().GetExprCtx() for _, item := range subDNFItem { condsToPushForItem, retForItem := la.pushDownCNFPredicatesForAggregation(item, groupByColumns, exprsOriginal) if len(condsToPushForItem) <= 0 { return nil, []expression.Expression{cond} } - condsToPush = append(condsToPush, expression.ComposeCNFCondition(la.SCtx(), condsToPushForItem...)) + condsToPush = append(condsToPush, expression.ComposeCNFCondition(exprCtx, condsToPushForItem...)) if len(retForItem) > 0 { - ret = append(ret, expression.ComposeCNFCondition(la.SCtx(), retForItem...)) + ret = append(ret, expression.ComposeCNFCondition(exprCtx, retForItem...)) } } if len(ret) == 0 { // All the condition can be pushed down. return []expression.Expression{cond}, nil } - dnfPushDownCond := expression.ComposeDNFCondition(la.SCtx(), condsToPush...) + dnfPushDownCond := expression.ComposeDNFCondition(exprCtx, condsToPush...) // Some condition can't be pushed down, we need to keep all the condition. return []expression.Expression{dnfPushDownCond}, []expression.Expression{cond} } @@ -681,9 +685,10 @@ func DeriveOtherConditions( leftCond []expression.Expression, rightCond []expression.Expression) { isOuterSemi := (p.JoinType == LeftOuterSemiJoin) || (p.JoinType == AntiLeftOuterSemiJoin) ctx := p.SCtx() + exprCtx := ctx.GetExprCtx() for _, expr := range p.OtherConditions { if deriveLeft { - leftRelaxedCond := expression.DeriveRelaxedFiltersFromDNF(ctx, expr, leftSchema) + leftRelaxedCond := expression.DeriveRelaxedFiltersFromDNF(exprCtx, expr, leftSchema) if leftRelaxedCond != nil { leftCond = append(leftCond, leftRelaxedCond) } @@ -693,7 +698,7 @@ func DeriveOtherConditions( } } if deriveRight { - rightRelaxedCond := expression.DeriveRelaxedFiltersFromDNF(ctx, expr, rightSchema) + rightRelaxedCond := expression.DeriveRelaxedFiltersFromDNF(exprCtx, expr, rightSchema) if rightRelaxedCond != nil { rightCond = append(rightCond, rightRelaxedCond) } @@ -734,7 +739,7 @@ func deriveNotNullExpr(ctx PlanContext, expr expression.Expression, schema *expr childCol = schema.RetrieveColumn(arg1) } if isNullRejected(ctx, schema, expr) && !mysql.HasNotNullFlag(childCol.RetType.GetFlag()) { - return expression.BuildNotNullExpr(ctx, childCol) + return expression.BuildNotNullExpr(ctx.GetExprCtx(), childCol) } return nil } @@ -749,7 +754,7 @@ func Conds2TableDual(p LogicalPlan, conds []expression.Expression) LogicalPlan { return nil } sc := p.SCtx().GetSessionVars().StmtCtx - if expression.MaybeOverOptimized4PlanCache(p.SCtx(), []expression.Expression{con}) { + if expression.MaybeOverOptimized4PlanCache(p.SCtx().GetExprCtx(), []expression.Expression{con}) { return nil } if isTrue, err := con.Value.ToBool(sc.TypeCtxOrDefault()); (err == nil && isTrue == 0) || con.Value.IsNull() { @@ -769,7 +774,7 @@ func DeleteTrueExprs(p LogicalPlan, conds []expression.Expression) []expression. newConds = append(newConds, cond) continue } - if expression.MaybeOverOptimized4PlanCache(p.SCtx(), []expression.Expression{con}) { + if expression.MaybeOverOptimized4PlanCache(p.SCtx().GetExprCtx(), []expression.Expression{con}) { newConds = append(newConds, cond) continue } @@ -802,7 +807,7 @@ func (p *LogicalJoin) outerJoinPropConst(predicates []expression.Expression) []e p.RightConditions = nil p.OtherConditions = nil nullSensitive := p.JoinType == AntiLeftOuterSemiJoin || p.JoinType == LeftOuterSemiJoin - joinConds, predicates = expression.PropConstOverOuterJoin(p.SCtx(), joinConds, predicates, outerTable.Schema(), innerTable.Schema(), nullSensitive) + joinConds, predicates = expression.PropConstOverOuterJoin(p.SCtx().GetExprCtx(), joinConds, predicates, outerTable.Schema(), innerTable.Schema(), nullSensitive) p.AttachOnConds(joinConds) return predicates } @@ -1009,6 +1014,7 @@ func (adder *exprPrefixAdder) addExprPrefix4DNFCond(condition *expression.Scalar dnfItems := expression.FlattenDNFConditions(condition) newAccessItems := make([]expression.Expression, 0, len(dnfItems)) + exprCtx := adder.sctx.GetExprCtx() for _, item := range dnfItems { if sf, ok := item.(*expression.ScalarFunction); ok { var accesses []expression.Expression @@ -1018,14 +1024,14 @@ func (adder *exprPrefixAdder) addExprPrefix4DNFCond(condition *expression.Scalar if err != nil { return []expression.Expression{condition}, err } - newAccessItems = append(newAccessItems, expression.ComposeCNFCondition(adder.sctx, accesses...)) + newAccessItems = append(newAccessItems, expression.ComposeCNFCondition(exprCtx, accesses...)) } else if sf.FuncName.L == ast.EQ || sf.FuncName.L == ast.In { // only add prefix expression for EQ or IN function accesses, err = adder.addExprPrefix4CNFCond([]expression.Expression{sf}) if err != nil { return []expression.Expression{condition}, err } - newAccessItems = append(newAccessItems, expression.ComposeCNFCondition(adder.sctx, accesses...)) + newAccessItems = append(newAccessItems, expression.ComposeCNFCondition(exprCtx, accesses...)) } else { newAccessItems = append(newAccessItems, item) } @@ -1034,7 +1040,7 @@ func (adder *exprPrefixAdder) addExprPrefix4DNFCond(condition *expression.Scalar } } - return []expression.Expression{expression.ComposeDNFCondition(adder.sctx, newAccessItems...)}, nil + return []expression.Expression{expression.ComposeDNFCondition(exprCtx, newAccessItems...)}, nil } // PredicatePushDown implements LogicalPlan PredicatePushDown interface. @@ -1068,7 +1074,7 @@ func (p *LogicalCTE) PredicatePushDown(predicates []expression.Expression, _ *lo newPred = append(newPred, pushedPredicates[i].Clone()) ResolveExprAndReplace(newPred[i], p.cte.ColumnMap) } - p.cte.pushDownPredicates = append(p.cte.pushDownPredicates, expression.ComposeCNFCondition(p.SCtx(), newPred...)) + p.cte.pushDownPredicates = append(p.cte.pushDownPredicates, expression.ComposeCNFCondition(p.SCtx().GetExprCtx(), newPred...)) return predicates, p.self } diff --git a/pkg/planner/core/rule_predicate_simplification.go b/pkg/planner/core/rule_predicate_simplification.go index 1319bcab0a848..387898cb57a78 100644 --- a/pkg/planner/core/rule_predicate_simplification.go +++ b/pkg/planner/core/rule_predicate_simplification.go @@ -96,7 +96,7 @@ func updateInPredicate(ctx PlanContext, inPredicate expression.Expression, notEQ var lastValue *expression.Constant for _, element := range v.GetArgs() { value, valueOK := element.(*expression.Constant) - redundantValue := valueOK && value.Equal(ctx, notEQValue) + redundantValue := valueOK && value.Equal(ctx.GetExprCtx(), notEQValue) if !redundantValue { newValues = append(newValues, element) } @@ -112,7 +112,7 @@ func updateInPredicate(ctx PlanContext, inPredicate expression.Expression, notEQ newValues = append(newValues, lastValue) specialCase = true } - newPred := expression.NewFunctionInternal(ctx, v.FuncName.L, v.RetType, newValues...) + newPred := expression.NewFunctionInternal(ctx.GetExprCtx(), v.FuncName.L, v.RetType, newValues...) return newPred, specialCase } diff --git a/pkg/planner/core/rule_semi_join_rewrite.go b/pkg/planner/core/rule_semi_join_rewrite.go index ca856aa967ca0..1209b3a3ed60a 100644 --- a/pkg/planner/core/rule_semi_join_rewrite.go +++ b/pkg/planner/core/rule_semi_join_rewrite.go @@ -103,7 +103,7 @@ func (smj *semiJoinRewriter) recursivePlan(p LogicalPlan) (LogicalPlan, error) { aggOutputCols := make([]*expression.Column, 0, len(join.EqualConditions)) for i := range join.EqualConditions { innerCol := join.EqualConditions[i].GetArgs()[1].(*expression.Column) - firstRow, err := aggregation.NewAggFuncDesc(join.SCtx(), ast.AggFuncFirstRow, []expression.Expression{innerCol}, false) + firstRow, err := aggregation.NewAggFuncDesc(join.SCtx().GetExprCtx(), ast.AggFuncFirstRow, []expression.Expression{innerCol}, false) if err != nil { return nil, err } diff --git a/pkg/planner/core/rule_topn_push_down.go b/pkg/planner/core/rule_topn_push_down.go index 04e3f2a3e1a98..e8f9c672c982f 100644 --- a/pkg/planner/core/rule_topn_push_down.go +++ b/pkg/planner/core/rule_topn_push_down.go @@ -133,9 +133,9 @@ func (p *LogicalProjection) pushDownTopN(topN *LogicalTopN, opt *logicalOptimize } } if topN != nil { - ctx := p.SCtx() + exprCtx := p.SCtx().GetExprCtx() for _, by := range topN.ByItems { - by.Expr = expression.FoldConstant(ctx, expression.ColumnSubstitute(ctx, by.Expr, p.schema, p.Exprs)) + by.Expr = expression.FoldConstant(exprCtx, expression.ColumnSubstitute(exprCtx, by.Expr, p.schema, p.Exprs)) } // remove meaningless constant sort items. diff --git a/pkg/planner/core/runtime_filter.go b/pkg/planner/core/runtime_filter.go index cb468260fa0a1..7ff21dd4eb6f3 100644 --- a/pkg/planner/core/runtime_filter.go +++ b/pkg/planner/core/runtime_filter.go @@ -215,7 +215,7 @@ func RuntimeFilterListToPB(ctx PlanContext, runtimeFilterList []*RuntimeFilter, // ToPB convert runtime filter to PB func (rf *RuntimeFilter) ToPB(ctx PlanContext, client kv.Client) (*tipb.RuntimeFilter, error) { - pc := expression.NewPBConverter(client, ctx) + pc := expression.NewPBConverter(client, ctx.GetExprCtx()) srcExprListPB := make([]*tipb.Expr, 0, len(rf.srcExprList)) for _, srcExpr := range rf.srcExprList { srcExprPB := pc.ExprToPB(srcExpr) diff --git a/pkg/planner/core/stats.go b/pkg/planner/core/stats.go index 9a887f5226ab3..be2c3965ed806 100644 --- a/pkg/planner/core/stats.go +++ b/pkg/planner/core/stats.go @@ -470,9 +470,10 @@ func (ds *DataSource) DeriveStats(_ []*property.StatsInfo, _ *expression.Schema, // two preprocess here. // 1: PushDownNot here can convert query 'not (a != 1)' to 'a = 1'. // 2: EliminateNoPrecisionCast here can convert query 'cast(c as bigint) = 1' to 'c = 1' to leverage access range. + exprCtx := ds.SCtx().GetExprCtx() for i, expr := range ds.pushedDownConds { - ds.pushedDownConds[i] = expression.PushDownNot(ds.SCtx(), expr) - ds.pushedDownConds[i] = expression.EliminateNoPrecisionLossCast(ds.SCtx(), ds.pushedDownConds[i]) + ds.pushedDownConds[i] = expression.PushDownNot(exprCtx, expr) + ds.pushedDownConds[i] = expression.EliminateNoPrecisionLossCast(exprCtx, ds.pushedDownConds[i]) } for _, path := range ds.possibleAccessPaths { if path.IsTablePath() { @@ -524,10 +525,11 @@ func getMinSelectivityFromPaths(paths []*util.AccessPath, totalRowCount float64) func (ts *LogicalTableScan) DeriveStats(_ []*property.StatsInfo, _ *expression.Schema, _ []*expression.Schema, _ [][]*expression.Column) (_ *property.StatsInfo, err error) { ts.Source.initStats(nil) // PushDownNot here can convert query 'not (a != 1)' to 'a = 1'. + exprCtx := ts.SCtx().GetExprCtx() for i, expr := range ts.AccessConds { // TODO The expressions may be shared by TableScan and several IndexScans, there would be redundant // `PushDownNot` function call in multiple `DeriveStats` then. - ts.AccessConds[i] = expression.PushDownNot(ts.SCtx(), expr) + ts.AccessConds[i] = expression.PushDownNot(exprCtx, expr) } ts.SetStats(ts.Source.deriveStatsByFilter(ts.AccessConds, nil)) // ts.Handle could be nil if PK is Handle, and PK column has been pruned. @@ -553,8 +555,9 @@ func (ts *LogicalTableScan) DeriveStats(_ []*property.StatsInfo, _ *expression.S // DeriveStats implements LogicalPlan DeriveStats interface. func (is *LogicalIndexScan) DeriveStats(_ []*property.StatsInfo, selfSchema *expression.Schema, _ []*expression.Schema, _ [][]*expression.Column) (*property.StatsInfo, error) { is.Source.initStats(nil) + exprCtx := is.SCtx().GetExprCtx() for i, expr := range is.AccessConds { - is.AccessConds[i] = expression.PushDownNot(is.SCtx(), expr) + is.AccessConds[i] = expression.PushDownNot(exprCtx, expr) } is.SetStats(is.Source.deriveStatsByFilter(is.AccessConds, nil)) if len(is.AccessConds) == 0 { @@ -1009,7 +1012,7 @@ func (p *LogicalCTE) DeriveStats(_ []*property.StatsInfo, selfSchema *expression if p.cte.seedPartPhysicalPlan == nil { // Build push-downed predicates. if len(p.cte.pushDownPredicates) > 0 { - newCond := expression.ComposeDNFCondition(p.SCtx(), p.cte.pushDownPredicates...) + newCond := expression.ComposeDNFCondition(p.SCtx().GetExprCtx(), p.cte.pushDownPredicates...) newSel := LogicalSelection{Conditions: []expression.Expression{newCond}}.Init(p.SCtx(), p.cte.seedPartLogicalPlan.QueryBlockOffset()) newSel.SetChildren(p.cte.seedPartLogicalPlan) p.cte.seedPartLogicalPlan = newSel diff --git a/pkg/planner/core/task.go b/pkg/planner/core/task.go index 82ec231d4ad44..62c057e348006 100644 --- a/pkg/planner/core/task.go +++ b/pkg/planner/core/task.go @@ -468,13 +468,13 @@ func (p *PhysicalHashJoin) convertPartitionKeysIfNeed(lTask, rTask *mppTask) (*m if lMask[i] { cType := cTypes[i].Clone() cType.SetFlag(lKey.Col.RetType.GetFlag()) - lCast := expression.BuildCastFunction(p.SCtx(), lKey.Col, cType) + lCast := expression.BuildCastFunction(p.SCtx().GetExprCtx(), lKey.Col, cType) lKey = &property.MPPPartitionColumn{Col: appendExpr(lProj, lCast), CollateID: lKey.CollateID} } if rMask[i] { cType := cTypes[i].Clone() cType.SetFlag(rKey.Col.RetType.GetFlag()) - rCast := expression.BuildCastFunction(p.SCtx(), rKey.Col, cType) + rCast := expression.BuildCastFunction(p.SCtx().GetExprCtx(), rKey.Col, cType) rKey = &property.MPPPartitionColumn{Col: appendExpr(rProj, rCast), CollateID: rKey.CollateID} } lPartKeys = append(lPartKeys, lKey) @@ -1112,7 +1112,7 @@ func (p *PhysicalTopN) canExpressionConvertedToPB(storeTp kv.StoreType) bool { for _, item := range p.ByItems { exprs = append(exprs, item.Expr) } - return expression.CanExprsPushDown(p.SCtx(), exprs, p.SCtx().GetClient(), storeTp) + return expression.CanExprsPushDown(p.SCtx().GetExprCtx(), exprs, p.SCtx().GetClient(), storeTp) } // containVirtualColumn checks whether TopN.ByItems contains virtual generated columns. @@ -1213,12 +1213,12 @@ func (p *PhysicalExpand) attach2Task(tasks ...task) task { func (p *PhysicalProjection) attach2Task(tasks ...task) task { t := tasks[0].copy() if cop, ok := t.(*copTask); ok { - if (len(cop.rootTaskConds) == 0 && len(cop.idxMergePartPlans) == 0) && expression.CanExprsPushDown(p.SCtx(), p.Exprs, p.SCtx().GetClient(), cop.getStoreType()) { + if (len(cop.rootTaskConds) == 0 && len(cop.idxMergePartPlans) == 0) && expression.CanExprsPushDown(p.SCtx().GetExprCtx(), p.Exprs, p.SCtx().GetClient(), cop.getStoreType()) { copTask := attachPlan2Task(p, cop) return copTask } } else if mpp, ok := t.(*mppTask); ok { - if expression.CanExprsPushDown(p.SCtx(), p.Exprs, p.SCtx().GetClient(), kv.TiFlash) { + if expression.CanExprsPushDown(p.SCtx().GetExprCtx(), p.Exprs, p.SCtx().GetClient(), kv.TiFlash) { p.SetChildren(mpp.p) mpp.p = p return mpp @@ -1275,7 +1275,7 @@ func (p *PhysicalUnionAll) attach2Task(tasks ...task) task { func (sel *PhysicalSelection) attach2Task(tasks ...task) task { if mppTask, _ := tasks[0].(*mppTask); mppTask != nil { // always push to mpp task. - if expression.CanExprsPushDown(sel.SCtx(), sel.Conditions, sel.SCtx().GetClient(), kv.TiFlash) { + if expression.CanExprsPushDown(sel.SCtx().GetExprCtx(), sel.Conditions, sel.SCtx().GetClient(), kv.TiFlash) { return attachPlan2Task(sel, mppTask.copy()) } } @@ -1302,7 +1302,7 @@ func CheckAggCanPushCop(sctx PlanContext, aggFuncs []*aggregation.AggFuncDesc, g ret = false break } - if !expression.CanExprsPushDownWithExtraInfo(sctx, aggFunc.Args, client, storeType, aggFunc.Name == ast.AggFuncSum) { + if !expression.CanExprsPushDownWithExtraInfo(sctx.GetExprCtx(), aggFunc.Args, client, storeType, aggFunc.Name == ast.AggFuncSum) { reason = "arguments of AggFunc `" + aggFunc.Name + "` contains unsupported exprs" ret = false break @@ -1313,13 +1313,13 @@ func CheckAggCanPushCop(sctx PlanContext, aggFuncs []*aggregation.AggFuncDesc, g for _, item := range aggFunc.OrderByItems { exprs = append(exprs, item.Expr) } - if !expression.CanExprsPushDownWithExtraInfo(sctx, exprs, client, storeType, false) { + if !expression.CanExprsPushDownWithExtraInfo(sctx.GetExprCtx(), exprs, client, storeType, false) { reason = "arguments of AggFunc `" + aggFunc.Name + "` contains unsupported exprs in order-by clause" ret = false break } } - pb, _ := aggregation.AggFuncToPBExpr(sctx, client, aggFunc, storeType) + pb, _ := aggregation.AggFuncToPBExpr(sctx.GetExprCtx(), client, aggFunc, storeType) if pb == nil { reason = "AggFunc `" + aggFunc.Name + "` can not be converted to pb expr" ret = false @@ -1330,7 +1330,7 @@ func CheckAggCanPushCop(sctx PlanContext, aggFuncs []*aggregation.AggFuncDesc, g reason = "groupByItems contain virtual columns, which is not supported now" ret = false } - if ret && !expression.CanExprsPushDown(sctx, groupByItems, client, storeType) { + if ret && !expression.CanExprsPushDown(sctx.GetExprCtx(), groupByItems, client, storeType) { reason = "groupByItems contain unsupported exprs" ret = false } @@ -1432,7 +1432,7 @@ func BuildFinalModeAggregation( // 1. add all args to partial.GroupByItems foundInGroupBy := false for j, gbyExpr := range partial.GroupByItems { - if gbyExpr.Equal(sctx, distinctArg) && gbyExpr.GetType().Equal(distinctArg.GetType()) { + if gbyExpr.Equal(sctx.GetExprCtx(), distinctArg) && gbyExpr.GetType().Equal(distinctArg.GetType()) { // if the two expressions exactly the same in terms of data types and collation, then can avoid it. foundInGroupBy = true ret = partialGbySchema.Columns[j] @@ -1463,7 +1463,7 @@ func BuildFinalModeAggregation( // items. // maybe we can unify them sometime. // only add firstrow for order by items of group_concat() - firstRow, err := aggregation.NewAggFuncDesc(sctx, ast.AggFuncFirstRow, []expression.Expression{distinctArg}, false) + firstRow, err := aggregation.NewAggFuncDesc(sctx.GetExprCtx(), ast.AggFuncFirstRow, []expression.Expression{distinctArg}, false) if err != nil { panic("NewAggFuncDesc FirstRow meets error: " + err.Error()) } @@ -1561,7 +1561,7 @@ func BuildFinalModeAggregation( if aggFunc.Name == ast.AggFuncAvg { cntAgg := aggFunc.Clone() cntAgg.Name = ast.AggFuncCount - err := cntAgg.TypeInfer(sctx) + err := cntAgg.TypeInfer(sctx.GetExprCtx()) if err != nil { // must not happen partial = nil final = original @@ -1632,7 +1632,7 @@ func (p *basePhysicalAgg) convertAvgForMPP() *PhysicalProjection { // inset a count(column) avgCount := aggFunc.Clone() avgCount.Name = ast.AggFuncCount - err := avgCount.TypeInfer(p.SCtx()) + err := avgCount.TypeInfer(p.SCtx().GetExprCtx()) if err != nil { // must not happen return nil } @@ -1653,9 +1653,9 @@ func (p *basePhysicalAgg) convertAvgForMPP() *PhysicalProjection { } newSchema.Append(avgSumCol) // avgSumCol/(case when avgCountCol=0 then 1 else avgCountCol end) - eq := expression.NewFunctionInternal(p.SCtx(), ast.EQ, types.NewFieldType(mysql.TypeTiny), avgCountCol, expression.NewZero()) - caseWhen := expression.NewFunctionInternal(p.SCtx(), ast.Case, avgCountCol.RetType, eq, expression.NewOne(), avgCountCol) - divide := expression.NewFunctionInternal(p.SCtx(), ast.Div, avgSumCol.RetType, avgSumCol, caseWhen) + eq := expression.NewFunctionInternal(p.SCtx().GetExprCtx(), ast.EQ, types.NewFieldType(mysql.TypeTiny), avgCountCol, expression.NewZero()) + caseWhen := expression.NewFunctionInternal(p.SCtx().GetExprCtx(), ast.Case, avgCountCol.RetType, eq, expression.NewOne(), avgCountCol) + divide := expression.NewFunctionInternal(p.SCtx().GetExprCtx(), ast.Div, avgSumCol.RetType, avgSumCol, caseWhen) divide.(*expression.ScalarFunction).RetType = p.schema.Columns[i].RetType exprs = append(exprs, divide) } else { @@ -1851,7 +1851,7 @@ func (p *basePhysicalAgg) canUse3Stage4SingleDistinctAgg() bool { func genFirstRowAggForGroupBy(ctx PlanContext, groupByItems []expression.Expression) ([]*aggregation.AggFuncDesc, error) { aggFuncs := make([]*aggregation.AggFuncDesc, 0, len(groupByItems)) for _, groupBy := range groupByItems { - agg, err := aggregation.NewAggFuncDesc(ctx, ast.AggFuncFirstRow, []expression.Expression{groupBy}, false) + agg, err := aggregation.NewAggFuncDesc(ctx.GetExprCtx(), ast.AggFuncFirstRow, []expression.Expression{groupBy}, false) if err != nil { return nil, err } @@ -1896,7 +1896,7 @@ func RemoveUnnecessaryFirstRow( if _, ok := gbyExpr.(*expression.Constant); ok { continue } - if gbyExpr.Equal(sctx, aggFunc.Args[0]) { + if gbyExpr.Equal(sctx.GetExprCtx(), aggFunc.Args[0]) { canOptimize = true firstRowFuncMap[aggFunc].Args[0] = finalGbyItems[j] break @@ -2237,8 +2237,8 @@ func (p *PhysicalHashAgg) adjust3StagePhaseAgg(partialAgg, finalAgg PhysicalPlan if !fun.HasDistinct { // for normal agg phase1, we should also modify them to target for specified group data. // Expr = (case when groupingID = targeted_groupingID then arg else null end) - eqExpr := expression.NewFunctionInternal(p.SCtx(), ast.EQ, types.NewFieldType(mysql.TypeTiny), groupingIDCol, expression.NewUInt64Const(fun.GroupingID)) - caseWhen := expression.NewFunctionInternal(p.SCtx(), ast.Case, fun.Args[0].GetType(), eqExpr, fun.Args[0], expression.NewNull()) + eqExpr := expression.NewFunctionInternal(p.SCtx().GetExprCtx(), ast.EQ, types.NewFieldType(mysql.TypeTiny), groupingIDCol, expression.NewUInt64Const(fun.GroupingID)) + caseWhen := expression.NewFunctionInternal(p.SCtx().GetExprCtx(), ast.Case, fun.Args[0].GetType(), eqExpr, fun.Args[0], expression.NewNull()) caseWhenProjCol := &expression.Column{ UniqueID: p.SCtx().GetSessionVars().AllocPlanColumnID(), RetType: fun.Args[0].GetType(), diff --git a/pkg/planner/core/tiflash_selection_late_materialization.go b/pkg/planner/core/tiflash_selection_late_materialization.go index 2335e0a26b5df..70e80ee5fa41e 100644 --- a/pkg/planner/core/tiflash_selection_late_materialization.go +++ b/pkg/planner/core/tiflash_selection_late_materialization.go @@ -250,7 +250,7 @@ func predicatePushDownToTableScanImpl(sctx PlanContext, physicalSelection *Physi if len(selectedConds) == 0 { return } - logutil.BgLogger().Debug("planner: push down conditions to table scan", zap.String("table", physicalTableScan.Table.Name.L), zap.String("conditions", string(expression.SortedExplainExpressionList(sctx, selectedConds)))) + logutil.BgLogger().Debug("planner: push down conditions to table scan", zap.String("table", physicalTableScan.Table.Name.L), zap.String("conditions", string(expression.SortedExplainExpressionList(sctx.GetExprCtx(), selectedConds)))) PushedDown(physicalSelection, physicalTableScan, selectedConds, selectedSelectivity) } diff --git a/pkg/planner/core/util.go b/pkg/planner/core/util.go index d557c5685aba7..fca1573d2da9d 100644 --- a/pkg/planner/core/util.go +++ b/pkg/planner/core/util.go @@ -137,7 +137,7 @@ func (s *logicalSchemaProducer) setSchemaAndNames(schema *expression.Schema, nam // inlineProjection prunes unneeded columns inline a executor. func (s *logicalSchemaProducer) inlineProjection(parentUsedCols []*expression.Column, opt *logicalOptimizeOp) { prunedColumns := make([]*expression.Column, 0) - used := expression.GetUsedList(s.SCtx(), parentUsedCols, s.Schema()) + used := expression.GetUsedList(s.SCtx().GetExprCtx(), parentUsedCols, s.Schema()) for i := len(used) - 1; i >= 0; i-- { if !used[i] { prunedColumns = append(prunedColumns, s.Schema().Columns[i]) diff --git a/pkg/planner/util/path.go b/pkg/planner/util/path.go index 925e3fa7843d2..48b90029e12d6 100644 --- a/pkg/planner/util/path.go +++ b/pkg/planner/util/path.go @@ -124,7 +124,7 @@ func (path *AccessPath) IsTablePath() bool { func (path *AccessPath) SplitCorColAccessCondFromFilters(ctx context.PlanContext, eqOrInCount int) (access, remained []expression.Expression) { // The plan cache do not support subquery now. So we skip this function when // 'MaybeOverOptimized4PlanCache' function return true . - if expression.MaybeOverOptimized4PlanCache(ctx, path.TableFilters) { + if expression.MaybeOverOptimized4PlanCache(ctx.GetExprCtx(), path.TableFilters) { return nil, path.TableFilters } access = make([]expression.Expression, len(path.IdxCols)-eqOrInCount) @@ -334,7 +334,7 @@ func CompareCol2Len(c1, c2 Col2Len) (int, bool) { // GetCol2LenFromAccessConds returns columns with lengths from path.AccessConds. func (path *AccessPath) GetCol2LenFromAccessConds(ctx context.PlanContext) Col2Len { if path.IsTablePath() { - return ExtractCol2Len(ctx, path.AccessConds, nil, nil) + return ExtractCol2Len(ctx.GetExprCtx(), path.AccessConds, nil, nil) } - return ExtractCol2Len(ctx, path.AccessConds, path.IdxCols, path.IdxColLens) + return ExtractCol2Len(ctx.GetExprCtx(), path.AccessConds, path.IdxCols, path.IdxColLens) } diff --git a/pkg/session/bootstrap_test.go b/pkg/session/bootstrap_test.go index 6e1632d5f9598..66dd4fafb6c83 100644 --- a/pkg/session/bootstrap_test.go +++ b/pkg/session/bootstrap_test.go @@ -151,7 +151,7 @@ func TestBootstrapWithError(t *testing.T) { sessionVars: variable.NewSessionVars(nil), } se.exprctx = newExpressionContextImpl(se) - se.pctx = newPlanContextImpl(se, se.exprctx.ExprCtxExtendedImpl) + se.pctx = newPlanContextImpl(se) se.tblctx = tbctximpl.NewTableContextImpl(se, se.exprctx) globalVarsAccessor := variable.NewMockGlobalAccessor4Tests() se.GetSessionVars().GlobalVarsAccessor = globalVarsAccessor diff --git a/pkg/session/contextimpl.go b/pkg/session/contextimpl.go index bbbbf65a94437..d858bb2864363 100644 --- a/pkg/session/contextimpl.go +++ b/pkg/session/contextimpl.go @@ -31,10 +31,10 @@ type planContextImpl struct { } // NewPlanContextImpl creates a new PlanContextImpl. -func newPlanContextImpl(s *session, exprExtended *exprctximpl.ExprCtxExtendedImpl) *planContextImpl { +func newPlanContextImpl(s *session) *planContextImpl { return &planContextImpl{ session: s, - PlanCtxExtendedImpl: planctximpl.NewPlanCtxExtendedImpl(s, exprExtended), + PlanCtxExtendedImpl: planctximpl.NewPlanCtxExtendedImpl(s), } } diff --git a/pkg/session/session.go b/pkg/session/session.go index cd04ff34061c2..121ac5c228a98 100644 --- a/pkg/session/session.go +++ b/pkg/session/session.go @@ -3571,7 +3571,7 @@ func createSessionWithOpt(store kv.Storage, opt *Opt) (*session, error) { } s.sessionVars = variable.NewSessionVars(s) s.exprctx = newExpressionContextImpl(s) - s.pctx = newPlanContextImpl(s, s.exprctx.ExprCtxExtendedImpl) + s.pctx = newPlanContextImpl(s) s.tblctx = tbctximpl.NewTableContextImpl(s, s.exprctx) if opt != nil && opt.PreparedPlanCache != nil { @@ -3634,7 +3634,7 @@ func CreateSessionWithDomain(store kv.Storage, dom *domain.Domain) (*session, er sessionStatesHandlers: make(map[sessionstates.SessionStateType]sessionctx.SessionStatesHandler), } s.exprctx = newExpressionContextImpl(s) - s.pctx = newPlanContextImpl(s, s.exprctx.ExprCtxExtendedImpl) + s.pctx = newPlanContextImpl(s) s.tblctx = tbctximpl.NewTableContextImpl(s, s.exprctx) s.mu.values = make(map[fmt.Stringer]any) s.lockedTables = make(map[int64]model.TableLockTpInfo) diff --git a/pkg/util/ranger/detacher.go b/pkg/util/ranger/detacher.go index ec58b986c223a..b15680d3a3f86 100644 --- a/pkg/util/ranger/detacher.go +++ b/pkg/util/ranger/detacher.go @@ -131,7 +131,7 @@ func getPotentialEqOrInColOffset(sctx planctx.PlanContext, expr expression.Expre return -1 } if constVal, ok := f.GetArgs()[1].(*expression.Constant); ok { - val, err := constVal.Eval(sctx, chunk.Row{}) + val, err := constVal.Eval(sctx.GetExprCtx(), chunk.Row{}) if err != nil || (!sctx.GetSessionVars().RegardNULLAsPoint && val.IsNull()) { // treat col<=>null as range scan instead of point get to avoid incorrect results // when nullable unique index has multiple matches for filter x is null @@ -139,7 +139,7 @@ func getPotentialEqOrInColOffset(sctx planctx.PlanContext, expr expression.Expre } for i, col := range cols { // When cols are a generated expression col, compare them in terms of virtual expr. - if col.EqualByExprAndID(sctx, c) { + if col.EqualByExprAndID(sctx.GetExprCtx(), c) { return i } } @@ -153,7 +153,7 @@ func getPotentialEqOrInColOffset(sctx planctx.PlanContext, expr expression.Expre return -1 } if constVal, ok := f.GetArgs()[0].(*expression.Constant); ok { - val, err := constVal.Eval(sctx, chunk.Row{}) + val, err := constVal.Eval(sctx.GetExprCtx(), chunk.Row{}) if err != nil || (!sctx.GetSessionVars().RegardNULLAsPoint && val.IsNull()) { return -1 } @@ -373,7 +373,7 @@ func (d *rangeDetacher) detachCNFCondAndBuildRangeForIndex(conditions []expressi checkerCol: d.cols[eqOrInCount], length: d.lengths[eqOrInCount], optPrefixIndexSingleScan: d.sctx.GetSessionVars().OptPrefixIndexSingleScan, - ctx: d.sctx, + ctx: d.sctx.GetExprCtx(), } if considerDNF { bestCNFItemRes, columnValues, err := extractBestCNFItemRanges(d.sctx, conditions, d.cols, d.lengths, d.rangeMaxSize, d.convertToSortKey) @@ -451,7 +451,7 @@ func (d *rangeDetacher) detachCNFCondAndBuildRangeForIndex(conditions []expressi return res, nil } // `eqOrInCount` must be 0 when coming here. - res.AccessConds, res.RemainedConds = detachColumnCNFConditions(d.sctx, newConditions, checker) + res.AccessConds, res.RemainedConds = detachColumnCNFConditions(d.sctx.GetExprCtx(), newConditions, checker) ranges, res.AccessConds, remainedConds, err = d.buildCNFIndexRange(newTpSlice, 0, res.AccessConds) if err != nil { return nil, err @@ -638,7 +638,7 @@ func ExtractEqAndInCondition(sctx planctx.PlanContext, conditions []expression.E } points[offset] = rb.intersection(points[offset], rb.build(cond, newTp, types.UnspecifiedLength, false), collator) if len(points[offset]) == 0 { // Early termination if false expression found - if expression.MaybeOverOptimized4PlanCache(sctx, conditions) { + if expression.MaybeOverOptimized4PlanCache(sctx.GetExprCtx(), conditions) { // `a>@x and a<@y` --> `invalid-range if @x>=@y` sctx.GetSessionVars().StmtCtx.SetSkipPlanCache(errors.NewNoStackErrorf("some parameters may be overwritten")) } @@ -662,21 +662,21 @@ func ExtractEqAndInCondition(sctx planctx.PlanContext, conditions []expression.E // There exists an interval whose length is larger than 0 accesses[i] = nil } else if len(points[i]) == 0 { // Early termination if false expression found - if expression.MaybeOverOptimized4PlanCache(sctx, conditions) { + if expression.MaybeOverOptimized4PlanCache(sctx.GetExprCtx(), conditions) { // `a>@x and a<@y` --> `invalid-range if @x>=@y` sctx.GetSessionVars().StmtCtx.SetSkipPlanCache(errors.NewNoStackErrorf("some parameters may be overwritten")) } return nil, nil, nil, nil, true } else { // All Intervals are single points - accesses[i] = points2EqOrInCond(sctx, points[i], cols[i]) + accesses[i] = points2EqOrInCond(sctx.GetExprCtx(), points[i], cols[i]) newConditions = append(newConditions, accesses[i]) if f, ok := accesses[i].(*expression.ScalarFunction); ok && f.FuncName.L == ast.EQ { // Actually the constant column value may not be mutable. Here we assume it is mutable to keep it simple. // Maybe we can improve it later. columnValues[i] = &valueInfo{mutable: true} } - if expression.MaybeOverOptimized4PlanCache(sctx, conditions) { + if expression.MaybeOverOptimized4PlanCache(sctx.GetExprCtx(), conditions) { // `a=@x and a=@y` --> `a=@x if @x==@y` sctx.GetSessionVars().StmtCtx.SetSkipPlanCache(errors.NewNoStackErrorf("some parameters may be overwritten")) } @@ -717,7 +717,7 @@ func (d *rangeDetacher) detachDNFCondAndBuildRangeForIndex(condition *expression checkerCol: d.cols[0], length: d.lengths[0], optPrefixIndexSingleScan: d.sctx.GetSessionVars().OptPrefixIndexSingleScan, - ctx: d.sctx, + ctx: d.sctx.GetExprCtx(), } rb := builder{sctx: d.sctx} dnfItems := expression.FlattenDNFConditions(condition) @@ -749,7 +749,7 @@ func (d *rangeDetacher) detachDNFCondAndBuildRangeForIndex(condition *expression d.sctx.GetSessionVars().StmtCtx.RecordRangeFallback(d.rangeMaxSize) return FullRange(), nil, nil, true, nil } - newAccessItems = append(newAccessItems, expression.ComposeCNFCondition(d.sctx, accesses...)) + newAccessItems = append(newAccessItems, expression.ComposeCNFCondition(d.sctx.GetExprCtx(), accesses...)) if res.ColumnValues != nil { if i == 0 { columnValues = res.ColumnValues @@ -818,7 +818,7 @@ func (d *rangeDetacher) detachDNFCondAndBuildRangeForIndex(condition *expression return nil, nil, nil, false, errors.Trace(err) } - return totalRanges, []expression.Expression{expression.ComposeDNFCondition(d.sctx, newAccessItems...)}, columnValues, hasResidual, nil + return totalRanges, []expression.Expression{expression.ComposeDNFCondition(d.sctx.GetExprCtx(), newAccessItems...)}, columnValues, hasResidual, nil } // valueInfo is used for recording the constant column value in DetachCondAndBuildRangeForIndex. @@ -994,7 +994,7 @@ func ExtractAccessConditionsForColumn(ctx planctx.PlanContext, conds []expressio checkerCol: col, length: types.UnspecifiedLength, optPrefixIndexSingleScan: ctx.GetSessionVars().OptPrefixIndexSingleScan, - ctx: ctx, + ctx: ctx.GetExprCtx(), } accessConds := make([]expression.Expression, 0, 8) filter := func(expr expression.Expression) bool { @@ -1010,9 +1010,9 @@ func DetachCondsForColumn(sctx planctx.PlanContext, conds []expression.Expressio checkerCol: col, length: types.UnspecifiedLength, optPrefixIndexSingleScan: sctx.GetSessionVars().OptPrefixIndexSingleScan, - ctx: sctx, + ctx: sctx.GetExprCtx(), } - return detachColumnCNFConditions(sctx, conds, checker) + return detachColumnCNFConditions(sctx.GetExprCtx(), conds, checker) } // MergeDNFItems4Col receives a slice of DNF conditions, merges some of them which can be built into ranges on a single column, then returns. @@ -1034,7 +1034,7 @@ func MergeDNFItems4Col(ctx planctx.PlanContext, dnfItems []expression.Expression checkerCol: cols[0], length: types.UnspecifiedLength, optPrefixIndexSingleScan: ctx.GetSessionVars().OptPrefixIndexSingleScan, - ctx: ctx, + ctx: ctx.GetExprCtx(), } // If we can't use this condition to build range, we can't merge it. // Currently, we assume if every condition in a DNF expression can pass this check, then `Selectivity` must be able to @@ -1049,7 +1049,7 @@ func MergeDNFItems4Col(ctx planctx.PlanContext, dnfItems []expression.Expression col2DNFItems[uniqueID] = append(col2DNFItems[uniqueID], dnfItem) } for _, items := range col2DNFItems { - mergedDNFItems = append(mergedDNFItems, expression.ComposeDNFCondition(ctx, items...)) + mergedDNFItems = append(mergedDNFItems, expression.ComposeDNFCondition(ctx.GetExprCtx(), items...)) } return mergedDNFItems } @@ -1109,31 +1109,31 @@ func AddGcColumn4InCond(sctx planctx.PlanContext, for i, arg := range sf.GetArgs()[1:] { // get every const value and calculate tidb_shard(val) con := arg.(*expression.Constant) - conVal, err := con.Eval(sctx, chunk.Row{}) + conVal, err := con.Eval(sctx.GetExprCtx(), chunk.Row{}) if err != nil { return accessesCond, err } record[0] = conVal mutRow := chunk.MutRowFromDatums(record) - exprVal, err := expr.Eval(sctx, mutRow.ToRow()) + exprVal, err := expr.Eval(sctx.GetExprCtx(), mutRow.ToRow()) if err != nil { return accessesCond, err } // tmpArg1 is like `tidb_shard(a) = 8`, tmpArg2 is like `a = 100` exprCon := &expression.Constant{Value: exprVal, RetType: cols[0].RetType} - tmpArg1, err := expression.NewFunction(sctx, ast.EQ, cols[0].RetType, cols[0], exprCon) + tmpArg1, err := expression.NewFunction(sctx.GetExprCtx(), ast.EQ, cols[0].RetType, cols[0], exprCon) if err != nil { return accessesCond, err } - tmpArg2, err := expression.NewFunction(sctx, ast.EQ, c.RetType, c.Clone(), arg) + tmpArg2, err := expression.NewFunction(sctx.GetExprCtx(), ast.EQ, c.RetType, c.Clone(), arg) if err != nil { return accessesCond, err } // make a LogicAnd, e.g. `tidb_shard(a) = 8 AND a = 100` - andExpr, err := expression.NewFunction(sctx, ast.LogicAnd, andType, tmpArg1, tmpArg2) + andExpr, err := expression.NewFunction(sctx.GetExprCtx(), ast.LogicAnd, andType, tmpArg1, tmpArg2) if err != nil { return accessesCond, err } @@ -1143,7 +1143,7 @@ func AddGcColumn4InCond(sctx planctx.PlanContext, } else { // if the LogicAnd more than one, make a LogicOr, // e.g. `(tidb_shard(a) = 8 AND a = 100) OR (tidb_shard(a) = 161 AND a = 200)` - andOrExpr, errRes = expression.NewFunction(sctx, ast.LogicOr, andType, andOrExpr, andExpr) + andOrExpr, errRes = expression.NewFunction(sctx.GetExprCtx(), ast.LogicOr, andType, andOrExpr, andExpr) if errRes != nil { return accessesCond, errRes } @@ -1177,14 +1177,14 @@ func AddGcColumn4EqCond(sctx planctx.PlanContext, } mutRow := chunk.MutRowFromDatums(record) - evaluated, err := expr.Eval(sctx, mutRow.ToRow()) + evaluated, err := expr.Eval(sctx.GetExprCtx(), mutRow.ToRow()) if err != nil { return accessesCond, err } vi := &valueInfo{&evaluated, false} con := &expression.Constant{Value: evaluated, RetType: cols[0].RetType} // make a tidb_shard() function, e.g. `tidb_shard(a) = 8` - cond, err := expression.NewFunction(sctx, ast.EQ, cols[0].RetType, cols[0], con) + cond, err := expression.NewFunction(sctx.GetExprCtx(), ast.EQ, cols[0].RetType, cols[0], con) if err != nil { return accessesCond, err } diff --git a/pkg/util/ranger/points.go b/pkg/util/ranger/points.go index 05752d6f76ea4..b12102e6d7292 100644 --- a/pkg/util/ranger/points.go +++ b/pkg/util/ranger/points.go @@ -252,7 +252,7 @@ func (r *builder) build( } func (r *builder) buildFromConstant(expr *expression.Constant) []*point { - dt, err := expr.Eval(r.sctx, chunk.Row{}) + dt, err := expr.Eval(r.sctx.GetExprCtx(), chunk.Row{}) if err != nil { r.err = err return nil @@ -342,7 +342,7 @@ func (r *builder) buildFromBinOp( var ok bool if col, ok = expr.GetArgs()[0].(*expression.Column); ok { ft = col.RetType - value, err = expr.GetArgs()[1].Eval(r.sctx, chunk.Row{}) + value, err = expr.GetArgs()[1].Eval(r.sctx.GetExprCtx(), chunk.Row{}) if err != nil { return nil } @@ -353,7 +353,7 @@ func (r *builder) buildFromBinOp( return nil } ft = col.RetType - value, err = expr.GetArgs()[0].Eval(r.sctx, chunk.Row{}) + value, err = expr.GetArgs()[0].Eval(r.sctx.GetExprCtx(), chunk.Row{}) if err != nil { return nil } @@ -640,7 +640,7 @@ func (r *builder) buildFromIn( r.err = plannererrors.ErrUnsupportedType.GenWithStack("expr:%v is not constant", e) return getFullRange(), hasNull } - dt, err := v.Eval(r.sctx, chunk.Row{}) + dt, err := v.Eval(r.sctx.GetExprCtx(), chunk.Row{}) if err != nil { r.err = plannererrors.ErrUnsupportedType.GenWithStack("expr:%v is not evaluated", e) return getFullRange(), hasNull @@ -728,7 +728,7 @@ func (r *builder) newBuildFromPatternLike( if !collate.CompatibleCollate(expr.GetArgs()[0].GetType().GetCollate(), collation) { return getFullRange() } - pdt, err := expr.GetArgs()[1].(*expression.Constant).Eval(r.sctx, chunk.Row{}) + pdt, err := expr.GetArgs()[1].(*expression.Constant).Eval(r.sctx.GetExprCtx(), chunk.Row{}) tpOfPattern := expr.GetArgs()[0].GetType() if err != nil { r.err = errors.Trace(err) @@ -754,7 +754,7 @@ func (r *builder) newBuildFromPatternLike( return res } lowValue := make([]byte, 0, len(pattern)) - edt, err := expr.GetArgs()[2].(*expression.Constant).Eval(r.sctx, chunk.Row{}) + edt, err := expr.GetArgs()[2].(*expression.Constant).Eval(r.sctx.GetExprCtx(), chunk.Row{}) if err != nil { r.err = errors.Trace(err) return getFullRange()