Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix push_down_filter for pushing filters on grouping columns rather than aggregate columns #4447

Merged
merged 7 commits into from
Dec 3, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 37 additions & 56 deletions datafusion/optimizer/src/push_down_filter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
use crate::utils::conjunction;
use crate::{utils, OptimizerConfig, OptimizerRule};
use datafusion_common::{Column, DFSchema, DataFusionError, Result};
use datafusion_expr::utils::exprlist_to_columns;
use datafusion_expr::{
and,
expr_rewriter::{replace_col, ExprRewritable, ExprRewriter},
Expand Down Expand Up @@ -620,19 +619,12 @@ impl OptimizerRule for PushDownFilter {
})
}
LogicalPlan::Aggregate(agg) => {
// An aggregate's aggregate columns are _not_ filter-commutable => collect these:
// * columns whose aggregation expression depends on
// * the aggregation columns themselves

// construct set of columns that `aggr_expr` depends on
let mut used_columns = HashSet::new();
exprlist_to_columns(&agg.aggr_expr, &mut used_columns)?;
let agg_columns = agg
.aggr_expr
// We can push down Predicate which in groupby_expr.
let group_expr_columns = agg
.group_expr
.iter()
.map(|x| Ok(Column::from_name(x.display_name()?)))
.map(|e| Ok(Column::from_qualified_name(&(e.display_name()?))))
.collect::<Result<HashSet<_>>>()?;
used_columns.extend(agg_columns);

let predicates = utils::split_conjunction_owned(utils::cnf_rewrite(
filter.predicate().clone(),
Expand All @@ -641,20 +633,27 @@ impl OptimizerRule for PushDownFilter {
let mut keep_predicates = vec![];
let mut push_predicates = vec![];
for expr in predicates {
let columns = expr.to_columns()?;
if columns.is_empty()
|| !columns
.intersection(&used_columns)
.collect::<HashSet<_>>()
.is_empty()
Comment on lines -646 to -649
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The original performance was bad

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As I mentioned in the last PR, I think we do not need to check the aggregate Exprs, but just check the group by Exprs. In some cases, the same column can exist in both aggregate Exprs and group by Exprs, for example select count(distinct col_a), col_a from table group by col_a; . If there is a Filter applied to col_a, the Filter can still be pushed down even it is referred by the agg Exprs.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The logic should check all the columns used by the Filter predicate is the subset of the group by Exprs output Columns.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. For push_down_filter through Agg, we can push Expr in groupby_expr.
Has add it.

{
keep_predicates.push(expr);
} else {
let cols = expr.to_columns()?;
if cols.iter().all(|c| group_expr_columns.contains(c)) {
push_predicates.push(expr);
} else {
keep_predicates.push(expr);
}
}

let child = match conjunction(push_predicates) {
// As for plan Filter: Column(a+b) > 0 -- Agg: groupby:[Column(a)+Column(b)]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice -- this is getting quite sophisticated.

// After push, we need to replace `a+b` with Column(a)+Column(b)
// So we need create a replace_map, add {`a+b` --> Expr(Column(a)+Column(b))}
let mut replace_map = HashMap::new();
for expr in &agg.group_expr {
replace_map.insert(expr.display_name()?, expr.clone());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Double checked that display_name is the right one: https://docs.rs/datafusion/14.0.0/datafusion/prelude/enum.Expr.html#method.display_name 👍

}
let replaced_push_predicates = push_predicates
.iter()
.map(|expr| replace_cols_by_name(expr.clone(), &replace_map))
.collect::<Result<Vec<_>>>()?;

let child = match conjunction(replaced_push_predicates) {
Some(predicate) => LogicalPlan::Filter(Filter::try_new(
predicate,
Arc::new((*agg.input).clone()),
Expand Down Expand Up @@ -881,40 +880,30 @@ mod tests {
}

#[test]
fn filter_keep_agg() -> Result<()> {
let table_scan = test_table_scan()?;
let plan = LogicalPlanBuilder::from(table_scan)
.aggregate(vec![col("a")], vec![sum(col("b")).alias("b")])?
.filter(col("b").gt(lit(10i64)))?
fn push_agg_need_replace_expr() -> Result<()> {
let plan = LogicalPlanBuilder::from(test_table_scan()?)
.aggregate(vec![add(col("b"), col("a"))], vec![sum(col("a")), col("b")])?
.filter(col("test.b + test.a").gt(lit(10i64)))?
.build()?;
// filter of aggregate is after aggregation since they are non-commutative
let expected = "\
Filter: b > Int64(10)\
\n Aggregate: groupBy=[[test.a]], aggr=[[SUM(test.b) AS b]]\
\n TableScan: test";
let expected =
"Aggregate: groupBy=[[test.b + test.a]], aggr=[[SUM(test.a), test.b]]\
\n Filter: test.b + test.a > Int64(10)\
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍 very nice

\n TableScan: test";
assert_optimized_plan_eq(&plan, expected)
}

#[test]
fn filter_keep_partial_agg() -> Result<()> {
fn filter_keep_agg() -> Result<()> {
let table_scan = test_table_scan()?;
let f1 = col("c").eq(lit(1i64)).and(col("b").gt(lit(2i64)));
let f2 = col("c").eq(lit(1i64)).and(col("b").gt(lit(3i64)));
let filter = f1.or(f2);
let plan = LogicalPlanBuilder::from(table_scan)
.aggregate(vec![col("a")], vec![sum(col("b")).alias("b")])?
.filter(filter)?
.filter(col("b").gt(lit(10i64)))?
.build()?;
// filter of aggregate is after aggregation since they are non-commutative
// (c =1 AND b > 2) OR (c = 1 AND b > 3)
// rewrite to CNF
// (c = 1 OR c = 1) [can pushDown] AND (c = 1 OR b > 3) AND (b > 2 OR C = 1) AND (b > 2 OR b > 3)

let expected = "\
Filter: (test.c = Int64(1) OR b > Int64(3)) AND (b > Int64(2) OR test.c = Int64(1)) AND (b > Int64(2) OR b > Int64(3))\
\n Aggregate: groupBy=[[test.a]], aggr=[[SUM(test.b) AS b]]\
\n Filter: test.c = Int64(1) OR test.c = Int64(1)\
\n TableScan: test";
Filter: b > Int64(10)\
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree this new plan is correct

\n Aggregate: groupBy=[[test.a]], aggr=[[SUM(test.b) AS b]]\
\n TableScan: test";
assert_optimized_plan_eq(&plan, expected)
}

Expand Down Expand Up @@ -1870,17 +1859,9 @@ mod tests {
#[async_trait]
impl TableSource for PushDownProvider {
fn schema(&self) -> SchemaRef {
Arc::new(arrow::datatypes::Schema::new(vec![
arrow::datatypes::Field::new(
"a",
arrow::datatypes::DataType::Int32,
true,
),
arrow::datatypes::Field::new(
"b",
arrow::datatypes::DataType::Int32,
true,
),
Arc::new(Schema::new(vec![
Field::new("a", DataType::Int32, true),
Field::new("b", DataType::Int32, true),
]))
}

Expand Down
12 changes: 12 additions & 0 deletions datafusion/optimizer/tests/integration-test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,18 @@ fn join_keys_in_subquery_alias_1() {
assert_eq!(expected, format!("{:?}", plan));
}

#[test]
fn push_down_filter_groupby_expr_contains_alias() {
let sql = "SELECT * FROM (SELECT (col_int32 + col_uint32) AS c, count(*) FROM test GROUP BY 1) where c > 3";
let plan = test_sql(sql).unwrap();
let expected = "Projection: c, COUNT(UInt8(1))\
\n Projection: test.col_int32 + test.col_uint32 AS c, COUNT(UInt8(1))\
\n Aggregate: groupBy=[[test.col_int32 + CAST(test.col_uint32 AS Int32)]], aggr=[[COUNT(UInt8(1))]]\
\n Filter: test.col_int32 + CAST(test.col_uint32 AS Int32) > Int32(3)\
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

\n TableScan: test projection=[col_int32, col_uint32]";
assert_eq!(expected, format!("{:?}", plan));
}

fn test_sql(sql: &str) -> Result<LogicalPlan> {
// parse the SQL
let dialect = GenericDialect {}; // or AnsiDialect, or your own dialect ...
Expand Down