From 80c025bd4f0803004e8258ae80969800f88202c8 Mon Sep 17 00:00:00 2001 From: jackwener Date: Thu, 1 Dec 2022 12:34:11 +0800 Subject: [PATCH] enhance filter push through agg --- datafusion/optimizer/src/push_down_filter.rs | 48 ++++++-------------- 1 file changed, 14 insertions(+), 34 deletions(-) diff --git a/datafusion/optimizer/src/push_down_filter.rs b/datafusion/optimizer/src/push_down_filter.rs index 779e6dd6ea52..966fa92e40b4 100644 --- a/datafusion/optimizer/src/push_down_filter.rs +++ b/datafusion/optimizer/src/push_down_filter.rs @@ -17,7 +17,6 @@ use crate::utils::conjunction; use crate::{utils, OptimizerConfig, OptimizerRule}; use datafusion_common::{Column, DFSchema, DataFusionError, Result}; -use datafusion_expr::utils::exprlist_to_columns; use datafusion_expr::{ and, expr_rewriter::{replace_col, ExprRewritable, ExprRewriter}, @@ -620,19 +619,12 @@ impl OptimizerRule for PushDownFilter { }) } LogicalPlan::Aggregate(agg) => { - // An aggregate's aggregate columns are _not_ filter-commutable => collect these: - // * columns whose aggregation expression depends on - // * the aggregation columns themselves - - // construct set of columns that `aggr_expr` depends on - let mut aggr_expr_columns = HashSet::new(); - exprlist_to_columns(&agg.aggr_expr, &mut aggr_expr_columns)?; - let agg_columns = agg - .aggr_expr + // We can push down Predicate which in groupby_expr. + let group_expr_columns = agg + .group_expr .iter() - .map(|x| Ok(Column::from_name(x.display_name()?))) + .map(|e| Ok(Column::from_qualified_name(&(e.display_name()?)))) .collect::>>()?; - aggr_expr_columns.extend(agg_columns); let predicates = utils::split_conjunction_owned(utils::cnf_rewrite( filter.predicate().clone(), @@ -641,13 +633,11 @@ impl OptimizerRule for PushDownFilter { let mut keep_predicates = vec![]; let mut push_predicates = vec![]; for expr in predicates { - let columns = expr.to_columns()?; - if columns.is_empty() - || columns.intersection(&aggr_expr_columns).next().is_some() - { - keep_predicates.push(expr); - } else { + let cols = expr.to_columns()?; + if cols.iter().all(|c| group_expr_columns.contains(c)) { push_predicates.push(expr); + } else { + keep_predicates.push(expr); } } @@ -864,7 +854,7 @@ mod tests { let table_scan = test_table_scan()?; let plan = LogicalPlanBuilder::from(table_scan) .aggregate(vec![col("a")], vec![sum(col("b")).alias("total_salary")])? - .filter(col("a").gt(lit(10i64)))? + .filter(col("test.a").gt(lit(10i64)))? .build()?; // filter of key aggregation is commutative let expected = "\ @@ -930,11 +920,9 @@ mod tests { // rewrite to CNF // (c = 1 OR c = 1) [can pushDown] AND (c = 1 OR b > 3) AND (b > 2 OR C = 1) AND (b > 2 OR b > 3) - let expected = "\ - Filter: (test.c = Int64(1) OR b > Int64(3)) AND (b > Int64(2) OR test.c = Int64(1)) AND (b > Int64(2) OR b > Int64(3))\ + let expected = "Filter: (test.c = Int64(1) OR test.c = Int64(1)) AND (test.c = Int64(1) OR b > Int64(3)) AND (b > Int64(2) OR test.c = Int64(1)) AND (b > Int64(2) OR b > Int64(3))\ \n Aggregate: groupBy=[[test.a]], aggr=[[SUM(test.b) AS b]]\ - \n Filter: test.c = Int64(1) OR test.c = Int64(1)\ - \n TableScan: test"; + \n TableScan: test"; assert_optimized_plan_eq(&plan, expected) } @@ -1890,17 +1878,9 @@ mod tests { #[async_trait] impl TableSource for PushDownProvider { fn schema(&self) -> SchemaRef { - Arc::new(arrow::datatypes::Schema::new(vec![ - arrow::datatypes::Field::new( - "a", - arrow::datatypes::DataType::Int32, - true, - ), - arrow::datatypes::Field::new( - "b", - arrow::datatypes::DataType::Int32, - true, - ), + Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Int32, true), ])) }