27
27
import org .apache .doris .nereids .trees .expressions .Slot ;
28
28
import org .apache .doris .nereids .trees .expressions .StatementScopeIdGenerator ;
29
29
import org .apache .doris .nereids .trees .expressions .functions .agg .Max ;
30
+ import org .apache .doris .nereids .trees .expressions .functions .agg .Sum ;
31
+ import org .apache .doris .nereids .trees .expressions .functions .scalar .If ;
30
32
import org .apache .doris .nereids .trees .expressions .literal .Literal ;
31
33
import org .apache .doris .nereids .trees .plans .logical .LogicalOlapScan ;
32
34
import org .apache .doris .nereids .trees .plans .logical .LogicalPlan ;
@@ -91,6 +93,28 @@ public void pushDownPredicateOneFilterTest() {
91
93
);
92
94
}
93
95
96
+ @ Test
97
+ void scalarAgg () {
98
+ LogicalPlan plan = new LogicalPlanBuilder (scan )
99
+ .agg (ImmutableList .of (), ImmutableList .of ((new Sum (scan .getOutput ().get (0 ))).alias ("sum" )))
100
+ .filter (new If (Literal .of (false ), Literal .of (false ), Literal .of (false )))
101
+ .project (ImmutableList .of (0 ))
102
+ .build ();
103
+
104
+ PlanChecker .from (MemoTestUtils .createConnectContext (), plan )
105
+ .applyTopDown (new PushDownFilterThroughAggregation ())
106
+ .printlnTree ()
107
+ .matches (
108
+ logicalProject (
109
+ logicalFilter (
110
+ logicalAggregate (
111
+ logicalOlapScan ()
112
+ )
113
+ )
114
+ )
115
+ );
116
+ }
117
+
94
118
/*-
95
119
* origin plan:
96
120
* project
@@ -174,7 +198,8 @@ public void pushDownPredicateGroupWithRepeatTest() {
174
198
logicalAggregate (
175
199
logicalFilter (
176
200
logicalRepeat ()
177
- ).when (filter -> filter .getConjuncts ().equals (ImmutableSet .of (filterPredicateId )))
201
+ ).when (filter -> filter .getConjuncts ()
202
+ .equals (ImmutableSet .of (filterPredicateId )))
178
203
)
179
204
)
180
205
);
@@ -195,9 +220,9 @@ public void pushDownPredicateGroupWithRepeatTest() {
195
220
.matches (
196
221
logicalProject (
197
222
logicalFilter (
198
- logicalAggregate (
199
- logicalRepeat ()
200
- )
223
+ logicalAggregate (
224
+ logicalRepeat ()
225
+ )
201
226
).when (filter -> filter .getConjuncts ().equals (ImmutableSet .of (filterPredicateId )))
202
227
)
203
228
);
0 commit comments