Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[opt](nereids) optimize push limit to agg #44042

Merged
merged 5 commits into from
Nov 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
import org.apache.doris.nereids.trees.plans.logical.LogicalLimit;
Expand All @@ -32,6 +34,7 @@
import com.google.common.collect.Lists;

import java.util.List;
import java.util.Optional;
import java.util.stream.Collectors;

/**
Expand All @@ -53,7 +56,11 @@ public List<Rule> buildRules() {
>= limit.getLimit() + limit.getOffset())
.then(limit -> {
LogicalAggregate<? extends Plan> agg = limit.child();
List<OrderKey> orderKeys = generateOrderKeyByGroupKey(agg);
Optional<OrderKey> orderKeysOpt = tryGenerateOrderKeyByTheFirstGroupKey(agg);
if (!orderKeysOpt.isPresent()) {
return null;
}
List<OrderKey> orderKeys = Lists.newArrayList(orderKeysOpt.get());
return new LogicalTopN<>(orderKeys, limit.getLimit(), limit.getOffset(), agg);
}).toRule(RuleType.LIMIT_AGG_TO_TOPN_AGG),
//limit->project->agg to topn->project->agg
Expand All @@ -62,12 +69,47 @@ public List<Rule> buildRules() {
&& ConnectContext.get().getSessionVariable().pushTopnToAgg
&& ConnectContext.get().getSessionVariable().topnOptLimitThreshold
>= limit.getLimit() + limit.getOffset())
.when(limit -> outputAllGroupKeys(limit, limit.child().child()))
.then(limit -> {
LogicalProject<? extends Plan> project = limit.child();
LogicalAggregate<? extends Plan> agg = (LogicalAggregate<? extends Plan>) project.child();
List<OrderKey> orderKeys = generateOrderKeyByGroupKey(agg);
return new LogicalTopN<>(orderKeys, limit.getLimit(), limit.getOffset(), project);
LogicalAggregate<? extends Plan> agg
= (LogicalAggregate<? extends Plan>) project.child();
Optional<OrderKey> orderKeysOpt = tryGenerateOrderKeyByTheFirstGroupKey(agg);
if (!orderKeysOpt.isPresent()) {
return null;
}
List<OrderKey> orderKeys = Lists.newArrayList(orderKeysOpt.get());
Plan result;

if (outputAllGroupKeys(limit, agg)) {
result = new LogicalTopN<>(orderKeys, limit.getLimit(),
limit.getOffset(), project);
} else {
// add the first group by key to topn, and prune this key by upper project
// topn order keys are prefix of group by keys
// refer to PushTopnToAgg.tryGenerateOrderKeyByGroupKeyAndTopnKey()
Expression firstGroupByKey = agg.getGroupByExpressions().get(0);
if (!(firstGroupByKey instanceof SlotReference)) {
return null;
}
boolean shouldPruneFirstGroupByKey = true;
if (project.getOutputs().contains(firstGroupByKey)) {
shouldPruneFirstGroupByKey = false;
} else {
List<NamedExpression> bottomProjections = Lists.newArrayList(project.getProjects());
bottomProjections.add((SlotReference) firstGroupByKey);
project = project.withProjects(bottomProjections);
}
LogicalTopN topn = new LogicalTopN<>(orderKeys, limit.getLimit(),
limit.getOffset(), project);
if (shouldPruneFirstGroupByKey) {
List<NamedExpression> limitOutput = limit.getOutput().stream()
.map(e -> (NamedExpression) e).collect(Collectors.toList());
result = new LogicalProject<>(limitOutput, topn);
} else {
result = topn;
}
}
return result;
}).toRule(RuleType.LIMIT_AGG_TO_TOPN_AGG),
// topn -> agg: add all group key to sort key, if sort key is prefix of group key
logicalTopN(logicalAggregate())
Expand Down Expand Up @@ -111,9 +153,10 @@ private boolean outputAllGroupKeys(LogicalLimit limit, LogicalAggregate agg) {
return limit.getOutputSet().containsAll(agg.getGroupByExpressions());
}

private List<OrderKey> generateOrderKeyByGroupKey(LogicalAggregate<? extends Plan> agg) {
return agg.getGroupByExpressions().stream()
.map(key -> new OrderKey(key, true, false))
.collect(Collectors.toList());
private Optional<OrderKey> tryGenerateOrderKeyByTheFirstGroupKey(LogicalAggregate<? extends Plan> agg) {
if (agg.getGroupByExpressions().isEmpty()) {
return Optional.empty();
}
return Optional.of(new OrderKey(agg.getGroupByExpressions().get(0), true, false));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -206,8 +206,8 @@ public String toString() {
"groupByExpr", groupByExpressions,
"outputExpr", outputExpressions,
"partitionExpr", partitionExpressions,
"requireProperties", requireProperties,
"topnOpt", topnPushInfo != null
"topnFilter", topnPushInfo != null,
"topnPushDown", getMutableState(MutableState.KEY_PUSH_TOPN_TO_AGG).isPresent()
);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -165,8 +165,7 @@ void testSortLimit() {
PlanChecker.from(connectContext).disableNereidsRules("PRUNE_EMPTY_PARTITION")
.analyze("select count(*) from (select * from student order by id) t limit 1")
.rewrite()
// there is no topn below agg
.matches(logicalTopN(logicalAggregate(logicalProject(logicalOlapScan()))));
.nonMatch(logicalTopN());
PlanChecker.from(connectContext)
.disableNereidsRules("PRUNE_EMPTY_PARTITION")
.analyze("select count(*) from (select * from student order by id limit 1) t")
Expand All @@ -184,8 +183,6 @@ void testSortLimit() {
.analyze("select count(*) from "
+ "(select * from student order by id) t1 left join student t2 on t1.id = t2.id limit 1")
.rewrite()
.matches(logicalTopN(logicalAggregate(logicalProject(logicalJoin(
logicalProject(logicalOlapScan()),
logicalProject(logicalOlapScan()))))));
.nonMatch(logicalTopN());
}
}
63 changes: 32 additions & 31 deletions regression-test/data/nereids_hint_tpcds_p0/shape/query23.out
Original file line number Diff line number Diff line change
Expand Up @@ -46,35 +46,36 @@ PhysicalCteAnchor ( cteId=CTEId#0 )
--------------------------------filter(d_year IN (2000, 2001, 2002, 2003))
----------------------------------PhysicalOlapScan[date_dim]
----PhysicalResultSink
------PhysicalTopN[GATHER_SORT]
--------hashAgg[GLOBAL]
----------PhysicalDistribute[DistributionSpecGather]
------------hashAgg[LOCAL]
--------------PhysicalUnion
----------------PhysicalProject
------------------hashJoin[RIGHT_SEMI_JOIN shuffle] hashCondition=((catalog_sales.cs_item_sk = frequent_ss_items.item_sk)) otherCondition=() build RFs:RF5 cs_item_sk->[item_sk]
--------------------PhysicalCteConsumer ( cteId=CTEId#0 ) apply RFs: RF5
--------------------PhysicalProject
----------------------hashJoin[LEFT_SEMI_JOIN broadcast] hashCondition=((catalog_sales.cs_bill_customer_sk = best_ss_customer.c_customer_sk)) otherCondition=() build RFs:RF4 c_customer_sk->[cs_bill_customer_sk]
------------------------PhysicalProject
--------------------------hashJoin[INNER_JOIN broadcast] hashCondition=((catalog_sales.cs_sold_date_sk = date_dim.d_date_sk)) otherCondition=() build RFs:RF3 d_date_sk->[cs_sold_date_sk]
----------------------------PhysicalProject
------------------------------PhysicalOlapScan[catalog_sales] apply RFs: RF3 RF4
----------------------------PhysicalProject
------------------------------filter((date_dim.d_moy = 7) and (date_dim.d_year = 2000))
--------------------------------PhysicalOlapScan[date_dim]
------------------------PhysicalCteConsumer ( cteId=CTEId#2 )
----------------PhysicalProject
------------------hashJoin[RIGHT_SEMI_JOIN shuffle] hashCondition=((web_sales.ws_item_sk = frequent_ss_items.item_sk)) otherCondition=() build RFs:RF8 ws_item_sk->[item_sk]
--------------------PhysicalCteConsumer ( cteId=CTEId#0 ) apply RFs: RF8
--------------------PhysicalProject
----------------------hashJoin[LEFT_SEMI_JOIN broadcast] hashCondition=((web_sales.ws_bill_customer_sk = best_ss_customer.c_customer_sk)) otherCondition=() build RFs:RF7 c_customer_sk->[ws_bill_customer_sk]
------------------------PhysicalProject
--------------------------hashJoin[INNER_JOIN broadcast] hashCondition=((web_sales.ws_sold_date_sk = date_dim.d_date_sk)) otherCondition=() build RFs:RF6 d_date_sk->[ws_sold_date_sk]
----------------------------PhysicalProject
------------------------------PhysicalOlapScan[web_sales] apply RFs: RF6 RF7
----------------------------PhysicalProject
------------------------------filter((date_dim.d_moy = 7) and (date_dim.d_year = 2000))
--------------------------------PhysicalOlapScan[date_dim]
------------------------PhysicalCteConsumer ( cteId=CTEId#2 )
------PhysicalLimit[GLOBAL]
--------PhysicalLimit[LOCAL]
----------hashAgg[GLOBAL]
------------PhysicalDistribute[DistributionSpecGather]
--------------hashAgg[LOCAL]
----------------PhysicalUnion
------------------PhysicalProject
--------------------hashJoin[RIGHT_SEMI_JOIN shuffle] hashCondition=((catalog_sales.cs_item_sk = frequent_ss_items.item_sk)) otherCondition=() build RFs:RF5 cs_item_sk->[item_sk]
----------------------PhysicalCteConsumer ( cteId=CTEId#0 ) apply RFs: RF5
----------------------PhysicalProject
------------------------hashJoin[LEFT_SEMI_JOIN broadcast] hashCondition=((catalog_sales.cs_bill_customer_sk = best_ss_customer.c_customer_sk)) otherCondition=() build RFs:RF4 c_customer_sk->[cs_bill_customer_sk]
--------------------------PhysicalProject
----------------------------hashJoin[INNER_JOIN broadcast] hashCondition=((catalog_sales.cs_sold_date_sk = date_dim.d_date_sk)) otherCondition=() build RFs:RF3 d_date_sk->[cs_sold_date_sk]
------------------------------PhysicalProject
--------------------------------PhysicalOlapScan[catalog_sales] apply RFs: RF3 RF4
------------------------------PhysicalProject
--------------------------------filter((date_dim.d_moy = 7) and (date_dim.d_year = 2000))
----------------------------------PhysicalOlapScan[date_dim]
--------------------------PhysicalCteConsumer ( cteId=CTEId#2 )
------------------PhysicalProject
--------------------hashJoin[RIGHT_SEMI_JOIN shuffle] hashCondition=((web_sales.ws_item_sk = frequent_ss_items.item_sk)) otherCondition=() build RFs:RF8 ws_item_sk->[item_sk]
----------------------PhysicalCteConsumer ( cteId=CTEId#0 ) apply RFs: RF8
----------------------PhysicalProject
------------------------hashJoin[LEFT_SEMI_JOIN broadcast] hashCondition=((web_sales.ws_bill_customer_sk = best_ss_customer.c_customer_sk)) otherCondition=() build RFs:RF7 c_customer_sk->[ws_bill_customer_sk]
--------------------------PhysicalProject
----------------------------hashJoin[INNER_JOIN broadcast] hashCondition=((web_sales.ws_sold_date_sk = date_dim.d_date_sk)) otherCondition=() build RFs:RF6 d_date_sk->[ws_sold_date_sk]
------------------------------PhysicalProject
--------------------------------PhysicalOlapScan[web_sales] apply RFs: RF6 RF7
------------------------------PhysicalProject
--------------------------------filter((date_dim.d_moy = 7) and (date_dim.d_year = 2000))
----------------------------------PhysicalOlapScan[date_dim]
--------------------------PhysicalCteConsumer ( cteId=CTEId#2 )

43 changes: 22 additions & 21 deletions regression-test/data/nereids_hint_tpcds_p0/shape/query32.out
Original file line number Diff line number Diff line change
@@ -1,27 +1,28 @@
-- This file is automatically generated. You should know what you did if you want to edit this
-- !ds_shape_32 --
PhysicalResultSink
--PhysicalTopN[GATHER_SORT]
----hashAgg[GLOBAL]
------PhysicalDistribute[DistributionSpecGather]
--------hashAgg[LOCAL]
----------PhysicalProject
------------filter((cast(cs_ext_discount_amt as DECIMALV3(38, 5)) > (1.3 * avg(cast(cs_ext_discount_amt as DECIMALV3(9, 4))) OVER(PARTITION BY i_item_sk))))
--------------PhysicalWindow
----------------PhysicalQuickSort[LOCAL_SORT]
------------------PhysicalDistribute[DistributionSpecHash]
--------------------PhysicalProject
----------------------hashJoin[INNER_JOIN broadcast] hashCondition=((date_dim.d_date_sk = catalog_sales.cs_sold_date_sk)) otherCondition=() build RFs:RF1 d_date_sk->[cs_sold_date_sk]
------------------------PhysicalProject
--------------------------hashJoin[INNER_JOIN broadcast] hashCondition=((item.i_item_sk = catalog_sales.cs_item_sk)) otherCondition=() build RFs:RF0 i_item_sk->[cs_item_sk]
----------------------------PhysicalProject
------------------------------PhysicalOlapScan[catalog_sales] apply RFs: RF0 RF1
----------------------------PhysicalProject
------------------------------filter((item.i_manufact_id = 722))
--------------------------------PhysicalOlapScan[item]
------------------------PhysicalProject
--------------------------filter((date_dim.d_date <= '2001-06-07') and (date_dim.d_date >= '2001-03-09'))
----------------------------PhysicalOlapScan[date_dim]
--PhysicalLimit[GLOBAL]
----PhysicalLimit[LOCAL]
------hashAgg[GLOBAL]
--------PhysicalDistribute[DistributionSpecGather]
----------hashAgg[LOCAL]
------------PhysicalProject
--------------filter((cast(cs_ext_discount_amt as DECIMALV3(38, 5)) > (1.3 * avg(cast(cs_ext_discount_amt as DECIMALV3(9, 4))) OVER(PARTITION BY i_item_sk))))
----------------PhysicalWindow
------------------PhysicalQuickSort[LOCAL_SORT]
--------------------PhysicalDistribute[DistributionSpecHash]
----------------------PhysicalProject
------------------------hashJoin[INNER_JOIN broadcast] hashCondition=((date_dim.d_date_sk = catalog_sales.cs_sold_date_sk)) otherCondition=() build RFs:RF1 d_date_sk->[cs_sold_date_sk]
--------------------------PhysicalProject
----------------------------hashJoin[INNER_JOIN broadcast] hashCondition=((item.i_item_sk = catalog_sales.cs_item_sk)) otherCondition=() build RFs:RF0 i_item_sk->[cs_item_sk]
------------------------------PhysicalProject
--------------------------------PhysicalOlapScan[catalog_sales] apply RFs: RF0 RF1
------------------------------PhysicalProject
--------------------------------filter((item.i_manufact_id = 722))
----------------------------------PhysicalOlapScan[item]
--------------------------PhysicalProject
----------------------------filter((date_dim.d_date <= '2001-06-07') and (date_dim.d_date >= '2001-03-09'))
------------------------------PhysicalOlapScan[date_dim]

Hint log:
Used: leading(catalog_sales item date_dim )
Expand Down
Loading
Loading