-
Notifications
You must be signed in to change notification settings - Fork 1.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix push_down_filter
for pushing filters on grouping columns rather than aggregate columns
#4447
Changes from all commits
c75e71c
018b9a5
992f3af
80c025b
e51daa2
c9c89c5
f9a7072
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,7 +17,6 @@ | |
use crate::utils::conjunction; | ||
use crate::{utils, OptimizerConfig, OptimizerRule}; | ||
use datafusion_common::{Column, DFSchema, DataFusionError, Result}; | ||
use datafusion_expr::utils::exprlist_to_columns; | ||
use datafusion_expr::{ | ||
and, | ||
expr_rewriter::{replace_col, ExprRewritable, ExprRewriter}, | ||
|
@@ -620,19 +619,12 @@ impl OptimizerRule for PushDownFilter { | |
}) | ||
} | ||
LogicalPlan::Aggregate(agg) => { | ||
// An aggregate's aggregate columns are _not_ filter-commutable => collect these: | ||
// * columns whose aggregation expression depends on | ||
// * the aggregation columns themselves | ||
|
||
// construct set of columns that `aggr_expr` depends on | ||
let mut used_columns = HashSet::new(); | ||
exprlist_to_columns(&agg.aggr_expr, &mut used_columns)?; | ||
let agg_columns = agg | ||
.aggr_expr | ||
// We can push down Predicate which in groupby_expr. | ||
let group_expr_columns = agg | ||
.group_expr | ||
.iter() | ||
.map(|x| Ok(Column::from_name(x.display_name()?))) | ||
.map(|e| Ok(Column::from_qualified_name(&(e.display_name()?)))) | ||
.collect::<Result<HashSet<_>>>()?; | ||
used_columns.extend(agg_columns); | ||
|
||
let predicates = utils::split_conjunction_owned(utils::cnf_rewrite( | ||
filter.predicate().clone(), | ||
|
@@ -641,20 +633,27 @@ impl OptimizerRule for PushDownFilter { | |
let mut keep_predicates = vec![]; | ||
let mut push_predicates = vec![]; | ||
for expr in predicates { | ||
let columns = expr.to_columns()?; | ||
if columns.is_empty() | ||
|| !columns | ||
.intersection(&used_columns) | ||
.collect::<HashSet<_>>() | ||
.is_empty() | ||
{ | ||
keep_predicates.push(expr); | ||
} else { | ||
let cols = expr.to_columns()?; | ||
if cols.iter().all(|c| group_expr_columns.contains(c)) { | ||
push_predicates.push(expr); | ||
} else { | ||
keep_predicates.push(expr); | ||
} | ||
} | ||
|
||
let child = match conjunction(push_predicates) { | ||
// As for plan Filter: Column(a+b) > 0 -- Agg: groupby:[Column(a)+Column(b)] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nice -- this is getting quite sophisticated. |
||
// After push, we need to replace `a+b` with Column(a)+Column(b) | ||
// So we need create a replace_map, add {`a+b` --> Expr(Column(a)+Column(b))} | ||
let mut replace_map = HashMap::new(); | ||
for expr in &agg.group_expr { | ||
replace_map.insert(expr.display_name()?, expr.clone()); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Double checked that |
||
} | ||
let replaced_push_predicates = push_predicates | ||
.iter() | ||
.map(|expr| replace_cols_by_name(expr.clone(), &replace_map)) | ||
.collect::<Result<Vec<_>>>()?; | ||
|
||
let child = match conjunction(replaced_push_predicates) { | ||
Some(predicate) => LogicalPlan::Filter(Filter::try_new( | ||
predicate, | ||
Arc::new((*agg.input).clone()), | ||
|
@@ -881,40 +880,30 @@ mod tests { | |
} | ||
|
||
#[test] | ||
fn filter_keep_agg() -> Result<()> { | ||
let table_scan = test_table_scan()?; | ||
let plan = LogicalPlanBuilder::from(table_scan) | ||
.aggregate(vec![col("a")], vec![sum(col("b")).alias("b")])? | ||
.filter(col("b").gt(lit(10i64)))? | ||
fn push_agg_need_replace_expr() -> Result<()> { | ||
let plan = LogicalPlanBuilder::from(test_table_scan()?) | ||
.aggregate(vec![add(col("b"), col("a"))], vec![sum(col("a")), col("b")])? | ||
.filter(col("test.b + test.a").gt(lit(10i64)))? | ||
.build()?; | ||
// filter of aggregate is after aggregation since they are non-commutative | ||
let expected = "\ | ||
Filter: b > Int64(10)\ | ||
\n Aggregate: groupBy=[[test.a]], aggr=[[SUM(test.b) AS b]]\ | ||
\n TableScan: test"; | ||
let expected = | ||
"Aggregate: groupBy=[[test.b + test.a]], aggr=[[SUM(test.a), test.b]]\ | ||
\n Filter: test.b + test.a > Int64(10)\ | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 👍 very nice |
||
\n TableScan: test"; | ||
assert_optimized_plan_eq(&plan, expected) | ||
} | ||
|
||
#[test] | ||
fn filter_keep_partial_agg() -> Result<()> { | ||
fn filter_keep_agg() -> Result<()> { | ||
let table_scan = test_table_scan()?; | ||
let f1 = col("c").eq(lit(1i64)).and(col("b").gt(lit(2i64))); | ||
let f2 = col("c").eq(lit(1i64)).and(col("b").gt(lit(3i64))); | ||
let filter = f1.or(f2); | ||
let plan = LogicalPlanBuilder::from(table_scan) | ||
.aggregate(vec![col("a")], vec![sum(col("b")).alias("b")])? | ||
.filter(filter)? | ||
.filter(col("b").gt(lit(10i64)))? | ||
.build()?; | ||
// filter of aggregate is after aggregation since they are non-commutative | ||
// (c =1 AND b > 2) OR (c = 1 AND b > 3) | ||
// rewrite to CNF | ||
// (c = 1 OR c = 1) [can pushDown] AND (c = 1 OR b > 3) AND (b > 2 OR C = 1) AND (b > 2 OR b > 3) | ||
|
||
let expected = "\ | ||
Filter: (test.c = Int64(1) OR b > Int64(3)) AND (b > Int64(2) OR test.c = Int64(1)) AND (b > Int64(2) OR b > Int64(3))\ | ||
\n Aggregate: groupBy=[[test.a]], aggr=[[SUM(test.b) AS b]]\ | ||
\n Filter: test.c = Int64(1) OR test.c = Int64(1)\ | ||
\n TableScan: test"; | ||
Filter: b > Int64(10)\ | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I agree this new plan is correct |
||
\n Aggregate: groupBy=[[test.a]], aggr=[[SUM(test.b) AS b]]\ | ||
\n TableScan: test"; | ||
assert_optimized_plan_eq(&plan, expected) | ||
} | ||
|
||
|
@@ -1870,17 +1859,9 @@ mod tests { | |
#[async_trait] | ||
impl TableSource for PushDownProvider { | ||
fn schema(&self) -> SchemaRef { | ||
Arc::new(arrow::datatypes::Schema::new(vec![ | ||
arrow::datatypes::Field::new( | ||
"a", | ||
arrow::datatypes::DataType::Int32, | ||
true, | ||
), | ||
arrow::datatypes::Field::new( | ||
"b", | ||
arrow::datatypes::DataType::Int32, | ||
true, | ||
), | ||
Arc::new(Schema::new(vec![ | ||
Field::new("a", DataType::Int32, true), | ||
Field::new("b", DataType::Int32, true), | ||
])) | ||
} | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -304,6 +304,18 @@ fn join_keys_in_subquery_alias_1() { | |
assert_eq!(expected, format!("{:?}", plan)); | ||
} | ||
|
||
#[test] | ||
fn push_down_filter_groupby_expr_contains_alias() { | ||
let sql = "SELECT * FROM (SELECT (col_int32 + col_uint32) AS c, count(*) FROM test GROUP BY 1) where c > 3"; | ||
let plan = test_sql(sql).unwrap(); | ||
let expected = "Projection: c, COUNT(UInt8(1))\ | ||
\n Projection: test.col_int32 + test.col_uint32 AS c, COUNT(UInt8(1))\ | ||
\n Aggregate: groupBy=[[test.col_int32 + CAST(test.col_uint32 AS Int32)]], aggr=[[COUNT(UInt8(1))]]\ | ||
\n Filter: test.col_int32 + CAST(test.col_uint32 AS Int32) > Int32(3)\ | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 👍 |
||
\n TableScan: test projection=[col_int32, col_uint32]"; | ||
assert_eq!(expected, format!("{:?}", plan)); | ||
} | ||
|
||
fn test_sql(sql: &str) -> Result<LogicalPlan> { | ||
// parse the SQL | ||
let dialect = GenericDialect {}; // or AnsiDialect, or your own dialect ... | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The original performance was bad
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As I mentioned in the last PR, I think we do not need to check the aggregate Exprs, but just check the group by Exprs. In some cases, the same column can exist in both aggregate Exprs and group by Exprs, for example
select count(distinct col_a), col_a from table group by col_a;
. If there is a Filter applied to col_a, the Filter can still be pushed down even it is referred by the agg Exprs.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The logic should check all the columns used by the Filter predicate is the subset of the group by Exprs output Columns.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes. For push_down_filter through Agg, we can push
Expr
ingroupby_expr
.Has add it.