Skip to content

Commit 25ba503

Browse files
committed
support limit->proj, support topn-agg, add rt
1 parent 7264db6 commit 25ba503

File tree

3 files changed

+83
-17
lines changed

3 files changed

+83
-17
lines changed

fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java

+2-1
Original file line numberDiff line numberDiff line change
@@ -289,7 +289,8 @@ public PlanFragment visitPhysicalDistribute(PhysicalDistribute<? extends Plan> d
289289
&& context.getFirstAggregateInFragment(inputFragment) == child) {
290290
PhysicalHashAggregate<?> hashAggregate = (PhysicalHashAggregate<?>) child;
291291
if (hashAggregate.getAggPhase() == AggPhase.LOCAL
292-
&& hashAggregate.getAggMode() == AggMode.INPUT_TO_BUFFER) {
292+
&& hashAggregate.getAggMode() == AggMode.INPUT_TO_BUFFER
293+
&& hashAggregate.getTopn() == null) {
293294
AggregationNode aggregationNode = (AggregationNode) inputFragment.getPlanRoot();
294295
aggregationNode.setUseStreamingPreagg(hashAggregate.isMaybeUsingStream());
295296
}

fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/PushLimitToLocalAgg.java

+42-15
Original file line numberDiff line numberDiff line change
@@ -22,14 +22,20 @@
2222

2323
import org.apache.doris.nereids.CascadesContext;
2424
import org.apache.doris.nereids.properties.OrderKey;
25+
import org.apache.doris.nereids.trees.expressions.Expression;
2526
import org.apache.doris.nereids.trees.plans.Plan;
2627
import org.apache.doris.nereids.trees.plans.physical.PhysicalDistribute;
2728
import org.apache.doris.nereids.trees.plans.physical.PhysicalHashAggregate;
2829
import org.apache.doris.nereids.trees.plans.physical.PhysicalHashAggregate.TopNOptInfo;
2930
import org.apache.doris.nereids.trees.plans.physical.PhysicalLimit;
31+
import org.apache.doris.nereids.trees.plans.physical.PhysicalProject;
3032
import org.apache.doris.nereids.trees.plans.physical.PhysicalTopN;
3133

34+
import com.google.common.collect.Sets;
35+
36+
import java.util.HashSet;
3237
import java.util.List;
38+
import java.util.Set;
3339
import java.util.stream.Collectors;
3440

3541
/**
@@ -46,32 +52,54 @@
4652
public class PushLimitToLocalAgg extends PlanPostProcessor {
4753
@Override
4854
public Plan visitPhysicalTopN(PhysicalTopN<? extends Plan> topN, CascadesContext ctx) {
49-
if (topN.child() instanceof PhysicalHashAggregate) {
50-
PhysicalHashAggregate<? extends Plan> upperAgg = (PhysicalHashAggregate<? extends Plan>) topN.child();
55+
Plan topnChild = topN.child();
56+
if (topnChild instanceof PhysicalProject) {
57+
topnChild = topnChild.child(0);
58+
}
59+
if (topnChild instanceof PhysicalHashAggregate) {
60+
PhysicalHashAggregate<? extends Plan> upperAgg = (PhysicalHashAggregate<? extends Plan>) topnChild;
5161
upperAgg.setTopn(new TopNOptInfo(
5262
topN.getOrderKeys(),
5363
topN.getLimit() + topN.getOffset()));
5464
if (upperAgg.getAggPhase().isGlobal()) {
5565
if (upperAgg.child() instanceof PhysicalDistribute
5666
&& upperAgg.child().child(0) instanceof PhysicalHashAggregate) {
57-
PhysicalHashAggregate<? extends Plan> bottomAgg =
58-
(PhysicalHashAggregate<? extends Plan>) upperAgg.child().child(0);
59-
bottomAgg.setTopn(new TopNOptInfo(
60-
topN.getOrderKeys(),
61-
topN.getLimit() + topN.getOffset()));
62-
bottomAgg.child().accept(this, ctx);
67+
if (checkTopnKeyAndGroupKey(topN, upperAgg)) {
68+
PhysicalHashAggregate<? extends Plan> bottomAgg =
69+
(PhysicalHashAggregate<? extends Plan>) upperAgg.child().child(0);
70+
bottomAgg.setTopn(new TopNOptInfo(
71+
topN.getOrderKeys(),
72+
topN.getLimit() + topN.getOffset()));
73+
}
6374
}
6475
}
65-
} else {
66-
topN.child().accept(this, ctx);
6776
}
77+
topN.child().accept(this, ctx);
6878
return topN;
6979
}
7080

81+
/**
82+
*
83+
* @param topN
84+
* @param agg
85+
* @return true, if topn order key set equals to agg group keys, ignore asc/desc and null_first
86+
*/
87+
private boolean checkTopnKeyAndGroupKey(PhysicalTopN<? extends Plan> topN,
88+
PhysicalHashAggregate<? extends Plan> agg) {
89+
Set<Expression> orderKeys = topN.getOrderKeys().stream()
90+
.map(OrderKey::getExpr).collect(Collectors.toSet());
91+
Set<Expression> groupKeys = new HashSet<>(agg.getGroupByExpressions());
92+
return groupKeys.size() == orderKeys.size() && groupKeys.containsAll(orderKeys);
93+
}
94+
7195
@Override
7296
public Plan visitPhysicalLimit(PhysicalLimit<? extends Plan> limit, CascadesContext ctx) {
73-
if (limit.child() instanceof PhysicalHashAggregate) {
74-
PhysicalHashAggregate<? extends Plan> upperAgg = (PhysicalHashAggregate<? extends Plan>) limit.child();
97+
Plan limitChild = limit.child();
98+
if (limitChild instanceof PhysicalProject) {
99+
limitChild = limitChild.child(0);
100+
}
101+
if (limitChild instanceof PhysicalHashAggregate) {
102+
PhysicalHashAggregate<? extends Plan> upperAgg = (PhysicalHashAggregate<? extends Plan>) limitChild;
75103
upperAgg.setTopn(new TopNOptInfo(
76104
generateOrderKeysByGroupKeys(upperAgg),
77105
limit.getLimit() + limit.getOffset()));
@@ -83,12 +111,11 @@ public Plan visitPhysicalLimit(PhysicalLimit<? extends Plan> limit, CascadesCont
83111
bottomAgg.setTopn(new TopNOptInfo(
84112
generateOrderKeysByGroupKeys(bottomAgg),
85113
limit.getLimit() + limit.getOffset()));
86-
bottomAgg.child().accept(this, ctx);
87114
}
88115
}
89-
} else {
90-
limit.child().accept(this, ctx);
91116
}
117+
limit.child().accept(this, ctx);
118+
92119
return limit;
93120
}
94121

regression-test/suites/nereids_tpch_p0/tpch/push_limit_to_local_agg.groovy

+39-1
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,46 @@
2020
suite("push_limit_to_local_agg") {
2121
String db = context.config.getDbNameByFile(new File(context.file.parent))
2222
sql "use ${db}"
23+
// limit -> agg
2324
explain{
24-
sql "select o_custkey, sum(o_shippriority) from orders group by o_custkey order by o_custkey limit 4;"
25+
sql "select o_custkey, sum(o_shippriority) from orders group by o_custkey limit 4;"
2526
multiContains ("sortByGroupKey:true", 2)
27+
notContains("STREAMING")
2628
}
29+
// after BE support push, change to order_qt_limitAgg
30+
sql "select o_custkey, sum(o_shippriority) from orders group by o_custkey limit 5;"
31+
32+
// limit -> proj -> agg
33+
explain{
34+
sql "select sum(c_custkey) from customer group by c_name limit 6;"
35+
multiContains ("sortByGroupKey:true", 2)
36+
notContains("STREAMING")
37+
}
38+
// after be support push, change it to order_qt_limitProjAgg
39+
sql "select sum(c_custkey) from customer group by c_name limit 7;"
40+
41+
// topn -> agg
42+
explain{
43+
sql "select o_custkey, sum(o_shippriority) from orders group by o_custkey order by o_custkey limit 8;"
44+
multiContains ("sortByGroupKey:true", 2)
45+
notContains("STREAMING")
46+
}
47+
48+
// topnKey != GroupKey
49+
explain{
50+
sql "select o_custkey, sum(o_shippriority) from orders group by o_custkey, o_clerk order by o_custkey limit 9;"
51+
multiContains("sortByGroupKey:true", 1) // global agg
52+
multiContains("sortByGroupKey:false", 1) // local agg
53+
}
54+
55+
// topnKey != GroupKey
56+
explain{
57+
sql "select o_custkey, sum(o_shippriority) as x from orders group by o_custkey, o_clerk order by o_custkey, x limit 10;"
58+
multiContains("sortByGroupKey:true", 1) // global agg
59+
multiContains("sortByGroupKey:false", 1) // local agg
60+
}
61+
62+
63+
64+
2765
}

0 commit comments

Comments
 (0)