Skip to content

Commit

Permalink
enhance filter push through agg
Browse files Browse the repository at this point in the history
  • Loading branch information
jackwener committed Dec 1, 2022
1 parent 992f3af commit 80c025b
Showing 1 changed file with 14 additions and 34 deletions.
48 changes: 14 additions & 34 deletions datafusion/optimizer/src/push_down_filter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
use crate::utils::conjunction;
use crate::{utils, OptimizerConfig, OptimizerRule};
use datafusion_common::{Column, DFSchema, DataFusionError, Result};
use datafusion_expr::utils::exprlist_to_columns;
use datafusion_expr::{
and,
expr_rewriter::{replace_col, ExprRewritable, ExprRewriter},
Expand Down Expand Up @@ -620,19 +619,12 @@ impl OptimizerRule for PushDownFilter {
})
}
LogicalPlan::Aggregate(agg) => {
// An aggregate's aggregate columns are _not_ filter-commutable => collect these:
// * columns whose aggregation expression depends on
// * the aggregation columns themselves

// construct set of columns that `aggr_expr` depends on
let mut aggr_expr_columns = HashSet::new();
exprlist_to_columns(&agg.aggr_expr, &mut aggr_expr_columns)?;
let agg_columns = agg
.aggr_expr
// We can push down Predicate which in groupby_expr.
let group_expr_columns = agg
.group_expr
.iter()
.map(|x| Ok(Column::from_name(x.display_name()?)))
.map(|e| Ok(Column::from_qualified_name(&(e.display_name()?))))
.collect::<Result<HashSet<_>>>()?;
aggr_expr_columns.extend(agg_columns);

let predicates = utils::split_conjunction_owned(utils::cnf_rewrite(
filter.predicate().clone(),
Expand All @@ -641,13 +633,11 @@ impl OptimizerRule for PushDownFilter {
let mut keep_predicates = vec![];
let mut push_predicates = vec![];
for expr in predicates {
let columns = expr.to_columns()?;
if columns.is_empty()
|| columns.intersection(&aggr_expr_columns).next().is_some()
{
keep_predicates.push(expr);
} else {
let cols = expr.to_columns()?;
if cols.iter().all(|c| group_expr_columns.contains(c)) {
push_predicates.push(expr);
} else {
keep_predicates.push(expr);
}
}

Expand Down Expand Up @@ -864,7 +854,7 @@ mod tests {
let table_scan = test_table_scan()?;
let plan = LogicalPlanBuilder::from(table_scan)
.aggregate(vec![col("a")], vec![sum(col("b")).alias("total_salary")])?
.filter(col("a").gt(lit(10i64)))?
.filter(col("test.a").gt(lit(10i64)))?
.build()?;
// filter of key aggregation is commutative
let expected = "\
Expand Down Expand Up @@ -930,11 +920,9 @@ mod tests {
// rewrite to CNF
// (c = 1 OR c = 1) [can pushDown] AND (c = 1 OR b > 3) AND (b > 2 OR C = 1) AND (b > 2 OR b > 3)

let expected = "\
Filter: (test.c = Int64(1) OR b > Int64(3)) AND (b > Int64(2) OR test.c = Int64(1)) AND (b > Int64(2) OR b > Int64(3))\
let expected = "Filter: (test.c = Int64(1) OR test.c = Int64(1)) AND (test.c = Int64(1) OR b > Int64(3)) AND (b > Int64(2) OR test.c = Int64(1)) AND (b > Int64(2) OR b > Int64(3))\
\n Aggregate: groupBy=[[test.a]], aggr=[[SUM(test.b) AS b]]\
\n Filter: test.c = Int64(1) OR test.c = Int64(1)\
\n TableScan: test";
\n TableScan: test";
assert_optimized_plan_eq(&plan, expected)
}

Expand Down Expand Up @@ -1890,17 +1878,9 @@ mod tests {
#[async_trait]
impl TableSource for PushDownProvider {
fn schema(&self) -> SchemaRef {
Arc::new(arrow::datatypes::Schema::new(vec![
arrow::datatypes::Field::new(
"a",
arrow::datatypes::DataType::Int32,
true,
),
arrow::datatypes::Field::new(
"b",
arrow::datatypes::DataType::Int32,
true,
),
Arc::new(Schema::new(vec![
Field::new("a", DataType::Int32, true),
Field::new("b", DataType::Int32, true),
]))
}

Expand Down

0 comments on commit 80c025b

Please sign in to comment.