Skip to content

Commit ae56111

Browse files
committed
rewrite limit->agg to topn-agg
1 parent 2949d6b commit ae56111

File tree

9 files changed

+180
-92
lines changed

9 files changed

+180
-92
lines changed

fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java

+2
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@
8383
import org.apache.doris.nereids.rules.rewrite.InferPredicates;
8484
import org.apache.doris.nereids.rules.rewrite.InferSetOperatorDistinct;
8585
import org.apache.doris.nereids.rules.rewrite.InlineLogicalView;
86+
import org.apache.doris.nereids.rules.rewrite.LimitAggToTopNAgg;
8687
import org.apache.doris.nereids.rules.rewrite.LimitSortToTopN;
8788
import org.apache.doris.nereids.rules.rewrite.MergeAggregate;
8889
import org.apache.doris.nereids.rules.rewrite.MergeFilters;
@@ -351,6 +352,7 @@ public class Rewriter extends AbstractBatchJobExecutor {
351352
// generate one PhysicalLimit if current distribution is gather or two
352353
// PhysicalLimits with gather exchange
353354
topDown(new LimitSortToTopN()),
355+
topDown(new LimitAggToTopNAgg()),
354356
topDown(new MergeTopNs()),
355357
topDown(new SplitLimit()),
356358
topDown(

fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/PlanPostProcessors.java

+3-3
Original file line numberDiff line numberDiff line change
@@ -64,10 +64,10 @@ public List<PlanPostProcessor> getProcessors() {
6464
builder.add(new RecomputeLogicalPropertiesProcessor());
6565
builder.add(new AddOffsetIntoDistribute());
6666
builder.add(new CommonSubExpressionOpt());
67-
if (cascadesContext.getConnectContext().getSessionVariable().pushLimitToLocalAgg) {
68-
builder.add(new PushLimitToLocalAgg());
69-
}
7067
// DO NOT replace PLAN NODE from here
68+
if (cascadesContext.getConnectContext().getSessionVariable().pushTopnToAgg) {
69+
builder.add(new PushTopnToAgg());
70+
}
7171
builder.add(new TopNScanOpt());
7272
builder.add(new FragmentProcessor());
7373
if (!cascadesContext.getConnectContext().getSessionVariable().getRuntimeFilterMode()

fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/PushLimitToLocalAgg.java fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/PushTopnToAgg.java

+61-47
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,10 @@
2121
package org.apache.doris.nereids.processor.post;
2222

2323
import org.apache.doris.nereids.CascadesContext;
24+
import org.apache.doris.nereids.properties.DistributionSpecGather;
2425
import org.apache.doris.nereids.properties.OrderKey;
2526
import org.apache.doris.nereids.trees.expressions.Expression;
27+
import org.apache.doris.nereids.trees.plans.AggMode;
2628
import org.apache.doris.nereids.trees.plans.Plan;
2729
import org.apache.doris.nereids.trees.plans.physical.PhysicalDistribute;
2830
import org.apache.doris.nereids.trees.plans.physical.PhysicalHashAggregate;
@@ -37,17 +39,20 @@
3739
import java.util.stream.Collectors;
3840

3941
/**
40-
Pattern1:
41-
limit(n) -> aggGlobal -> distribute -> aggLocal
42-
=>
43-
limit(n) -> aggGlobal(topn=n) -> distribute -> aggLocal(topn=n)
44-
45-
Pattern2: topn orderkeys are the prefix of group keys
42+
TOPN:
43+
Precondition: topn orderkeys are the prefix of group keys
44+
Pattern 2-phase agg:
4645
topn -> aggGlobal -> distribute -> aggLocal
4746
=>
4847
topn(n) -> aggGlobal(topn=n) -> distribute -> aggLocal(topn=n)
48+
Pattern 1-phase agg:
49+
topn->agg->Any(not agg) -> topn -> agg(topn=n) -> any
50+
51+
LIMIT:
52+
Pattern 1: limit->agg(1phase)->any
53+
Pattern 2: limit->agg(global)->gather->agg(local)
4954
*/
50-
public class PushLimitToLocalAgg extends PlanPostProcessor {
55+
public class PushTopnToAgg extends PlanPostProcessor {
5156
@Override
5257
public Plan visitPhysicalTopN(PhysicalTopN<? extends Plan> topN, CascadesContext ctx) {
5358
Plan topnChild = topN.child();
@@ -56,58 +61,27 @@ public Plan visitPhysicalTopN(PhysicalTopN<? extends Plan> topN, CascadesContext
5661
}
5762
if (topnChild instanceof PhysicalHashAggregate) {
5863
PhysicalHashAggregate<? extends Plan> upperAgg = (PhysicalHashAggregate<? extends Plan>) topnChild;
59-
if (upperAgg.getAggPhase().isGlobal()
60-
&& upperAgg.child() instanceof PhysicalDistribute
61-
&& upperAgg.child().child(0) instanceof PhysicalHashAggregate) {
64+
if (upperAgg.getAggPhase().isGlobal() && upperAgg.getAggMode() == AggMode.BUFFER_TO_RESULT) {
6265
List<OrderKey> orderKeys = tryGenerateOrderKeyByGroupKeyAndTopnKey(topN, upperAgg);
6366
if (!orderKeys.isEmpty()) {
6467
upperAgg.setTopnPushInfo(new TopnPushInfo(
6568
orderKeys,
6669
topN.getLimit() + topN.getOffset()));
67-
PhysicalHashAggregate<? extends Plan> bottomAgg =
68-
(PhysicalHashAggregate<? extends Plan>) upperAgg.child().child(0);
69-
bottomAgg.setTopnPushInfo(new TopnPushInfo(
70-
orderKeys,
71-
topN.getLimit() + topN.getOffset()));
70+
if (upperAgg.child() instanceof PhysicalDistribute
71+
&& upperAgg.child().child(0) instanceof PhysicalHashAggregate) {
72+
PhysicalHashAggregate<? extends Plan> bottomAgg =
73+
(PhysicalHashAggregate<? extends Plan>) upperAgg.child().child(0);
74+
bottomAgg.setTopnPushInfo(new TopnPushInfo(
75+
orderKeys,
76+
topN.getLimit() + topN.getOffset()));
77+
}
7278
}
7379
}
7480
}
7581
topN.child().accept(this, ctx);
7682
return topN;
7783
}
7884

79-
@Override
80-
public Plan visitPhysicalLimit(PhysicalLimit<? extends Plan> limit, CascadesContext ctx) {
81-
Plan limitChild = limit.child();
82-
if (limitChild instanceof PhysicalProject) {
83-
limitChild = limitChild.child(0);
84-
}
85-
if (limitChild instanceof PhysicalHashAggregate) {
86-
PhysicalHashAggregate<? extends Plan> upperAgg = (PhysicalHashAggregate<? extends Plan>) limitChild;
87-
if (upperAgg.getAggPhase().isGlobal()
88-
&& upperAgg.child() instanceof PhysicalDistribute
89-
&& upperAgg.child().child(0) instanceof PhysicalHashAggregate) {
90-
upperAgg.setTopnPushInfo(new TopnPushInfo(
91-
generateOrderKeyByGroupKey(upperAgg),
92-
limit.getLimit() + limit.getOffset()));
93-
PhysicalHashAggregate<? extends Plan> bottomAgg =
94-
(PhysicalHashAggregate<? extends Plan>) upperAgg.child().child(0);
95-
bottomAgg.setTopnPushInfo(new TopnPushInfo(
96-
generateOrderKeyByGroupKey(bottomAgg),
97-
limit.getLimit() + limit.getOffset()));
98-
}
99-
}
100-
limit.child().accept(this, ctx);
101-
102-
return limit;
103-
}
104-
105-
private List<OrderKey> generateOrderKeyByGroupKey(PhysicalHashAggregate<? extends Plan> agg) {
106-
return agg.getGroupByExpressions().stream()
107-
.map(key -> new OrderKey(key, true, false))
108-
.collect(Collectors.toList());
109-
}
110-
11185
/**
11286
return true, if topn order-key is prefix of agg group-key, ignore asc/desc and null_first
11387
TODO order-key can be subset of group-key. BE does not support now.
@@ -132,4 +106,44 @@ private List<OrderKey> tryGenerateOrderKeyByGroupKeyAndTopnKey(PhysicalTopN<? ex
132106
}
133107
return orderKeys;
134108
}
109+
110+
@Override
111+
public Plan visitPhysicalLimit(PhysicalLimit<? extends Plan> limit, CascadesContext ctx) {
112+
Plan limitChild = limit.child();
113+
if (limitChild instanceof PhysicalProject) {
114+
limitChild = limitChild.child(0);
115+
}
116+
if (limitChild instanceof PhysicalHashAggregate) {
117+
PhysicalHashAggregate<? extends Plan> upperAgg = (PhysicalHashAggregate<? extends Plan>) limitChild;
118+
if (upperAgg.getAggPhase().isGlobal() && upperAgg.getAggMode() == AggMode.BUFFER_TO_RESULT) {
119+
Plan child = upperAgg.child();
120+
Plan grandChild = child.child(0);
121+
if (child instanceof PhysicalDistribute
122+
&& ((PhysicalDistribute<?>) child).getDistributionSpec() instanceof DistributionSpecGather
123+
&& grandChild instanceof PhysicalHashAggregate) {
124+
upperAgg.setTopnPushInfo(new TopnPushInfo(
125+
generateOrderKeyByGroupKey(upperAgg),
126+
limit.getLimit() + limit.getOffset()));
127+
PhysicalHashAggregate<? extends Plan> bottomAgg =
128+
(PhysicalHashAggregate<? extends Plan>) grandChild;
129+
bottomAgg.setTopnPushInfo(new TopnPushInfo(
130+
generateOrderKeyByGroupKey(bottomAgg),
131+
limit.getLimit() + limit.getOffset()));
132+
}
133+
} else if (upperAgg.getAggMode() == AggMode.INPUT_TO_RESULT) {
134+
// 1-phase agg
135+
upperAgg.setTopnPushInfo(new TopnPushInfo(
136+
generateOrderKeyByGroupKey(upperAgg),
137+
limit.getLimit() + limit.getOffset()));
138+
}
139+
}
140+
limit.child().accept(this, ctx);
141+
return limit;
142+
}
143+
144+
private List<OrderKey> generateOrderKeyByGroupKey(PhysicalHashAggregate<? extends Plan> agg) {
145+
return agg.getGroupByExpressions().stream()
146+
.map(key -> new OrderKey(key, true, false))
147+
.collect(Collectors.toList());
148+
}
135149
}

fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java

+1
Original file line numberDiff line numberDiff line change
@@ -302,6 +302,7 @@ public enum RuleType {
302302
PUSH_LIMIT_THROUGH_UNION(RuleTypeClass.REWRITE),
303303
PUSH_LIMIT_THROUGH_WINDOW(RuleTypeClass.REWRITE),
304304
LIMIT_SORT_TO_TOP_N(RuleTypeClass.REWRITE),
305+
LIMIT_AGG_TO_TOPN_AGG(RuleTypeClass.REWRITE),
305306
// topN push down
306307
PUSH_DOWN_TOP_N_THROUGH_JOIN(RuleTypeClass.REWRITE),
307308
PUSH_DOWN_TOP_N_THROUGH_PROJECT_JOIN(RuleTypeClass.REWRITE),
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
package org.apache.doris.nereids.rules.rewrite;
19+
20+
import org.apache.doris.nereids.properties.OrderKey;
21+
import org.apache.doris.nereids.rules.Rule;
22+
import org.apache.doris.nereids.rules.RuleType;
23+
import org.apache.doris.nereids.trees.plans.Plan;
24+
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
25+
import org.apache.doris.nereids.trees.plans.logical.LogicalLimit;
26+
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
27+
import org.apache.doris.nereids.trees.plans.logical.LogicalTopN;
28+
29+
import com.google.common.collect.ImmutableList;
30+
31+
import java.util.List;
32+
import java.util.stream.Collectors;
33+
34+
/**
35+
* convert limit->agg to topn->agg
36+
* to enable
37+
* 1. topn-filter
38+
* 2. push limit to local agg
39+
*/
40+
public class LimitAggToTopNAgg implements RewriteRuleFactory {
41+
@Override
42+
public List<Rule> buildRules() {
43+
return ImmutableList.of(
44+
// limit -> agg to topn->agg
45+
logicalLimit(logicalAggregate())
46+
.then(limit -> {
47+
LogicalAggregate<? extends Plan> agg = limit.child();
48+
List<OrderKey> orderKeys = generateOrderKeyByGroupKey(agg);
49+
return new LogicalTopN<>(orderKeys, limit.getLimit(), limit.getOffset(), agg);
50+
}).toRule(RuleType.LIMIT_AGG_TO_TOPN_AGG),
51+
//limit->project->agg to topn->project->agg
52+
logicalLimit(logicalProject(logicalAggregate()))
53+
.when(limit -> outputAllGroupKeys(limit, limit.child().child()))
54+
.then(limit -> {
55+
LogicalProject<? extends Plan> project = limit.child();
56+
LogicalAggregate<? extends Plan> agg = (LogicalAggregate<? extends Plan>) project.child();
57+
List<OrderKey> orderKeys = generateOrderKeyByGroupKey(agg);
58+
return new LogicalTopN<>(orderKeys, limit.getLimit(), limit.getOffset(), project);
59+
}).toRule(RuleType.LIMIT_AGG_TO_TOPN_AGG));
60+
}
61+
62+
private boolean outputAllGroupKeys(LogicalLimit limit, LogicalAggregate agg) {
63+
return limit.getOutputSet().containsAll(agg.getGroupByExpressions());
64+
}
65+
66+
private List<OrderKey> generateOrderKeyByGroupKey(LogicalAggregate<? extends Plan> agg) {
67+
return agg.getGroupByExpressions().stream()
68+
.map(key -> new OrderKey(key, true, false))
69+
.collect(Collectors.toList());
70+
}
71+
}

fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/physical/PhysicalHashAggregate.java

+16-15
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
import org.apache.doris.nereids.trees.plans.PlanType;
3434
import org.apache.doris.nereids.trees.plans.algebra.Aggregate;
3535
import org.apache.doris.nereids.trees.plans.visitor.PlanVisitor;
36+
import org.apache.doris.nereids.util.MutableState;
3637
import org.apache.doris.nereids.util.Utils;
3738
import org.apache.doris.statistics.Statistics;
3839

@@ -61,9 +62,6 @@ public class PhysicalHashAggregate<CHILD_TYPE extends Plan> extends PhysicalUnar
6162

6263
private final RequireProperties requireProperties;
6364

64-
// only used in post processor
65-
private TopnPushInfo topnPushInfo = null;
66-
6765
public PhysicalHashAggregate(List<Expression> groupByExpressions, List<NamedExpression> outputExpressions,
6866
AggregateParam aggregateParam, boolean maybeUsingStream, LogicalProperties logicalProperties,
6967
RequireProperties requireProperties, CHILD_TYPE child) {
@@ -192,6 +190,8 @@ public List<? extends Expression> getExpressions() {
192190

193191
@Override
194192
public String toString() {
193+
TopnPushInfo topnPushInfo = (TopnPushInfo) getMutableState(
194+
MutableState.KEY_PUSH_TOPN_TO_AGG).orElseGet(() -> null);
195195
return Utils.toSqlString("PhysicalHashAggregate[" + id.asInt() + "]" + getGroupIdWithPrefix(),
196196
"aggPhase", aggregateParam.aggPhase,
197197
"aggMode", aggregateParam.aggMode,
@@ -236,22 +236,19 @@ public PhysicalHashAggregate<Plan> withChildren(List<Plan> children) {
236236
return new PhysicalHashAggregate<>(groupByExpressions, outputExpressions, partitionExpressions,
237237
aggregateParam, maybeUsingStream, groupExpression, getLogicalProperties(),
238238
requireProperties, physicalProperties, statistics,
239-
children.get(0))
240-
.setTopnPushInfo(topnPushInfo);
239+
children.get(0));
241240
}
242241

243242
public PhysicalHashAggregate<CHILD_TYPE> withPartitionExpressions(List<Expression> partitionExpressions) {
244243
return new PhysicalHashAggregate<>(groupByExpressions, outputExpressions,
245244
Optional.ofNullable(partitionExpressions), aggregateParam, maybeUsingStream,
246-
Optional.empty(), getLogicalProperties(), requireProperties, child())
247-
.setTopnPushInfo(topnPushInfo);
245+
Optional.empty(), getLogicalProperties(), requireProperties, child());
248246
}
249247

250248
@Override
251249
public PhysicalHashAggregate<CHILD_TYPE> withGroupExpression(Optional<GroupExpression> groupExpression) {
252250
return new PhysicalHashAggregate<>(groupByExpressions, outputExpressions, partitionExpressions,
253-
aggregateParam, maybeUsingStream, groupExpression, getLogicalProperties(), requireProperties, child())
254-
.setTopnPushInfo(topnPushInfo);
251+
aggregateParam, maybeUsingStream, groupExpression, getLogicalProperties(), requireProperties, child());
255252
}
256253

257254
@Override
@@ -260,7 +257,7 @@ public Plan withGroupExprLogicalPropChildren(Optional<GroupExpression> groupExpr
260257
Preconditions.checkArgument(children.size() == 1);
261258
return new PhysicalHashAggregate<>(groupByExpressions, outputExpressions, partitionExpressions,
262259
aggregateParam, maybeUsingStream, groupExpression, logicalProperties.get(),
263-
requireProperties, children.get(0)).setTopnPushInfo(topnPushInfo);
260+
requireProperties, children.get(0));
264261
}
265262

266263
@Override
@@ -269,21 +266,21 @@ public PhysicalHashAggregate<CHILD_TYPE> withPhysicalPropertiesAndStats(Physical
269266
return new PhysicalHashAggregate<>(groupByExpressions, outputExpressions, partitionExpressions,
270267
aggregateParam, maybeUsingStream, groupExpression, getLogicalProperties(),
271268
requireProperties, physicalProperties, statistics,
272-
child()).setTopnPushInfo(topnPushInfo);
269+
child());
273270
}
274271

275272
@Override
276273
public PhysicalHashAggregate<CHILD_TYPE> withAggOutput(List<NamedExpression> newOutput) {
277274
return new PhysicalHashAggregate<>(groupByExpressions, newOutput, partitionExpressions,
278275
aggregateParam, maybeUsingStream, Optional.empty(), getLogicalProperties(),
279-
requireProperties, physicalProperties, statistics, child()).setTopnPushInfo(topnPushInfo);
276+
requireProperties, physicalProperties, statistics, child());
280277
}
281278

282279
public <C extends Plan> PhysicalHashAggregate<C> withRequirePropertiesAndChild(
283280
RequireProperties requireProperties, C newChild) {
284281
return new PhysicalHashAggregate<>(groupByExpressions, outputExpressions, partitionExpressions,
285282
aggregateParam, maybeUsingStream, Optional.empty(), getLogicalProperties(),
286-
requireProperties, physicalProperties, statistics, newChild).setTopnPushInfo(topnPushInfo);
283+
requireProperties, physicalProperties, statistics, newChild);
287284
}
288285

289286
@Override
@@ -322,11 +319,15 @@ public TopnPushInfo(List<OrderKey> orderkeys, long limit) {
322319
}
323320

324321
public TopnPushInfo getTopnPushInfo() {
325-
return topnPushInfo;
322+
Optional<Object> obj = getMutableState(MutableState.KEY_PUSH_TOPN_TO_AGG);
323+
if (obj.isPresent() && obj.get() instanceof TopnPushInfo) {
324+
return (TopnPushInfo) obj.get();
325+
}
326+
return null;
326327
}
327328

328329
public PhysicalHashAggregate<CHILD_TYPE> setTopnPushInfo(TopnPushInfo topnPushInfo) {
329-
this.topnPushInfo = topnPushInfo;
330+
setMutableState(MutableState.KEY_PUSH_TOPN_TO_AGG, topnPushInfo);
330331
return this;
331332
}
332333
}

fe/fe-core/src/main/java/org/apache/doris/nereids/util/MutableState.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ public interface MutableState {
2727
String KEY_FRAGMENT = "fragment";
2828
String KEY_PARENT = "parent";
2929
String KEY_RF_JUMP = "rf-jump";
30-
30+
String KEY_PUSH_TOPN_TO_AGG = "pushTopnToAgg";
3131
<T> Optional<T> get(String key);
3232

3333
MutableState set(String key, Object value);

fe/fe-core/src/main/java/org/apache/doris/qe/SessionVariable.java

+2-2
Original file line numberDiff line numberDiff line change
@@ -1195,8 +1195,8 @@ public void setEnableLeftZigZag(boolean enableLeftZigZag) {
11951195
@VariableMgr.VarAttr(name = REWRITE_OR_TO_IN_PREDICATE_THRESHOLD, fuzzy = true)
11961196
private int rewriteOrToInPredicateThreshold = 2;
11971197

1198-
@VariableMgr.VarAttr(name = "push_limit_to_local_agg", fuzzy = false, needForward = true)
1199-
public boolean pushLimitToLocalAgg = true;
1198+
@VariableMgr.VarAttr(name = "push_topn_to_agg", fuzzy = false, needForward = true)
1199+
public boolean pushTopnToAgg = true;
12001200

12011201
@VariableMgr.VarAttr(name = NEREIDS_CBO_PENALTY_FACTOR, needForward = true)
12021202
private double nereidsCboPenaltyFactor = 0.7;

0 commit comments

Comments
 (0)