Skip to content

Commit

Permalink
planner: fix aggregation hint didn't work in some cases (pingcap#11996)
Browse files Browse the repository at this point in the history
  • Loading branch information
foreyes committed Sep 16, 2019
1 parent c8f7267 commit f6b6858
Show file tree
Hide file tree
Showing 4 changed files with 70 additions and 33 deletions.
9 changes: 9 additions & 0 deletions planner/core/expression_rewriter.go
Original file line number Diff line number Diff line change
Expand Up @@ -442,6 +442,9 @@ func (er *expressionRewriter) handleCompareSubquery(ctx context.Context, v *ast.
// it will be rewrote to t.id < (select max(s.id) from s).
func (er *expressionRewriter) handleOtherComparableSubq(lexpr, rexpr expression.Expression, np LogicalPlan, useMin bool, cmpFunc string, all bool) {
plan4Agg := LogicalAggregation{}.Init(er.sctx)
if hint := er.b.TableHints(); hint != nil {
plan4Agg.preferAggType = hint.preferAggType
}
plan4Agg.SetChildren(np)

// Create a "max" or "min" aggregation.
Expand Down Expand Up @@ -568,6 +571,9 @@ func (er *expressionRewriter) handleNEAny(lexpr, rexpr expression.Expression, np
plan4Agg := LogicalAggregation{
AggFuncs: []*aggregation.AggFuncDesc{firstRowFunc, countFunc},
}.Init(er.sctx)
if hint := er.b.TableHints(); hint != nil {
plan4Agg.preferAggType = hint.preferAggType
}
plan4Agg.SetChildren(np)
firstRowResultCol := &expression.Column{
ColName: model.NewCIStr("col_firstRow"),
Expand Down Expand Up @@ -602,6 +608,9 @@ func (er *expressionRewriter) handleEQAll(lexpr, rexpr expression.Expression, np
plan4Agg := LogicalAggregation{
AggFuncs: []*aggregation.AggFuncDesc{firstRowFunc, countFunc},
}.Init(er.sctx)
if hint := er.b.TableHints(); hint != nil {
plan4Agg.preferAggType = hint.preferAggType
}
plan4Agg.SetChildren(np)
firstRowResultCol := &expression.Column{
ColName: model.NewCIStr("col_firstRow"),
Expand Down
9 changes: 6 additions & 3 deletions planner/core/logical_plan_builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,9 @@ func (b *PlanBuilder) buildAggregation(ctx context.Context, p LogicalPlan, aggFu
b.optFlag = b.optFlag | flagEliminateProjection

plan4Agg := LogicalAggregation{AggFuncs: make([]*aggregation.AggFuncDesc, 0, len(aggFuncList))}.Init(b.ctx)
if hint := b.TableHints(); hint != nil {
plan4Agg.preferAggType = hint.preferAggType
}
schema4Agg := expression.NewSchema(make([]*expression.Column, 0, len(aggFuncList)+p.Schema().Len())...)
// aggIdxMap maps the old index to new index after applying common aggregation functions elimination.
aggIndexMap := make(map[int]int)
Expand Down Expand Up @@ -149,9 +152,6 @@ func (b *PlanBuilder) buildAggregation(ctx context.Context, p LogicalPlan, aggFu
plan4Agg.GroupByItems = gbyItems
plan4Agg.SetSchema(schema4Agg)
plan4Agg.collectGroupByColumns()
if hint := b.TableHints(); hint != nil {
plan4Agg.preferAggType = hint.preferAggType
}
return plan4Agg, aggIndexMap, nil
}

Expand Down Expand Up @@ -790,6 +790,9 @@ func (b *PlanBuilder) buildDistinct(child LogicalPlan, length int) (*LogicalAggr
AggFuncs: make([]*aggregation.AggFuncDesc, 0, child.Schema().Len()),
GroupByItems: expression.Column2Exprs(child.Schema().Clone().Columns[:length]),
}.Init(b.ctx)
if hint := b.TableHints(); hint != nil {
plan4Agg.preferAggType = hint.preferAggType
}
plan4Agg.collectGroupByColumns()
for _, col := range child.Schema().Columns {
aggDesc, err := aggregation.NewAggFuncDesc(b.ctx, ast.AggFuncFirstRow, []expression.Expression{col}, false)
Expand Down
63 changes: 43 additions & 20 deletions planner/core/physical_plan_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1587,56 +1587,79 @@ func (s *testPlanSuite) TestAggregationHints(c *C) {
c.Assert(err, IsNil)

tests := []struct {
sql string
best string
warning string
sql string
best string
warning string
aggPushDown bool
}{
// without Aggregation hints
{
sql: "select count(*) from t t1, t t2 where t1.a = t2.b",
best: "LeftHashJoin{TableReader(Table(t))->TableReader(Table(t))}(test.t1.a,test.t2.b)->StreamAgg",
warning: "",
sql: "select count(*) from t t1, t t2 where t1.a = t2.b",
best: "LeftHashJoin{TableReader(Table(t))->TableReader(Table(t))}(test.t1.a,test.t2.b)->StreamAgg",
},
{
sql: "select count(t1.a) from t t1, t t2 where t1.a = t2.a*2 group by t1.a",
best: "LeftHashJoin{TableReader(Table(t))->TableReader(Table(t))->Projection}(test.t1.a,mul(test.t2.a, 2))->HashAgg",
warning: "",
sql: "select count(t1.a) from t t1, t t2 where t1.a = t2.a*2 group by t1.a",
best: "LeftHashJoin{TableReader(Table(t))->TableReader(Table(t))->Projection}(test.t1.a,mul(test.t2.a, 2))->HashAgg",
},
// with Aggregation hints
{
sql: "select /*+ HASH_AGG() */ count(*) from t t1, t t2 where t1.a = t2.b",
best: "LeftHashJoin{TableReader(Table(t))->TableReader(Table(t))}(test.t1.a,test.t2.b)->HashAgg",
warning: "",
sql: "select /*+ HASH_AGG() */ count(*) from t t1, t t2 where t1.a = t2.b",
best: "LeftHashJoin{TableReader(Table(t))->TableReader(Table(t))}(test.t1.a,test.t2.b)->HashAgg",
},
{
sql: "select /*+ STREAM_AGG() */ count(t1.a) from t t1, t t2 where t1.a = t2.a*2 group by t1.a",
best: "LeftHashJoin{TableReader(Table(t))->TableReader(Table(t))->Projection}(test.t1.a,mul(test.t2.a, 2))->Sort->StreamAgg",
warning: "",
sql: "select /*+ STREAM_AGG() */ count(t1.a) from t t1, t t2 where t1.a = t2.a*2 group by t1.a",
best: "LeftHashJoin{TableReader(Table(t))->TableReader(Table(t))->Projection}(test.t1.a,mul(test.t2.a, 2))->Sort->StreamAgg",
},
// test conflict warning
{
sql: "select /*+ HASH_AGG() STREAM_AGG() */ count(*) from t t1, t t2 where t1.a = t2.b",
best: "LeftHashJoin{TableReader(Table(t))->TableReader(Table(t))}(test.t1.a,test.t2.b)->StreamAgg",
warning: "[planner:1815]Optimizer aggregation hints are conflicted",
},
// additional test
{
sql: "select /*+ STREAM_AGG() */ distinct a from t",
best: "TableReader(Table(t)->StreamAgg)->StreamAgg",
},
{
sql: "select /*+ HASH_AGG() */ t1.a from t t1 where t1.a < any(select t2.b from t t2)",
best: "LeftHashJoin{TableReader(Table(t)->Sel([if(isnull(test.t1.a), <nil>, 1)]))->TableReader(Table(t)->HashAgg)->HashAgg->Sel([ne(agg_col_cnt, 0)])}->Projection->Projection",
},
{
sql: "select /*+ hash_agg() */ t1.a from t t1 where t1.a != any(select t2.b from t t2)",
best: "LeftHashJoin{TableReader(Table(t)->Sel([if(isnull(test.t1.a), <nil>, 1)]))->TableReader(Table(t))->Projection->HashAgg->Sel([ne(agg_col_cnt, 0)])}->Projection->Projection",
},
{
sql: "select /*+ hash_agg() */ t1.a from t t1 where t1.a = all(select t2.b from t t2)",
best: "LeftHashJoin{TableReader(Table(t))->TableReader(Table(t))->Projection->HashAgg}->Projection->Projection",
},
{
sql: "select /*+ STREAM_AGG() */ sum(t1.a) from t t1 join t t2 on t1.b = t2.b group by t1.b",
best: "LeftHashJoin{TableReader(Table(t))->TableReader(Table(t))->Sort->Projection->StreamAgg}(test.t2.b,test.t1.b)->HashAgg",
warning: "[planner:1815]Optimizer Hint STREAM_AGG is inapplicable",
aggPushDown: true,
},
}
ctx := context.Background()
for i, test := range tests {
comment := Commentf("case:%v sql:%s", i, test)
se.GetSessionVars().StmtCtx.SetWarnings(nil)
se.GetSessionVars().AllowAggPushDown = test.aggPushDown

stmt, err := s.ParseOneStmt(test.sql, "", "")
c.Assert(err, IsNil, comment)

p, err := planner.Optimize(ctx, se, stmt, s.is)
c.Assert(err, IsNil)
c.Assert(core.ToString(p), Equals, test.best)
c.Assert(core.ToString(p), Equals, test.best, comment)

warnings := se.GetSessionVars().StmtCtx.GetWarnings()
if test.warning == "" {
c.Assert(len(warnings), Equals, 0)
c.Assert(len(warnings), Equals, 0, comment)
} else {
c.Assert(len(warnings), Equals, 1)
c.Assert(warnings[0].Level, Equals, stmtctx.WarnLevelWarning)
c.Assert(warnings[0].Err.Error(), Equals, test.warning)
c.Assert(len(warnings), Equals, 1, comment)
c.Assert(warnings[0].Level, Equals, stmtctx.WarnLevelWarning, comment)
c.Assert(warnings[0].Err.Error(), Equals, test.warning, comment)
}
}
}
Expand Down
22 changes: 12 additions & 10 deletions planner/core/rule_aggregation_push_down.go
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ func (a *aggregationPushDownSolver) decompose(ctx sessionctx.Context, aggFunc *a
// tryToPushDownAgg tries to push down an aggregate function into a join path. If all aggFuncs are first row, we won't
// process it temporarily. If not, We will add additional group by columns and first row functions. We make a new aggregation operator.
// If the pushed aggregation is grouped by unique key, it's no need to push it down.
func (a *aggregationPushDownSolver) tryToPushDownAgg(aggFuncs []*aggregation.AggFuncDesc, gbyCols []*expression.Column, join *LogicalJoin, childIdx int) (_ LogicalPlan, err error) {
func (a *aggregationPushDownSolver) tryToPushDownAgg(aggFuncs []*aggregation.AggFuncDesc, gbyCols []*expression.Column, join *LogicalJoin, childIdx int, preferAggType uint) (_ LogicalPlan, err error) {
child := join.children[childIdx]
if aggregation.IsAllFirstRow(aggFuncs) {
return child, nil
Expand All @@ -204,7 +204,7 @@ func (a *aggregationPushDownSolver) tryToPushDownAgg(aggFuncs []*aggregation.Agg
return child, nil
}
}
agg, err := a.makeNewAgg(join.ctx, aggFuncs, gbyCols)
agg, err := a.makeNewAgg(join.ctx, aggFuncs, gbyCols, preferAggType)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -247,10 +247,11 @@ func (a *aggregationPushDownSolver) checkAnyCountAndSum(aggFuncs []*aggregation.
return false
}

func (a *aggregationPushDownSolver) makeNewAgg(ctx sessionctx.Context, aggFuncs []*aggregation.AggFuncDesc, gbyCols []*expression.Column) (*LogicalAggregation, error) {
func (a *aggregationPushDownSolver) makeNewAgg(ctx sessionctx.Context, aggFuncs []*aggregation.AggFuncDesc, gbyCols []*expression.Column, preferAggType uint) (*LogicalAggregation, error) {
agg := LogicalAggregation{
GroupByItems: expression.Column2Exprs(gbyCols),
groupByCols: gbyCols,
GroupByItems: expression.Column2Exprs(gbyCols),
groupByCols: gbyCols,
preferAggType: preferAggType,
}.Init(ctx)
aggLen := len(aggFuncs) + len(gbyCols)
newAggFuncDescs := make([]*aggregation.AggFuncDesc, 0, aggLen)
Expand Down Expand Up @@ -282,8 +283,9 @@ func (a *aggregationPushDownSolver) makeNewAgg(ctx sessionctx.Context, aggFuncs
func (a *aggregationPushDownSolver) pushAggCrossUnion(agg *LogicalAggregation, unionSchema *expression.Schema, unionChild LogicalPlan) LogicalPlan {
ctx := agg.ctx
newAgg := LogicalAggregation{
AggFuncs: make([]*aggregation.AggFuncDesc, 0, len(agg.AggFuncs)),
GroupByItems: make([]expression.Expression, 0, len(agg.GroupByItems)),
AggFuncs: make([]*aggregation.AggFuncDesc, 0, len(agg.AggFuncs)),
GroupByItems: make([]expression.Expression, 0, len(agg.GroupByItems)),
preferAggType: agg.preferAggType,
}.Init(ctx)
newAgg.SetSchema(agg.schema.Clone())
for _, aggFunc := range agg.AggFuncs {
Expand Down Expand Up @@ -340,15 +342,15 @@ func (a *aggregationPushDownSolver) aggPushDown(p LogicalPlan) (_ LogicalPlan, e
if rightInvalid {
rChild = join.children[1]
} else {
rChild, err = a.tryToPushDownAgg(rightAggFuncs, rightGbyCols, join, 1)
rChild, err = a.tryToPushDownAgg(rightAggFuncs, rightGbyCols, join, 1, agg.preferAggType)
if err != nil {
return nil, err
}
}
if leftInvalid {
lChild = join.children[0]
} else {
lChild, err = a.tryToPushDownAgg(leftAggFuncs, leftGbyCols, join, 0)
lChild, err = a.tryToPushDownAgg(leftAggFuncs, leftGbyCols, join, 0, agg.preferAggType)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -380,7 +382,7 @@ func (a *aggregationPushDownSolver) aggPushDown(p LogicalPlan) (_ LogicalPlan, e
} else if union, ok1 := child.(*LogicalUnionAll); ok1 {
var gbyCols []*expression.Column
gbyCols = expression.ExtractColumnsFromExpressions(gbyCols, agg.GroupByItems, nil)
pushedAgg, err := a.makeNewAgg(agg.ctx, agg.AggFuncs, gbyCols)
pushedAgg, err := a.makeNewAgg(agg.ctx, agg.AggFuncs, gbyCols, agg.preferAggType)
if err != nil {
return nil, err
}
Expand Down

0 comments on commit f6b6858

Please sign in to comment.