From f6b6858772d2e5fc7064900111f534c8ae516bf9 Mon Sep 17 00:00:00 2001 From: Zijie Zhang Date: Wed, 4 Sep 2019 14:08:35 +0800 Subject: [PATCH] planner: fix aggregation hint didn't work in some cases (#11996) --- planner/core/expression_rewriter.go | 9 ++++ planner/core/logical_plan_builder.go | 9 ++-- planner/core/physical_plan_test.go | 63 +++++++++++++++------- planner/core/rule_aggregation_push_down.go | 22 ++++---- 4 files changed, 70 insertions(+), 33 deletions(-) diff --git a/planner/core/expression_rewriter.go b/planner/core/expression_rewriter.go index 0e37315805601..4f166c6e02039 100644 --- a/planner/core/expression_rewriter.go +++ b/planner/core/expression_rewriter.go @@ -442,6 +442,9 @@ func (er *expressionRewriter) handleCompareSubquery(ctx context.Context, v *ast. // it will be rewrote to t.id < (select max(s.id) from s). func (er *expressionRewriter) handleOtherComparableSubq(lexpr, rexpr expression.Expression, np LogicalPlan, useMin bool, cmpFunc string, all bool) { plan4Agg := LogicalAggregation{}.Init(er.sctx) + if hint := er.b.TableHints(); hint != nil { + plan4Agg.preferAggType = hint.preferAggType + } plan4Agg.SetChildren(np) // Create a "max" or "min" aggregation. @@ -568,6 +571,9 @@ func (er *expressionRewriter) handleNEAny(lexpr, rexpr expression.Expression, np plan4Agg := LogicalAggregation{ AggFuncs: []*aggregation.AggFuncDesc{firstRowFunc, countFunc}, }.Init(er.sctx) + if hint := er.b.TableHints(); hint != nil { + plan4Agg.preferAggType = hint.preferAggType + } plan4Agg.SetChildren(np) firstRowResultCol := &expression.Column{ ColName: model.NewCIStr("col_firstRow"), @@ -602,6 +608,9 @@ func (er *expressionRewriter) handleEQAll(lexpr, rexpr expression.Expression, np plan4Agg := LogicalAggregation{ AggFuncs: []*aggregation.AggFuncDesc{firstRowFunc, countFunc}, }.Init(er.sctx) + if hint := er.b.TableHints(); hint != nil { + plan4Agg.preferAggType = hint.preferAggType + } plan4Agg.SetChildren(np) firstRowResultCol := &expression.Column{ ColName: model.NewCIStr("col_firstRow"), diff --git a/planner/core/logical_plan_builder.go b/planner/core/logical_plan_builder.go index 76c34569a2490..2cf879471f86d 100644 --- a/planner/core/logical_plan_builder.go +++ b/planner/core/logical_plan_builder.go @@ -97,6 +97,9 @@ func (b *PlanBuilder) buildAggregation(ctx context.Context, p LogicalPlan, aggFu b.optFlag = b.optFlag | flagEliminateProjection plan4Agg := LogicalAggregation{AggFuncs: make([]*aggregation.AggFuncDesc, 0, len(aggFuncList))}.Init(b.ctx) + if hint := b.TableHints(); hint != nil { + plan4Agg.preferAggType = hint.preferAggType + } schema4Agg := expression.NewSchema(make([]*expression.Column, 0, len(aggFuncList)+p.Schema().Len())...) // aggIdxMap maps the old index to new index after applying common aggregation functions elimination. aggIndexMap := make(map[int]int) @@ -149,9 +152,6 @@ func (b *PlanBuilder) buildAggregation(ctx context.Context, p LogicalPlan, aggFu plan4Agg.GroupByItems = gbyItems plan4Agg.SetSchema(schema4Agg) plan4Agg.collectGroupByColumns() - if hint := b.TableHints(); hint != nil { - plan4Agg.preferAggType = hint.preferAggType - } return plan4Agg, aggIndexMap, nil } @@ -790,6 +790,9 @@ func (b *PlanBuilder) buildDistinct(child LogicalPlan, length int) (*LogicalAggr AggFuncs: make([]*aggregation.AggFuncDesc, 0, child.Schema().Len()), GroupByItems: expression.Column2Exprs(child.Schema().Clone().Columns[:length]), }.Init(b.ctx) + if hint := b.TableHints(); hint != nil { + plan4Agg.preferAggType = hint.preferAggType + } plan4Agg.collectGroupByColumns() for _, col := range child.Schema().Columns { aggDesc, err := aggregation.NewAggFuncDesc(b.ctx, ast.AggFuncFirstRow, []expression.Expression{col}, false) diff --git a/planner/core/physical_plan_test.go b/planner/core/physical_plan_test.go index 6ddeee0a6f32c..79f96c579d1bc 100644 --- a/planner/core/physical_plan_test.go +++ b/planner/core/physical_plan_test.go @@ -1587,31 +1587,28 @@ func (s *testPlanSuite) TestAggregationHints(c *C) { c.Assert(err, IsNil) tests := []struct { - sql string - best string - warning string + sql string + best string + warning string + aggPushDown bool }{ // without Aggregation hints { - sql: "select count(*) from t t1, t t2 where t1.a = t2.b", - best: "LeftHashJoin{TableReader(Table(t))->TableReader(Table(t))}(test.t1.a,test.t2.b)->StreamAgg", - warning: "", + sql: "select count(*) from t t1, t t2 where t1.a = t2.b", + best: "LeftHashJoin{TableReader(Table(t))->TableReader(Table(t))}(test.t1.a,test.t2.b)->StreamAgg", }, { - sql: "select count(t1.a) from t t1, t t2 where t1.a = t2.a*2 group by t1.a", - best: "LeftHashJoin{TableReader(Table(t))->TableReader(Table(t))->Projection}(test.t1.a,mul(test.t2.a, 2))->HashAgg", - warning: "", + sql: "select count(t1.a) from t t1, t t2 where t1.a = t2.a*2 group by t1.a", + best: "LeftHashJoin{TableReader(Table(t))->TableReader(Table(t))->Projection}(test.t1.a,mul(test.t2.a, 2))->HashAgg", }, // with Aggregation hints { - sql: "select /*+ HASH_AGG() */ count(*) from t t1, t t2 where t1.a = t2.b", - best: "LeftHashJoin{TableReader(Table(t))->TableReader(Table(t))}(test.t1.a,test.t2.b)->HashAgg", - warning: "", + sql: "select /*+ HASH_AGG() */ count(*) from t t1, t t2 where t1.a = t2.b", + best: "LeftHashJoin{TableReader(Table(t))->TableReader(Table(t))}(test.t1.a,test.t2.b)->HashAgg", }, { - sql: "select /*+ STREAM_AGG() */ count(t1.a) from t t1, t t2 where t1.a = t2.a*2 group by t1.a", - best: "LeftHashJoin{TableReader(Table(t))->TableReader(Table(t))->Projection}(test.t1.a,mul(test.t2.a, 2))->Sort->StreamAgg", - warning: "", + sql: "select /*+ STREAM_AGG() */ count(t1.a) from t t1, t t2 where t1.a = t2.a*2 group by t1.a", + best: "LeftHashJoin{TableReader(Table(t))->TableReader(Table(t))->Projection}(test.t1.a,mul(test.t2.a, 2))->Sort->StreamAgg", }, // test conflict warning { @@ -1619,24 +1616,50 @@ func (s *testPlanSuite) TestAggregationHints(c *C) { best: "LeftHashJoin{TableReader(Table(t))->TableReader(Table(t))}(test.t1.a,test.t2.b)->StreamAgg", warning: "[planner:1815]Optimizer aggregation hints are conflicted", }, + // additional test + { + sql: "select /*+ STREAM_AGG() */ distinct a from t", + best: "TableReader(Table(t)->StreamAgg)->StreamAgg", + }, + { + sql: "select /*+ HASH_AGG() */ t1.a from t t1 where t1.a < any(select t2.b from t t2)", + best: "LeftHashJoin{TableReader(Table(t)->Sel([if(isnull(test.t1.a), , 1)]))->TableReader(Table(t)->HashAgg)->HashAgg->Sel([ne(agg_col_cnt, 0)])}->Projection->Projection", + }, + { + sql: "select /*+ hash_agg() */ t1.a from t t1 where t1.a != any(select t2.b from t t2)", + best: "LeftHashJoin{TableReader(Table(t)->Sel([if(isnull(test.t1.a), , 1)]))->TableReader(Table(t))->Projection->HashAgg->Sel([ne(agg_col_cnt, 0)])}->Projection->Projection", + }, + { + sql: "select /*+ hash_agg() */ t1.a from t t1 where t1.a = all(select t2.b from t t2)", + best: "LeftHashJoin{TableReader(Table(t))->TableReader(Table(t))->Projection->HashAgg}->Projection->Projection", + }, + { + sql: "select /*+ STREAM_AGG() */ sum(t1.a) from t t1 join t t2 on t1.b = t2.b group by t1.b", + best: "LeftHashJoin{TableReader(Table(t))->TableReader(Table(t))->Sort->Projection->StreamAgg}(test.t2.b,test.t1.b)->HashAgg", + warning: "[planner:1815]Optimizer Hint STREAM_AGG is inapplicable", + aggPushDown: true, + }, } ctx := context.Background() for i, test := range tests { comment := Commentf("case:%v sql:%s", i, test) + se.GetSessionVars().StmtCtx.SetWarnings(nil) + se.GetSessionVars().AllowAggPushDown = test.aggPushDown + stmt, err := s.ParseOneStmt(test.sql, "", "") c.Assert(err, IsNil, comment) p, err := planner.Optimize(ctx, se, stmt, s.is) c.Assert(err, IsNil) - c.Assert(core.ToString(p), Equals, test.best) + c.Assert(core.ToString(p), Equals, test.best, comment) warnings := se.GetSessionVars().StmtCtx.GetWarnings() if test.warning == "" { - c.Assert(len(warnings), Equals, 0) + c.Assert(len(warnings), Equals, 0, comment) } else { - c.Assert(len(warnings), Equals, 1) - c.Assert(warnings[0].Level, Equals, stmtctx.WarnLevelWarning) - c.Assert(warnings[0].Err.Error(), Equals, test.warning) + c.Assert(len(warnings), Equals, 1, comment) + c.Assert(warnings[0].Level, Equals, stmtctx.WarnLevelWarning, comment) + c.Assert(warnings[0].Err.Error(), Equals, test.warning, comment) } } } diff --git a/planner/core/rule_aggregation_push_down.go b/planner/core/rule_aggregation_push_down.go index 6f7488667b0ac..157f8b6711a7f 100644 --- a/planner/core/rule_aggregation_push_down.go +++ b/planner/core/rule_aggregation_push_down.go @@ -189,7 +189,7 @@ func (a *aggregationPushDownSolver) decompose(ctx sessionctx.Context, aggFunc *a // tryToPushDownAgg tries to push down an aggregate function into a join path. If all aggFuncs are first row, we won't // process it temporarily. If not, We will add additional group by columns and first row functions. We make a new aggregation operator. // If the pushed aggregation is grouped by unique key, it's no need to push it down. -func (a *aggregationPushDownSolver) tryToPushDownAgg(aggFuncs []*aggregation.AggFuncDesc, gbyCols []*expression.Column, join *LogicalJoin, childIdx int) (_ LogicalPlan, err error) { +func (a *aggregationPushDownSolver) tryToPushDownAgg(aggFuncs []*aggregation.AggFuncDesc, gbyCols []*expression.Column, join *LogicalJoin, childIdx int, preferAggType uint) (_ LogicalPlan, err error) { child := join.children[childIdx] if aggregation.IsAllFirstRow(aggFuncs) { return child, nil @@ -204,7 +204,7 @@ func (a *aggregationPushDownSolver) tryToPushDownAgg(aggFuncs []*aggregation.Agg return child, nil } } - agg, err := a.makeNewAgg(join.ctx, aggFuncs, gbyCols) + agg, err := a.makeNewAgg(join.ctx, aggFuncs, gbyCols, preferAggType) if err != nil { return nil, err } @@ -247,10 +247,11 @@ func (a *aggregationPushDownSolver) checkAnyCountAndSum(aggFuncs []*aggregation. return false } -func (a *aggregationPushDownSolver) makeNewAgg(ctx sessionctx.Context, aggFuncs []*aggregation.AggFuncDesc, gbyCols []*expression.Column) (*LogicalAggregation, error) { +func (a *aggregationPushDownSolver) makeNewAgg(ctx sessionctx.Context, aggFuncs []*aggregation.AggFuncDesc, gbyCols []*expression.Column, preferAggType uint) (*LogicalAggregation, error) { agg := LogicalAggregation{ - GroupByItems: expression.Column2Exprs(gbyCols), - groupByCols: gbyCols, + GroupByItems: expression.Column2Exprs(gbyCols), + groupByCols: gbyCols, + preferAggType: preferAggType, }.Init(ctx) aggLen := len(aggFuncs) + len(gbyCols) newAggFuncDescs := make([]*aggregation.AggFuncDesc, 0, aggLen) @@ -282,8 +283,9 @@ func (a *aggregationPushDownSolver) makeNewAgg(ctx sessionctx.Context, aggFuncs func (a *aggregationPushDownSolver) pushAggCrossUnion(agg *LogicalAggregation, unionSchema *expression.Schema, unionChild LogicalPlan) LogicalPlan { ctx := agg.ctx newAgg := LogicalAggregation{ - AggFuncs: make([]*aggregation.AggFuncDesc, 0, len(agg.AggFuncs)), - GroupByItems: make([]expression.Expression, 0, len(agg.GroupByItems)), + AggFuncs: make([]*aggregation.AggFuncDesc, 0, len(agg.AggFuncs)), + GroupByItems: make([]expression.Expression, 0, len(agg.GroupByItems)), + preferAggType: agg.preferAggType, }.Init(ctx) newAgg.SetSchema(agg.schema.Clone()) for _, aggFunc := range agg.AggFuncs { @@ -340,7 +342,7 @@ func (a *aggregationPushDownSolver) aggPushDown(p LogicalPlan) (_ LogicalPlan, e if rightInvalid { rChild = join.children[1] } else { - rChild, err = a.tryToPushDownAgg(rightAggFuncs, rightGbyCols, join, 1) + rChild, err = a.tryToPushDownAgg(rightAggFuncs, rightGbyCols, join, 1, agg.preferAggType) if err != nil { return nil, err } @@ -348,7 +350,7 @@ func (a *aggregationPushDownSolver) aggPushDown(p LogicalPlan) (_ LogicalPlan, e if leftInvalid { lChild = join.children[0] } else { - lChild, err = a.tryToPushDownAgg(leftAggFuncs, leftGbyCols, join, 0) + lChild, err = a.tryToPushDownAgg(leftAggFuncs, leftGbyCols, join, 0, agg.preferAggType) if err != nil { return nil, err } @@ -380,7 +382,7 @@ func (a *aggregationPushDownSolver) aggPushDown(p LogicalPlan) (_ LogicalPlan, e } else if union, ok1 := child.(*LogicalUnionAll); ok1 { var gbyCols []*expression.Column gbyCols = expression.ExtractColumnsFromExpressions(gbyCols, agg.GroupByItems, nil) - pushedAgg, err := a.makeNewAgg(agg.ctx, agg.AggFuncs, gbyCols) + pushedAgg, err := a.makeNewAgg(agg.ctx, agg.AggFuncs, gbyCols, agg.preferAggType) if err != nil { return nil, err }