diff --git a/cmd/explaintest/r/topn_push_down.result b/cmd/explaintest/r/topn_push_down.result index 4d40a3b3b8caf..89813ab77c0c0 100644 --- a/cmd/explaintest/r/topn_push_down.result +++ b/cmd/explaintest/r/topn_push_down.result @@ -240,19 +240,19 @@ explain select /*+ TIDB_SMJ(t1, t2) */ * from t t1 join t t2 on t1.a = t2.a limi id count task operator info Limit_11 5.00 root offset:0, count:5 └─MergeJoin_12 5.00 root inner join, left key:test.t1.a, right key:test.t2.a - ├─IndexReader_14 4.00 root index:IndexScan_13 - │ └─IndexScan_13 4.00 cop table:t1, index:a, range:[NULL,+inf], keep order:true, stats:pseudo - └─IndexReader_16 4.00 root index:IndexScan_15 - └─IndexScan_15 4.00 cop table:t2, index:a, range:[NULL,+inf], keep order:true, stats:pseudo + ├─IndexReader_15 4.00 root index:IndexScan_14 + │ └─IndexScan_14 4.00 cop table:t1, index:a, range:[NULL,+inf], keep order:true, stats:pseudo + └─IndexReader_17 4.00 root index:IndexScan_16 + └─IndexScan_16 4.00 cop table:t2, index:a, range:[NULL,+inf], keep order:true, stats:pseudo explain select /*+ TIDB_SMJ(t1, t2) */ * from t t1 left join t t2 on t1.a = t2.a where t2.a is null limit 5; id count task operator info Limit_12 5.00 root offset:0, count:5 └─Selection_13 5.00 root isnull(test.t2.a) └─MergeJoin_14 5.00 root left outer join, left key:test.t1.a, right key:test.t2.a - ├─IndexReader_16 4.00 root index:IndexScan_15 - │ └─IndexScan_15 4.00 cop table:t1, index:a, range:[NULL,+inf], keep order:true, stats:pseudo - └─IndexReader_18 4.00 root index:IndexScan_17 - └─IndexScan_17 4.00 cop table:t2, index:a, range:[NULL,+inf], keep order:true, stats:pseudo + ├─IndexReader_17 4.00 root index:IndexScan_16 + │ └─IndexScan_16 4.00 cop table:t1, index:a, range:[NULL,+inf], keep order:true, stats:pseudo + └─IndexReader_19 4.00 root index:IndexScan_18 + └─IndexScan_18 4.00 cop table:t2, index:a, range:[NULL,+inf], keep order:true, stats:pseudo explain select /*+ TIDB_HJ(t1, t2) */ * from t t1 join t t2 on t1.a = t2.a limit 5; id count task operator info Limit_11 5.00 root offset:0, count:5 diff --git a/executor/join_test.go b/executor/join_test.go index 0f0b66214bf9d..1fa8ba6f20c37 100644 --- a/executor/join_test.go +++ b/executor/join_test.go @@ -143,13 +143,13 @@ func (s *testSuite2) TestJoin(c *C) { tk.MustQuery("select /*+ TIDB_INLJ(t1) */ * from t right outer join t1 on t.a=t1.a").Check(testkit.Rows("1 1 1 2", "1 1 1 3", "1 1 1 4", "3 3 3 4", " 4 5")) tk.MustQuery("select /*+ TIDB_INLJ(t) */ avg(t.b) from t right outer join t1 on t.a=t1.a").Check(testkit.Rows("1.5000")) - // Test that two conflict hints will return error. - err := tk.ExecToErr("select /*+ TIDB_INLJ(t) TIDB_SMJ(t) */ * from t join t1 on t.a=t1.a") - c.Assert(err, NotNil) - err = tk.ExecToErr("select /*+ TIDB_INLJ(t) TIDB_HJ(t) */ from t join t1 on t.a=t1.a") - c.Assert(err, NotNil) - err = tk.ExecToErr("select /*+ TIDB_SMJ(t) TIDB_HJ(t) */ from t join t1 on t.a=t1.a") - c.Assert(err, NotNil) + // Test that two conflict hints will return warning. + tk.MustExec("select /*+ TIDB_INLJ(t) TIDB_SMJ(t) */ * from t join t1 on t.a=t1.a") + c.Assert(tk.Se.GetSessionVars().StmtCtx.GetWarnings(), HasLen, 1) + tk.MustExec("select /*+ TIDB_INLJ(t) TIDB_HJ(t) */ * from t join t1 on t.a=t1.a") + c.Assert(tk.Se.GetSessionVars().StmtCtx.GetWarnings(), HasLen, 1) + tk.MustExec("select /*+ TIDB_SMJ(t) TIDB_HJ(t) */ * from t join t1 on t.a=t1.a") + c.Assert(tk.Se.GetSessionVars().StmtCtx.GetWarnings(), HasLen, 1) tk.MustExec("drop table if exists t") tk.MustExec("create table t(a int)") @@ -888,11 +888,11 @@ func (s *testSuite2) TestMergejoinOrder(c *C) { tk.MustExec("insert into t2 select a*100, b*100 from t1;") tk.MustQuery("explain select /*+ TIDB_SMJ(t2) */ * from t1 left outer join t2 on t1.a=t2.a and t1.a!=3 order by t1.a;").Check(testkit.Rows( - "MergeJoin_15 10000.00 root left outer join, left key:test.t1.a, right key:test.t2.a, left cond:[ne(test.t1.a, 3)]", - "├─TableReader_11 10000.00 root data:TableScan_10", - "│ └─TableScan_10 10000.00 cop table:t1, range:[-inf,+inf], keep order:true, stats:pseudo", - "└─TableReader_13 6666.67 root data:TableScan_12", - " └─TableScan_12 6666.67 cop table:t2, range:[-inf,3), (3,+inf], keep order:true, stats:pseudo", + "MergeJoin_20 10000.00 root left outer join, left key:test.t1.a, right key:test.t2.a, left cond:[ne(test.t1.a, 3)]", + "├─TableReader_12 10000.00 root data:TableScan_11", + "│ └─TableScan_11 10000.00 cop table:t1, range:[-inf,+inf], keep order:true, stats:pseudo", + "└─TableReader_14 6666.67 root data:TableScan_13", + " └─TableScan_13 6666.67 cop table:t2, range:[-inf,3), (3,+inf], keep order:true, stats:pseudo", )) tk.MustExec("set @@tidb_init_chunk_size=1") diff --git a/planner/core/exhaust_physical_plans.go b/planner/core/exhaust_physical_plans.go index 5ff411d6e5e7f..73145def95be6 100644 --- a/planner/core/exhaust_physical_plans.go +++ b/planner/core/exhaust_physical_plans.go @@ -117,7 +117,7 @@ func (p *PhysicalMergeJoin) tryToGetChildReqProp(prop *property.PhysicalProperty } func (p *LogicalJoin) getMergeJoin(prop *property.PhysicalProperty) []PhysicalPlan { - joins := make([]PhysicalPlan, 0, len(p.leftProperties)) + joins := make([]PhysicalPlan, 0, len(p.leftProperties)+1) // The leftProperties caches all the possible properties that are provided by its children. for _, lhsChildProperty := range p.leftProperties { offsets := getMaxSortPrefix(lhsChildProperty, p.LeftJoinKeys) @@ -158,10 +158,10 @@ func (p *LogicalJoin) getMergeJoin(prop *property.PhysicalProperty) []PhysicalPl joins = append(joins, mergeJoin) } } - // If TiDB_SMJ hint is existed && no join keys in children property, - // it should to enforce merge join. - if len(joins) == 0 && (p.preferJoinType&preferMergeJoin) > 0 { - return p.getEnforcedMergeJoin(prop) + // If TiDB_SMJ hint is existed, it should consider enforce merge join, + // because we can't trust lhsChildProperty completely. + if (p.preferJoinType & preferMergeJoin) > 0 { + joins = append(joins, p.getEnforcedMergeJoin(prop)...) } return joins @@ -1198,11 +1198,36 @@ func (p *baseLogicalPlan) exhaustPhysicalPlans(_ *property.PhysicalProperty) []P panic("baseLogicalPlan.exhaustPhysicalPlans() should never be called.") } +func (la *LogicalAggregation) getEnforcedStreamAggs(prop *property.PhysicalProperty) []PhysicalPlan { + _, desc := prop.AllSameOrder() + enforcedAggs := make([]PhysicalPlan, 0, len(wholeTaskTypes)) + childProp := &property.PhysicalProperty{ + ExpectedCnt: math.Max(prop.ExpectedCnt*la.inputCount/la.stats.RowCount, prop.ExpectedCnt), + Enforced: true, + Items: property.ItemsFromCols(la.groupByCols, desc), + } + + for _, taskTp := range wholeTaskTypes { + copiedChildProperty := new(property.PhysicalProperty) + *copiedChildProperty = *childProp // It's ok to not deep copy the "cols" field. + copiedChildProperty.TaskTp = taskTp + + agg := basePhysicalAgg{ + GroupByItems: la.GroupByItems, + AggFuncs: la.AggFuncs, + }.initForStream(la.ctx, la.stats.ScaleByExpectCnt(prop.ExpectedCnt), copiedChildProperty) + agg.SetSchema(la.schema.Clone()) + enforcedAggs = append(enforcedAggs, agg) + } + return enforcedAggs +} + func (la *LogicalAggregation) getStreamAggs(prop *property.PhysicalProperty) []PhysicalPlan { all, desc := prop.AllSameOrder() - if len(la.possibleProperties) == 0 || !all { + if !all { return nil } + for _, aggFunc := range la.AggFuncs { if aggFunc.Mode == aggregation.FinalMode { return nil @@ -1213,7 +1238,7 @@ func (la *LogicalAggregation) getStreamAggs(prop *property.PhysicalProperty) []P return nil } - streamAggs := make([]PhysicalPlan, 0, len(la.possibleProperties)*(len(wholeTaskTypes)-1)) + streamAggs := make([]PhysicalPlan, 0, len(la.possibleProperties)*(len(wholeTaskTypes)-1)+len(wholeTaskTypes)) childProp := &property.PhysicalProperty{ ExpectedCnt: math.Max(prop.ExpectedCnt*la.inputCount/la.stats.RowCount, prop.ExpectedCnt), } @@ -1244,6 +1269,11 @@ func (la *LogicalAggregation) getStreamAggs(prop *property.PhysicalProperty) []P streamAggs = append(streamAggs, agg) } } + // If STREAM_AGG hint is existed, it should consider enforce stream aggregation, + // because we can't trust possibleChildProperty completely. + if (la.preferAggType & preferStreamAgg) > 0 { + streamAggs = append(streamAggs, la.getEnforcedStreamAggs(prop)...) + } return streamAggs } @@ -1264,9 +1294,35 @@ func (la *LogicalAggregation) getHashAggs(prop *property.PhysicalProperty) []Phy } func (la *LogicalAggregation) exhaustPhysicalPlans(prop *property.PhysicalProperty) []PhysicalPlan { - aggs := make([]PhysicalPlan, 0, len(la.possibleProperties)+1) - aggs = append(aggs, la.getHashAggs(prop)...) - aggs = append(aggs, la.getStreamAggs(prop)...) + preferHash := (la.preferAggType & preferHashAgg) > 0 + preferStream := (la.preferAggType & preferStreamAgg) > 0 + if preferHash && preferStream { + errMsg := "Optimizer aggregation hints are conflicted" + warning := ErrInternal.GenWithStack(errMsg) + la.ctx.GetSessionVars().StmtCtx.AppendWarning(warning) + la.preferAggType = 0 + preferHash, preferStream = false, false + } + + hashAggs := la.getHashAggs(prop) + if hashAggs != nil && preferHash { + return hashAggs + } + + streamAggs := la.getStreamAggs(prop) + if streamAggs != nil && preferStream { + return streamAggs + } + + if streamAggs == nil && preferStream { + errMsg := "Optimizer Hint TIDB_STREAMAGG is inapplicable" + warning := ErrInternal.GenWithStack(errMsg) + la.ctx.GetSessionVars().StmtCtx.AppendWarning(warning) + } + + aggs := make([]PhysicalPlan, 0, len(hashAggs)+len(streamAggs)) + aggs = append(aggs, hashAggs...) + aggs = append(aggs, streamAggs...) return aggs } diff --git a/planner/core/expression_rewriter.go b/planner/core/expression_rewriter.go index ef197daa5a63b..af1be06afc031 100644 --- a/planner/core/expression_rewriter.go +++ b/planner/core/expression_rewriter.go @@ -756,10 +756,7 @@ func (er *expressionRewriter) handleInSubquery(ctx context.Context, v *ast.Patte join.attachOnConds(expression.SplitCNFItems(checkCondition)) // Set join hint for this join. if er.b.TableHints() != nil { - er.err = join.setPreferredJoinType(er.b.TableHints()) - if er.err != nil { - return v, true - } + join.setPreferredJoinType(er.b.TableHints()) } er.p = join } else { diff --git a/planner/core/logical_plan_builder.go b/planner/core/logical_plan_builder.go index 69e831a015009..a85a635a791af 100644 --- a/planner/core/logical_plan_builder.go +++ b/planner/core/logical_plan_builder.go @@ -54,6 +54,10 @@ const ( TiDBIndexNestedLoopJoin = "tidb_inlj" // TiDBHashJoin is hint enforce hash join. TiDBHashJoin = "tidb_hj" + // TiDBHashAgg is hint enforce hash aggregation. + TiDBHashAgg = "tidb_hashagg" + // TiDBStreamAgg is hint enforce stream aggregation. + TiDBStreamAgg = "tidb_streamagg" ) const ( @@ -137,6 +141,9 @@ func (b *PlanBuilder) buildAggregation(ctx context.Context, p LogicalPlan, aggFu plan4Agg.GroupByItems = gbyItems plan4Agg.SetSchema(schema4Agg) plan4Agg.collectGroupByColumns() + if hint := b.TableHints(); hint != nil { + plan4Agg.preferAggType = hint.preferAggType + } return plan4Agg, aggIndexMap, nil } @@ -327,9 +334,9 @@ func extractTableAlias(p LogicalPlan) *model.CIStr { return nil } -func (p *LogicalJoin) setPreferredJoinType(hintInfo *tableHintInfo) error { +func (p *LogicalJoin) setPreferredJoinType(hintInfo *tableHintInfo) { if hintInfo == nil { - return nil + return } lhsAlias := extractTableAlias(p.children[0]) @@ -355,9 +362,10 @@ func (p *LogicalJoin) setPreferredJoinType(hintInfo *tableHintInfo) error { // If there're multiple join types and one of them is not index join hint, // then there is a conflict of join types. if bits.OnesCount(p.preferJoinType) > 1 && (p.preferJoinType^preferRightAsIndexInner^preferLeftAsIndexInner) > 0 { - return errors.New("Join hints are conflict, you can only specify one type of join") + errMsg := "Join hints are conflict, you can only specify one type of join" + warning := ErrInternal.GenWithStack(errMsg) + p.ctx.GetSessionVars().StmtCtx.AppendWarning(warning) } - return nil } func resetNotNullFlag(schema *expression.Schema, start, end int) { @@ -424,10 +432,7 @@ func (b *PlanBuilder) buildJoin(ctx context.Context, joinNode *ast.Join) (Logica joinPlan.redundantSchema = expression.MergeSchema(lRedundant, rRedundant) // Set preferred join algorithm if some join hints is specified by user. - err = joinPlan.setPreferredJoinType(b.TableHints()) - if err != nil { - return nil, err - } + joinPlan.setPreferredJoinType(b.TableHints()) // "NATURAL JOIN" doesn't have "ON" or "USING" conditions. // @@ -1931,8 +1936,9 @@ func (b *PlanBuilder) unfoldWildStar(p LogicalPlan, selectFields []*ast.SelectFi return resultList, nil } -func (b *PlanBuilder) pushTableHints(hints []*ast.TableOptimizerHint) bool { +func (b *PlanBuilder) pushTableHints(hints []*ast.TableOptimizerHint) { var sortMergeTables, INLJTables, hashJoinTables []hintTableInfo + var preferAggType uint for _, hint := range hints { switch hint.HintName.L { case TiDBMergeJoin: @@ -1941,19 +1947,20 @@ func (b *PlanBuilder) pushTableHints(hints []*ast.TableOptimizerHint) bool { INLJTables = tableNames2HintTableInfo(hint.Tables) case TiDBHashJoin: hashJoinTables = tableNames2HintTableInfo(hint.Tables) + case TiDBHashAgg: + preferAggType |= preferHashAgg + case TiDBStreamAgg: + preferAggType |= preferStreamAgg default: // ignore hints that not implemented } } - if len(sortMergeTables)+len(INLJTables)+len(hashJoinTables) > 0 { - b.tableHintInfo = append(b.tableHintInfo, tableHintInfo{ - sortMergeJoinTables: sortMergeTables, - indexNestedLoopJoinTables: INLJTables, - hashJoinTables: hashJoinTables, - }) - return true - } - return false + b.tableHintInfo = append(b.tableHintInfo, tableHintInfo{ + sortMergeJoinTables: sortMergeTables, + indexNestedLoopJoinTables: INLJTables, + hashJoinTables: hashJoinTables, + preferAggType: preferAggType, + }) } func (b *PlanBuilder) popTableHints() { @@ -1983,10 +1990,10 @@ func (b *PlanBuilder) TableHints() *tableHintInfo { } func (b *PlanBuilder) buildSelect(ctx context.Context, sel *ast.SelectStmt) (p LogicalPlan, err error) { - if b.pushTableHints(sel.TableHints) { - // table hints are only visible in the current SELECT statement. - defer b.popTableHints() - } + b.pushTableHints(sel.TableHints) + // table hints are only visible in the current SELECT statement. + defer b.popTableHints() + if sel.SelectStmtOpts != nil { origin := b.inStraightJoin b.inStraightJoin = sel.SelectStmtOpts.StraightJoin @@ -2602,10 +2609,9 @@ func buildColumns2Handle( } func (b *PlanBuilder) buildUpdate(ctx context.Context, update *ast.UpdateStmt) (Plan, error) { - if b.pushTableHints(update.TableHints) { - // table hints are only visible in the current UPDATE statement. - defer b.popTableHints() - } + b.pushTableHints(update.TableHints) + // table hints are only visible in the current UPDATE statement. + defer b.popTableHints() // update subquery table should be forbidden var asNameList []string @@ -2823,10 +2829,9 @@ func extractTableAsNameForUpdate(p LogicalPlan, asNames map[*model.TableInfo][]* } func (b *PlanBuilder) buildDelete(ctx context.Context, delete *ast.DeleteStmt) (Plan, error) { - if b.pushTableHints(delete.TableHints) { - // table hints are only visible in the current DELETE statement. - defer b.popTableHints() - } + b.pushTableHints(delete.TableHints) + // table hints are only visible in the current DELETE statement. + defer b.popTableHints() p, err := b.buildResultSetNode(ctx, delete.TableRefs.TableRefs) if err != nil { diff --git a/planner/core/logical_plans.go b/planner/core/logical_plans.go index 25094e5d0028b..c459d0d75810f 100644 --- a/planner/core/logical_plans.go +++ b/planner/core/logical_plans.go @@ -97,6 +97,8 @@ const ( preferRightAsIndexInner preferHashJoin preferMergeJoin + preferHashAgg + preferStreamAgg ) // LogicalJoin is the logical join plan. @@ -246,6 +248,9 @@ type LogicalAggregation struct { // groupByCols stores the columns that are group-by items. groupByCols []*expression.Column + // preferAggType stores preferred aggregation algorithm type. + preferAggType uint + possibleProperties [][]*expression.Column inputCount float64 // inputCount is the input count of this plan. } diff --git a/planner/core/physical_plan_test.go b/planner/core/physical_plan_test.go index 8fe8eeff128f7..7d7be7839e4f9 100644 --- a/planner/core/physical_plan_test.go +++ b/planner/core/physical_plan_test.go @@ -1521,7 +1521,7 @@ func (s *testPlanSuite) TestUnmatchedTableInHint(c *C) { } } -func (s *testPlanSuite) TestIndexJoinHint(c *C) { +func (s *testPlanSuite) TestJoinHints(c *C) { defer testleak.AfterTest(c)() store, dom, err := newStoreWithBootstrap() c.Assert(err, IsNil) @@ -1570,3 +1570,159 @@ func (s *testPlanSuite) TestIndexJoinHint(c *C) { } } } + +func (s *testPlanSuite) TestAggregationHints(c *C) { + defer testleak.AfterTest(c)() + store, dom, err := newStoreWithBootstrap() + c.Assert(err, IsNil) + defer func() { + dom.Close() + store.Close() + }() + se, err := session.CreateSession4Test(store) + c.Assert(err, IsNil) + _, err = se.Execute(context.Background(), "use test") + c.Assert(err, IsNil) + + tests := []struct { + sql string + best string + warning string + }{ + // without Aggregation hints + { + sql: "select count(*) from t t1, t t2 where t1.a = t2.b", + best: "LeftHashJoin{TableReader(Table(t))->TableReader(Table(t))}(test.t1.a,test.t2.b)->StreamAgg", + warning: "", + }, + { + sql: "select count(t1.a) from t t1, t t2 where t1.a = t2.a*2 group by t1.a", + best: "LeftHashJoin{TableReader(Table(t))->TableReader(Table(t))->Projection}(test.t1.a,mul(test.t2.a, 2))->HashAgg", + warning: "", + }, + // with Aggregation hints + { + sql: "select /*+ TIDB_HASHAGG() */ count(*) from t t1, t t2 where t1.a = t2.b", + best: "LeftHashJoin{TableReader(Table(t))->TableReader(Table(t))}(test.t1.a,test.t2.b)->HashAgg", + warning: "", + }, + { + sql: "select /*+ TIDB_STREAMAGG() */ count(t1.a) from t t1, t t2 where t1.a = t2.a*2 group by t1.a", + best: "LeftHashJoin{TableReader(Table(t))->TableReader(Table(t))->Projection}(test.t1.a,mul(test.t2.a, 2))->Sort->StreamAgg", + warning: "", + }, + // test conflict warning + { + sql: "select /*+ TIDB_HASHAGG() TIDB_STREAMAGG() */ count(*) from t t1, t t2 where t1.a = t2.b", + best: "LeftHashJoin{TableReader(Table(t))->TableReader(Table(t))}(test.t1.a,test.t2.b)->StreamAgg", + warning: "[planner:1815]Optimizer aggregation hints are conflicted", + }, + } + ctx := context.Background() + for i, test := range tests { + comment := Commentf("case:%v sql:%s", i, test) + stmt, err := s.ParseOneStmt(test.sql, "", "") + c.Assert(err, IsNil, comment) + + p, err := planner.Optimize(ctx, se, stmt, s.is) + c.Assert(err, IsNil) + c.Assert(core.ToString(p), Equals, test.best) + + warnings := se.GetSessionVars().StmtCtx.GetWarnings() + if test.warning == "" { + c.Assert(len(warnings), Equals, 0) + } else { + c.Assert(len(warnings), Equals, 1) + c.Assert(warnings[0].Level, Equals, stmtctx.WarnLevelWarning) + c.Assert(warnings[0].Err.Error(), Equals, test.warning) + } + } +} + +func (s *testPlanSuite) TestHintScope(c *C) { + defer testleak.AfterTest(c)() + store, dom, err := newStoreWithBootstrap() + c.Assert(err, IsNil) + defer func() { + dom.Close() + store.Close() + }() + se, err := session.CreateSession4Test(store) + c.Assert(err, IsNil) + _, err = se.Execute(context.Background(), "use test") + c.Assert(err, IsNil) + + tests := []struct { + sql string + best string + }{ + // join hints + { + sql: "select /*+ TIDB_SMJ(t1) */ t1.a, t1.b from t t1, (select /*+ TIDB_INLJ(t3) */ t2.a from t t2, t t3 where t2.a = t3.c) s where t1.a=s.a", + best: "MergeInnerJoin{TableReader(Table(t))->IndexJoin{TableReader(Table(t))->IndexReader(Index(t.c_d_e)[[NULL,+inf]])}(test.t2.a,test.t3.c)}(test.t1.a,test.t2.a)->Projection", + }, + { + sql: "select /*+ TIDB_SMJ(t1) */ t1.a, t1.b from t t1, (select /*+ TIDB_HJ(t2) */ t2.a from t t2, t t3 where t2.a = t3.c) s where t1.a=s.a", + best: "MergeInnerJoin{TableReader(Table(t))->LeftHashJoin{TableReader(Table(t))->TableReader(Table(t))}(test.t2.a,test.t3.c)->Sort}(test.t1.a,test.t2.a)->Projection", + }, + { + sql: "select /*+ TIDB_INLJ(t1) */ t1.a, t1.b from t t1, (select /*+ TIDB_HJ(t2) */ t2.a from t t2, t t3 where t2.a = t3.c) s where t1.a=s.a", + best: "IndexJoin{TableReader(Table(t))->LeftHashJoin{TableReader(Table(t))->TableReader(Table(t))}(test.t2.a,test.t3.c)}(test.t2.a,test.t1.a)->Projection", + }, + { + sql: "select /*+ TIDB_INLJ(t1) */ t1.a, t1.b from t t1, (select /*+ TIDB_SMJ(t2) */ t2.a from t t2, t t3 where t2.a = t3.c) s where t1.a=s.a", + best: "IndexJoin{TableReader(Table(t))->MergeInnerJoin{TableReader(Table(t))->IndexReader(Index(t.c_d_e)[[NULL,+inf]])}(test.t2.a,test.t3.c)}(test.t2.a,test.t1.a)->Projection", + }, + { + sql: "select /*+ TIDB_HJ(t1) */ t1.a, t1.b from t t1, (select /*+ TIDB_SMJ(t2) */ t2.a from t t2, t t3 where t2.a = t3.c) s where t1.a=s.a", + best: "RightHashJoin{TableReader(Table(t))->MergeInnerJoin{TableReader(Table(t))->IndexReader(Index(t.c_d_e)[[NULL,+inf]])}(test.t2.a,test.t3.c)}(test.t1.a,test.t2.a)->Projection", + }, + { + sql: "select /*+ TIDB_HJ(t1) */ t1.a, t1.b from t t1, (select /*+ TIDB_INLJ(t2) */ t2.a from t t2, t t3 where t2.a = t3.c) s where t1.a=s.a", + best: "RightHashJoin{TableReader(Table(t))->IndexJoin{TableReader(Table(t))->TableReader(Table(t))}(test.t3.c,test.t2.a)}(test.t1.a,test.t2.a)->Projection", + }, + { + sql: "select /*+ TIDB_SMJ(t1) */ t1.a, t1.b from t t1, (select t2.a from t t2, t t3 where t2.a = t3.c) s where t1.a=s.a", + best: "MergeInnerJoin{TableReader(Table(t))->MergeInnerJoin{TableReader(Table(t))->IndexReader(Index(t.c_d_e)[[NULL,+inf]])}(test.t2.a,test.t3.c)}(test.t1.a,test.t2.a)->Projection", + }, + { + sql: "select /*+ TIDB_INLJ(t1) */ t1.a, t1.b from t t1, (select t2.a from t t2, t t3 where t2.a = t3.c) s where t1.a=s.a", + best: "IndexJoin{TableReader(Table(t))->MergeInnerJoin{TableReader(Table(t))->IndexReader(Index(t.c_d_e)[[NULL,+inf]])}(test.t2.a,test.t3.c)}(test.t2.a,test.t1.a)->Projection", + }, + { + sql: "select /*+ TIDB_HJ(t1) */ t1.a, t1.b from t t1, (select t2.a from t t2, t t3 where t2.a = t3.c) s where t1.a=s.a", + best: "RightHashJoin{TableReader(Table(t))->MergeInnerJoin{TableReader(Table(t))->IndexReader(Index(t.c_d_e)[[NULL,+inf]])}(test.t2.a,test.t3.c)}(test.t1.a,test.t2.a)->Projection", + }, + // aggregation hints + { + sql: "select /*+ TIDB_STREAMAGG() */ s, count(s) from (select /*+ TIDB_HASHAGG() */ sum(t1.a) as s from t t1, t t2 where t1.a = t2.b group by t1.a) p group by s", + best: "LeftHashJoin{TableReader(Table(t))->TableReader(Table(t))}(test.t1.a,test.t2.b)->Projection->HashAgg->Sort->StreamAgg->Projection", + }, + { + sql: "select /*+ TIDB_HASHAGG() */ s, count(s) from (select /*+ TIDB_STREAMAGG() */ sum(t1.a) as s from t t1, t t2 where t1.a = t2.b group by t1.a) p group by s", + best: "LeftHashJoin{TableReader(Table(t))->TableReader(Table(t))}(test.t1.a,test.t2.b)->Sort->Projection->StreamAgg->HashAgg->Projection", + }, + { + sql: "select /*+ TIDB_HASHAGG() */ s, count(s) from (select sum(t1.a) as s from t t1, t t2 where t1.a = t2.b group by t1.a) p group by s", + best: "LeftHashJoin{TableReader(Table(t))->TableReader(Table(t))}(test.t1.a,test.t2.b)->Projection->HashAgg->HashAgg->Projection", + }, + { + sql: "select /*+ TIDB_STREAMAGG() */ s, count(s) from (select sum(t1.a) as s from t t1, t t2 where t1.a = t2.b group by t1.a) p group by s", + best: "LeftHashJoin{TableReader(Table(t))->TableReader(Table(t))}(test.t1.a,test.t2.b)->Projection->HashAgg->Sort->StreamAgg->Projection", + }, + } + + ctx := context.Background() + for i, test := range tests { + comment := Commentf("case:%v sql:%s", i, test) + stmt, err := s.ParseOneStmt(test.sql, "", "") + c.Assert(err, IsNil, comment) + + p, err := planner.Optimize(ctx, se, stmt, s.is) + c.Assert(err, IsNil) + c.Assert(core.ToString(p), Equals, test.best) + + warnings := se.GetSessionVars().StmtCtx.GetWarnings() + c.Assert(warnings, HasLen, 0, comment) + } +} diff --git a/planner/core/planbuilder.go b/planner/core/planbuilder.go index 0b351cd360dfb..3134485529e87 100644 --- a/planner/core/planbuilder.go +++ b/planner/core/planbuilder.go @@ -57,6 +57,7 @@ type tableHintInfo struct { indexNestedLoopJoinTables []hintTableInfo sortMergeJoinTables []hintTableInfo hashJoinTables []hintTableInfo + preferAggType uint } type hintTableInfo struct {