Skip to content

Commit 5c3644d

Browse files
engleflydataroaring
authored andcommitted
[feat](nereids)push Limit to local agg (#34853)
## Proposed changes for a pattern: topn(n)->globalAgg->localAgg this pr tries to add a filter on global/localAgg which means only the first n tuples are counted, and others could be ignored. inorder to obtain this benefit, optimizer will change limit node to topn node if the limit number is less than topnOptLimitThreshold. Issue Number: close #xxx <!--Describe your changes.--> ## Further comments If this is a relatively large or complex change, kick off the discussion at [dev@doris.apache.org](mailto:dev@doris.apache.org) by explaining why you chose the solution you did and what alternatives you considered, etc...
1 parent dffb1ff commit 5c3644d

File tree

64 files changed

+1719
-1053
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

64 files changed

+1719
-1053
lines changed

fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java

+19-1
Original file line numberDiff line numberDiff line change
@@ -294,7 +294,8 @@ public PlanFragment visitPhysicalDistribute(PhysicalDistribute<? extends Plan> d
294294
&& context.getFirstAggregateInFragment(inputFragment) == child) {
295295
PhysicalHashAggregate<?> hashAggregate = (PhysicalHashAggregate<?>) child;
296296
if (hashAggregate.getAggPhase() == AggPhase.LOCAL
297-
&& hashAggregate.getAggMode() == AggMode.INPUT_TO_BUFFER) {
297+
&& hashAggregate.getAggMode() == AggMode.INPUT_TO_BUFFER
298+
&& hashAggregate.getTopnPushInfo() == null) {
298299
AggregationNode aggregationNode = (AggregationNode) inputFragment.getPlanRoot();
299300
aggregationNode.setUseStreamingPreagg(hashAggregate.isMaybeUsingStream());
300301
}
@@ -1056,6 +1057,23 @@ public PlanFragment visitPhysicalHashAggregate(
10561057
// local exchanger will be used.
10571058
aggregationNode.setColocate(true);
10581059
}
1060+
if (aggregate.getTopnPushInfo() != null) {
1061+
List<Expr> orderingExprs = Lists.newArrayList();
1062+
List<Boolean> ascOrders = Lists.newArrayList();
1063+
List<Boolean> nullsFirstParams = Lists.newArrayList();
1064+
aggregate.getTopnPushInfo().orderkeys.forEach(k -> {
1065+
orderingExprs.add(ExpressionTranslator.translate(k.getExpr(), context));
1066+
ascOrders.add(k.isAsc());
1067+
nullsFirstParams.add(k.isNullFirst());
1068+
});
1069+
SortInfo sortInfo = new SortInfo(orderingExprs, ascOrders, nullsFirstParams, outputTupleDesc);
1070+
aggregationNode.setSortByGroupKey(sortInfo);
1071+
if (aggregationNode.getLimit() == -1) {
1072+
aggregationNode.setLimit(aggregate.getTopnPushInfo().limit);
1073+
}
1074+
} else {
1075+
aggregationNode.setSortByGroupKey(null);
1076+
}
10591077
setPlanRoot(inputPlanFragment, aggregationNode, aggregate);
10601078
if (aggregate.getStats() != null) {
10611079
aggregationNode.setCardinality((long) aggregate.getStats().getRowCount());

fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java

+2
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@
8484
import org.apache.doris.nereids.rules.rewrite.InferPredicates;
8585
import org.apache.doris.nereids.rules.rewrite.InferSetOperatorDistinct;
8686
import org.apache.doris.nereids.rules.rewrite.InlineLogicalView;
87+
import org.apache.doris.nereids.rules.rewrite.LimitAggToTopNAgg;
8788
import org.apache.doris.nereids.rules.rewrite.LimitSortToTopN;
8889
import org.apache.doris.nereids.rules.rewrite.LogicalResultSinkToShortCircuitPointQuery;
8990
import org.apache.doris.nereids.rules.rewrite.MergeAggregate;
@@ -366,6 +367,7 @@ public class Rewriter extends AbstractBatchJobExecutor {
366367
// generate one PhysicalLimit if current distribution is gather or two
367368
// PhysicalLimits with gather exchange
368369
topDown(new LimitSortToTopN()),
370+
topDown(new LimitAggToTopNAgg()),
369371
topDown(new MergeTopNs()),
370372
topDown(new SplitLimit()),
371373
topDown(

fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/PlanPostProcessors.java

+3
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,9 @@ public List<PlanPostProcessor> getProcessors() {
6565
builder.add(new AddOffsetIntoDistribute());
6666
builder.add(new CommonSubExpressionOpt());
6767
// DO NOT replace PLAN NODE from here
68+
if (cascadesContext.getConnectContext().getSessionVariable().pushTopnToAgg) {
69+
builder.add(new PushTopnToAgg());
70+
}
6871
builder.add(new TopNScanOpt());
6972
builder.add(new FragmentProcessor());
7073
if (!cascadesContext.getConnectContext().getSessionVariable().getRuntimeFilterMode()
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,177 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
// This file is copied from
18+
// https://github.com/apache/impala/blob/branch-2.9.0/fe/src/main/java/org/apache/impala/AggregationNode.java
19+
// and modified by Doris
20+
21+
package org.apache.doris.nereids.processor.post;
22+
23+
import org.apache.doris.nereids.CascadesContext;
24+
import org.apache.doris.nereids.properties.DistributionSpecGather;
25+
import org.apache.doris.nereids.properties.OrderKey;
26+
import org.apache.doris.nereids.trees.expressions.Expression;
27+
import org.apache.doris.nereids.trees.plans.AggMode;
28+
import org.apache.doris.nereids.trees.plans.Plan;
29+
import org.apache.doris.nereids.trees.plans.physical.PhysicalDistribute;
30+
import org.apache.doris.nereids.trees.plans.physical.PhysicalHashAggregate;
31+
import org.apache.doris.nereids.trees.plans.physical.PhysicalHashAggregate.TopnPushInfo;
32+
import org.apache.doris.nereids.trees.plans.physical.PhysicalLimit;
33+
import org.apache.doris.nereids.trees.plans.physical.PhysicalProject;
34+
import org.apache.doris.nereids.trees.plans.physical.PhysicalTopN;
35+
import org.apache.doris.qe.ConnectContext;
36+
37+
import org.apache.hadoop.util.Lists;
38+
39+
import java.util.List;
40+
import java.util.stream.Collectors;
41+
42+
/**
43+
* Add SortInfo to Agg. This SortInfo is used as boundary, not used to sort elements.
44+
* example
45+
* sql: select count(*) from orders group by o_clerk order by o_clerk limit 1;
46+
* plan: topn(1) -> aggGlobal -> shuffle -> aggLocal -> scan
47+
* optimization: aggLocal and aggGlobal only need to generate the smallest row with respect to o_clerk.
48+
*
49+
* TODO: the following case is not covered:
50+
* sql: select sum(o_shippriority) from orders group by o_clerk limit 1;
51+
* plan: limit -> aggGlobal -> shuffle -> aggLocal -> scan
52+
* aggGlobal may receive partial aggregate results, and hence is not supported now
53+
* instance1: input (key=2, v=1) => localAgg => (2, 1) => aggGlobal inst1 => (2, 1)
54+
* instance2: input (key=1, v=1), (key=2, v=2) => localAgg inst2 => (1, 1)
55+
* (2,1),(1,1) => limit => may output (2, 1), which is not complete, missing (2, 2) in instance2
56+
*
57+
*TOPN:
58+
* Precondition: topn orderkeys are the prefix of group keys
59+
* TODO: topnKeys could be subset of groupKeys. This will be implemented in future
60+
* Pattern 2-phase agg:
61+
* topn -> aggGlobal -> distribute -> aggLocal
62+
* =>
63+
* topn(n) -> aggGlobal(topn=n) -> distribute -> aggLocal(topn=n)
64+
* Pattern 1-phase agg:
65+
* topn->agg->Any(not agg) -> topn -> agg(topn=n) -> any
66+
*
67+
* LIMIT:
68+
* Pattern 1: limit->agg(1phase)->any
69+
* Pattern 2: limit->agg(global)->gather->agg(local)
70+
*/
71+
public class PushTopnToAgg extends PlanPostProcessor {
72+
@Override
73+
public Plan visitPhysicalTopN(PhysicalTopN<? extends Plan> topN, CascadesContext ctx) {
74+
topN.child().accept(this, ctx);
75+
if (ConnectContext.get().getSessionVariable().topnOptLimitThreshold <= topN.getLimit() + topN.getOffset()) {
76+
return topN;
77+
}
78+
Plan topnChild = topN.child();
79+
if (topnChild instanceof PhysicalProject) {
80+
topnChild = topnChild.child(0);
81+
}
82+
if (topnChild instanceof PhysicalHashAggregate) {
83+
PhysicalHashAggregate<? extends Plan> upperAgg = (PhysicalHashAggregate<? extends Plan>) topnChild;
84+
List<OrderKey> orderKeys = tryGenerateOrderKeyByGroupKeyAndTopnKey(topN, upperAgg);
85+
if (!orderKeys.isEmpty()) {
86+
87+
if (upperAgg.getAggPhase().isGlobal() && upperAgg.getAggMode() == AggMode.BUFFER_TO_RESULT) {
88+
upperAgg.setTopnPushInfo(new TopnPushInfo(
89+
orderKeys,
90+
topN.getLimit() + topN.getOffset()));
91+
if (upperAgg.child() instanceof PhysicalDistribute
92+
&& upperAgg.child().child(0) instanceof PhysicalHashAggregate) {
93+
PhysicalHashAggregate<? extends Plan> bottomAgg =
94+
(PhysicalHashAggregate<? extends Plan>) upperAgg.child().child(0);
95+
bottomAgg.setTopnPushInfo(new TopnPushInfo(
96+
orderKeys,
97+
topN.getLimit() + topN.getOffset()));
98+
}
99+
} else if (upperAgg.getAggPhase().isLocal() && upperAgg.getAggMode() == AggMode.INPUT_TO_RESULT) {
100+
// one phase agg
101+
upperAgg.setTopnPushInfo(new TopnPushInfo(
102+
orderKeys,
103+
topN.getLimit() + topN.getOffset()));
104+
}
105+
}
106+
}
107+
return topN;
108+
}
109+
110+
/**
111+
return true, if topn order-key is prefix of agg group-key, ignore asc/desc and null_first
112+
TODO order-key can be subset of group-key. BE does not support now.
113+
*/
114+
private List<OrderKey> tryGenerateOrderKeyByGroupKeyAndTopnKey(PhysicalTopN<? extends Plan> topN,
115+
PhysicalHashAggregate<? extends Plan> agg) {
116+
List<OrderKey> orderKeys = Lists.newArrayListWithCapacity(agg.getGroupByExpressions().size());
117+
if (topN.getOrderKeys().size() > agg.getGroupByExpressions().size()) {
118+
return orderKeys;
119+
}
120+
List<Expression> topnKeys = topN.getOrderKeys().stream()
121+
.map(OrderKey::getExpr).collect(Collectors.toList());
122+
for (int i = 0; i < topN.getOrderKeys().size(); i++) {
123+
// prefix check
124+
if (!topnKeys.get(i).equals(agg.getGroupByExpressions().get(i))) {
125+
return Lists.newArrayList();
126+
}
127+
orderKeys.add(topN.getOrderKeys().get(i));
128+
}
129+
for (int i = topN.getOrderKeys().size(); i < agg.getGroupByExpressions().size(); i++) {
130+
orderKeys.add(new OrderKey(agg.getGroupByExpressions().get(i), true, false));
131+
}
132+
return orderKeys;
133+
}
134+
135+
@Override
136+
public Plan visitPhysicalLimit(PhysicalLimit<? extends Plan> limit, CascadesContext ctx) {
137+
limit.child().accept(this, ctx);
138+
if (ConnectContext.get().getSessionVariable().topnOptLimitThreshold <= limit.getLimit() + limit.getOffset()) {
139+
return limit;
140+
}
141+
Plan limitChild = limit.child();
142+
if (limitChild instanceof PhysicalProject) {
143+
limitChild = limitChild.child(0);
144+
}
145+
if (limitChild instanceof PhysicalHashAggregate) {
146+
PhysicalHashAggregate<? extends Plan> upperAgg = (PhysicalHashAggregate<? extends Plan>) limitChild;
147+
if (upperAgg.getAggPhase().isGlobal() && upperAgg.getAggMode() == AggMode.BUFFER_TO_RESULT) {
148+
Plan child = upperAgg.child();
149+
Plan grandChild = child.child(0);
150+
if (child instanceof PhysicalDistribute
151+
&& ((PhysicalDistribute<?>) child).getDistributionSpec() instanceof DistributionSpecGather
152+
&& grandChild instanceof PhysicalHashAggregate) {
153+
upperAgg.setTopnPushInfo(new TopnPushInfo(
154+
generateOrderKeyByGroupKey(upperAgg),
155+
limit.getLimit() + limit.getOffset()));
156+
PhysicalHashAggregate<? extends Plan> bottomAgg =
157+
(PhysicalHashAggregate<? extends Plan>) grandChild;
158+
bottomAgg.setTopnPushInfo(new TopnPushInfo(
159+
generateOrderKeyByGroupKey(bottomAgg),
160+
limit.getLimit() + limit.getOffset()));
161+
}
162+
} else if (upperAgg.getAggMode() == AggMode.INPUT_TO_RESULT) {
163+
// 1-phase agg
164+
upperAgg.setTopnPushInfo(new TopnPushInfo(
165+
generateOrderKeyByGroupKey(upperAgg),
166+
limit.getLimit() + limit.getOffset()));
167+
}
168+
}
169+
return limit;
170+
}
171+
172+
private List<OrderKey> generateOrderKeyByGroupKey(PhysicalHashAggregate<? extends Plan> agg) {
173+
return agg.getGroupByExpressions().stream()
174+
.map(key -> new OrderKey(key, true, false))
175+
.collect(Collectors.toList());
176+
}
177+
}

fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java

+1
Original file line numberDiff line numberDiff line change
@@ -304,6 +304,7 @@ public enum RuleType {
304304
PUSH_LIMIT_THROUGH_UNION(RuleTypeClass.REWRITE),
305305
PUSH_LIMIT_THROUGH_WINDOW(RuleTypeClass.REWRITE),
306306
LIMIT_SORT_TO_TOP_N(RuleTypeClass.REWRITE),
307+
LIMIT_AGG_TO_TOPN_AGG(RuleTypeClass.REWRITE),
307308
// topN push down
308309
PUSH_DOWN_TOP_N_THROUGH_JOIN(RuleTypeClass.REWRITE),
309310
PUSH_DOWN_TOP_N_THROUGH_PROJECT_JOIN(RuleTypeClass.REWRITE),
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
package org.apache.doris.nereids.rules.rewrite;
19+
20+
import org.apache.doris.nereids.properties.OrderKey;
21+
import org.apache.doris.nereids.rules.Rule;
22+
import org.apache.doris.nereids.rules.RuleType;
23+
import org.apache.doris.nereids.trees.expressions.Expression;
24+
import org.apache.doris.nereids.trees.plans.Plan;
25+
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
26+
import org.apache.doris.nereids.trees.plans.logical.LogicalLimit;
27+
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
28+
import org.apache.doris.nereids.trees.plans.logical.LogicalTopN;
29+
import org.apache.doris.qe.ConnectContext;
30+
31+
import com.google.common.collect.ImmutableList;
32+
import com.google.common.collect.Lists;
33+
34+
import java.util.List;
35+
import java.util.stream.Collectors;
36+
37+
/**
38+
* convert limit->agg to topn->agg
39+
* if all group keys are in limit.output
40+
* to enable
41+
* 1. topn-filter
42+
* 2. push limit to local agg
43+
*/
44+
public class LimitAggToTopNAgg implements RewriteRuleFactory {
45+
@Override
46+
public List<Rule> buildRules() {
47+
return ImmutableList.of(
48+
// limit -> agg to topn->agg
49+
logicalLimit(logicalAggregate())
50+
.when(limit -> ConnectContext.get() != null
51+
&& ConnectContext.get().getSessionVariable().pushTopnToAgg
52+
&& ConnectContext.get().getSessionVariable().topnOptLimitThreshold
53+
>= limit.getLimit() + limit.getOffset())
54+
.then(limit -> {
55+
LogicalAggregate<? extends Plan> agg = limit.child();
56+
List<OrderKey> orderKeys = generateOrderKeyByGroupKey(agg);
57+
return new LogicalTopN<>(orderKeys, limit.getLimit(), limit.getOffset(), agg);
58+
}).toRule(RuleType.LIMIT_AGG_TO_TOPN_AGG),
59+
//limit->project->agg to topn->project->agg
60+
logicalLimit(logicalProject(logicalAggregate()))
61+
.when(limit -> ConnectContext.get() != null
62+
&& ConnectContext.get().getSessionVariable().pushTopnToAgg
63+
&& ConnectContext.get().getSessionVariable().topnOptLimitThreshold
64+
>= limit.getLimit() + limit.getOffset())
65+
.when(limit -> outputAllGroupKeys(limit, limit.child().child()))
66+
.then(limit -> {
67+
LogicalProject<? extends Plan> project = limit.child();
68+
LogicalAggregate<? extends Plan> agg = (LogicalAggregate<? extends Plan>) project.child();
69+
List<OrderKey> orderKeys = generateOrderKeyByGroupKey(agg);
70+
return new LogicalTopN<>(orderKeys, limit.getLimit(), limit.getOffset(), project);
71+
}).toRule(RuleType.LIMIT_AGG_TO_TOPN_AGG),
72+
// topn -> agg: add all group key to sort key, if sort key is prefix of group key
73+
logicalTopN(logicalAggregate())
74+
.when(topn -> ConnectContext.get() != null
75+
&& ConnectContext.get().getSessionVariable().pushTopnToAgg
76+
&& ConnectContext.get().getSessionVariable().topnOptLimitThreshold
77+
>= topn.getLimit() + topn.getOffset())
78+
.then(topn -> {
79+
LogicalAggregate<? extends Plan> agg = (LogicalAggregate<? extends Plan>) topn.child();
80+
List<OrderKey> newOrders = tryGenerateOrderKeyByGroupKeyAndTopnKey(topn, agg);
81+
if (newOrders.isEmpty()) {
82+
return topn;
83+
} else {
84+
return topn.withOrderKeys(newOrders);
85+
}
86+
}).toRule(RuleType.LIMIT_AGG_TO_TOPN_AGG));
87+
}
88+
89+
private List<OrderKey> tryGenerateOrderKeyByGroupKeyAndTopnKey(LogicalTopN<? extends Plan> topN,
90+
LogicalAggregate<? extends Plan> agg) {
91+
List<OrderKey> orderKeys = Lists.newArrayListWithCapacity(agg.getGroupByExpressions().size());
92+
if (topN.getOrderKeys().size() > agg.getGroupByExpressions().size()) {
93+
return orderKeys;
94+
}
95+
List<Expression> topnKeys = topN.getOrderKeys().stream()
96+
.map(OrderKey::getExpr).collect(Collectors.toList());
97+
for (int i = 0; i < topN.getOrderKeys().size(); i++) {
98+
// prefix check
99+
if (!topnKeys.get(i).equals(agg.getGroupByExpressions().get(i))) {
100+
return Lists.newArrayList();
101+
}
102+
orderKeys.add(topN.getOrderKeys().get(i));
103+
}
104+
for (int i = topN.getOrderKeys().size(); i < agg.getGroupByExpressions().size(); i++) {
105+
orderKeys.add(new OrderKey(agg.getGroupByExpressions().get(i), true, false));
106+
}
107+
return orderKeys;
108+
}
109+
110+
private boolean outputAllGroupKeys(LogicalLimit limit, LogicalAggregate agg) {
111+
return limit.getOutputSet().containsAll(agg.getGroupByExpressions());
112+
}
113+
114+
private List<OrderKey> generateOrderKeyByGroupKey(LogicalAggregate<? extends Plan> agg) {
115+
return agg.getGroupByExpressions().stream()
116+
.map(key -> new OrderKey(key, true, false))
117+
.collect(Collectors.toList());
118+
}
119+
}

0 commit comments

Comments
 (0)