From e03ea0d1c105ae1bf0dd0895082ceec7e60b4faf Mon Sep 17 00:00:00 2001 From: Sergei Grebnov Date: Mon, 30 Sep 2024 11:11:26 -0700 Subject: [PATCH 1/7] Support unparsing plans with both Aggregation and Window functions (#35) --- datafusion/sql/src/unparser/plan.rs | 25 +++-- datafusion/sql/src/unparser/utils.rs | 118 +++++++++++++++------- datafusion/sql/tests/cases/plan_to_sql.rs | 8 +- 3 files changed, 106 insertions(+), 45 deletions(-) diff --git a/datafusion/sql/src/unparser/plan.rs b/datafusion/sql/src/unparser/plan.rs index a76e26aa7d98..c4fcbb2d6458 100644 --- a/datafusion/sql/src/unparser/plan.rs +++ b/datafusion/sql/src/unparser/plan.rs @@ -38,7 +38,10 @@ use super::{ rewrite_plan_for_sort_on_non_projected_fields, subquery_alias_inner_query_and_columns, TableAliasRewriter, }, - utils::{find_agg_node_within_select, unproject_window_exprs, AggVariant}, + utils::{ + find_agg_node_within_select, find_window_nodes_within_select, + unproject_window_exprs, + }, Unparser, }; @@ -172,13 +175,17 @@ impl Unparser<'_> { p: &Projection, select: &mut SelectBuilder, ) -> Result<()> { - match find_agg_node_within_select(plan, None, true) { - Some(AggVariant::Aggregate(agg)) => { + match ( + find_agg_node_within_select(plan, true), + find_window_nodes_within_select(plan, None, true), + ) { + (Some(agg), window) => { + let window_option = window.as_deref(); let items = p .expr .iter() .map(|proj_expr| { - let unproj = unproject_agg_exprs(proj_expr, agg)?; + let unproj = unproject_agg_exprs(proj_expr, agg, window_option)?; self.select_item_to_sql(&unproj) }) .collect::>>()?; @@ -192,7 +199,7 @@ impl Unparser<'_> { vec![], )); } - Some(AggVariant::Window(window)) => { + (None, Some(window)) => { let items = p .expr .iter() @@ -204,7 +211,7 @@ impl Unparser<'_> { select.projection(items); } - None => { + _ => { let items = p .expr .iter() @@ -287,10 +294,10 @@ impl Unparser<'_> { self.select_to_sql_recursively(p.input.as_ref(), query, select, relation) } LogicalPlan::Filter(filter) => { - if let Some(AggVariant::Aggregate(agg)) = - find_agg_node_within_select(plan, None, select.already_projected()) + if let Some(agg) = + find_agg_node_within_select(plan, select.already_projected()) { - let unprojected = unproject_agg_exprs(&filter.predicate, agg)?; + let unprojected = unproject_agg_exprs(&filter.predicate, agg, None)?; let filter_expr = self.expr_to_sql(&unprojected)?; select.having(Some(filter_expr)); } else { diff --git a/datafusion/sql/src/unparser/utils.rs b/datafusion/sql/src/unparser/utils.rs index c1b3fe18f7e7..399bb876b3d0 100644 --- a/datafusion/sql/src/unparser/utils.rs +++ b/datafusion/sql/src/unparser/utils.rs @@ -18,58 +18,81 @@ use datafusion_common::{ internal_err, tree_node::{Transformed, TreeNode}, - Result, + Column, Result, }; use datafusion_expr::{Aggregate, Expr, LogicalPlan, Window}; -/// One of the possible aggregation plans which can be found within a single select query. -pub(crate) enum AggVariant<'a> { - Aggregate(&'a Aggregate), - Window(Vec<&'a Window>), +/// Recursively searches children of [LogicalPlan] to find an Aggregate node if exists +/// prior to encountering a Join, TableScan, or a nested subquery (derived table factor). +/// If an Aggregate or node is not found prior to this or at all before reaching the end +/// of the tree, None is returned. +pub(crate) fn find_agg_node_within_select( + plan: &LogicalPlan, + already_projected: bool, +) -> Option<&Aggregate> { + // Note that none of the nodes that have a corresponding node can have more + // than 1 input node. E.g. Projection / Filter always have 1 input node. + let input = plan.inputs(); + let input = if input.len() > 1 { + return None; + } else { + input.first()? + }; + // Agg nodes explicitly return immediately with a single node + if let LogicalPlan::Aggregate(agg) = input { + Some(agg) + } else if let LogicalPlan::TableScan(_) = input { + None + } else if let LogicalPlan::Projection(_) = input { + if already_projected { + None + } else { + find_agg_node_within_select(input, true) + } + } else { + find_agg_node_within_select(input, already_projected) + } } -/// Recursively searches children of [LogicalPlan] to find an Aggregate or window node if one exists +/// Recursively searches children of [LogicalPlan] to find Window nodes if exist /// prior to encountering a Join, TableScan, or a nested subquery (derived table factor). -/// If an Aggregate or window node is not found prior to this or at all before reaching the end -/// of the tree, None is returned. It is assumed that a Window and Aggregate node cannot both -/// be found in a single select query. -pub(crate) fn find_agg_node_within_select<'a>( +/// If Window node is not found prior to this or at all before reaching the end +/// of the tree, None is returned. +pub(crate) fn find_window_nodes_within_select<'a>( plan: &'a LogicalPlan, - mut prev_windows: Option>, + mut prev_windows: Option>, already_projected: bool, -) -> Option> { - // Note that none of the nodes that have a corresponding agg node can have more +) -> Option> { + // Note that none of the nodes that have a corresponding node can have more // than 1 input node. E.g. Projection / Filter always have 1 input node. let input = plan.inputs(); let input = if input.len() > 1 { - return None; + return prev_windows; } else { input.first()? }; - // Agg nodes explicitly return immediately with a single node // Window nodes accumulate in a vec until encountering a TableScan or 2nd projection match input { - LogicalPlan::Aggregate(agg) => Some(AggVariant::Aggregate(agg)), LogicalPlan::Window(window) => { prev_windows = match &mut prev_windows { - Some(AggVariant::Window(windows)) => { + Some(windows) => { windows.push(window); prev_windows } - _ => Some(AggVariant::Window(vec![window])), + _ => Some(vec![window]), }; - find_agg_node_within_select(input, prev_windows, already_projected) + find_window_nodes_within_select(input, prev_windows, already_projected) } LogicalPlan::Projection(_) => { if already_projected { prev_windows } else { - find_agg_node_within_select(input, prev_windows, true) + find_window_nodes_within_select(input, prev_windows, true) } } LogicalPlan::TableScan(_) => prev_windows, - _ => find_agg_node_within_select(input, prev_windows, already_projected), + _ => find_window_nodes_within_select(input, prev_windows, already_projected), } } @@ -78,19 +101,30 @@ pub(crate) fn find_agg_node_within_select<'a>( /// /// For example, if expr contains the column expr "COUNT(*)" it will be transformed /// into an actual aggregate expression COUNT(*) as identified in the aggregate node. -pub(crate) fn unproject_agg_exprs(expr: &Expr, agg: &Aggregate) -> Result { +pub(crate) fn unproject_agg_exprs( + expr: &Expr, + agg: &Aggregate, + windows: Option<&[&Window]>, +) -> Result { expr.clone() .transform(|sub_expr| { if let Expr::Column(c) = sub_expr { - // find the column in the agg schema - if let Ok(n) = agg.schema.index_of_column(&c) { - let unprojected_expr = agg - .group_expr - .iter() - .chain(agg.aggr_expr.iter()) - .nth(n) - .unwrap(); + if let Some(unprojected_expr) = find_agg_expr(agg, &c) { Ok(Transformed::yes(unprojected_expr.clone())) + } else if let Some(mut unprojected_expr) = + windows.and_then(|w| find_window_expr(w, &c.name).cloned()) + { + if let Expr::WindowFunction(func) = &mut unprojected_expr { + // Window function can contain aggregation column, for ex 'avg(sum(ss_sales_price)) over ..' that needs to be unprojected + for arg in &mut func.args { + if let Expr::Column(c) = arg { + if let Some(expr) = find_agg_expr(agg, c) { + *arg = expr.clone(); + } + } + } + } + Ok(Transformed::yes(unprojected_expr)) } else { internal_err!( "Tried to unproject agg expr not found in provided Aggregate!" @@ -112,11 +146,7 @@ pub(crate) fn unproject_window_exprs(expr: &Expr, windows: &[&Window]) -> Result expr.clone() .transform(|sub_expr| { if let Expr::Column(c) = sub_expr { - if let Some(unproj) = windows - .iter() - .flat_map(|w| w.window_expr.iter()) - .find(|window_expr| window_expr.schema_name().to_string() == c.name) - { + if let Some(unproj) = find_window_expr(windows, &c.name) { Ok(Transformed::yes(unproj.clone())) } else { Ok(Transformed::no(Expr::Column(c))) @@ -127,3 +157,21 @@ pub(crate) fn unproject_window_exprs(expr: &Expr, windows: &[&Window]) -> Result }) .map(|e| e.data) } + +fn find_agg_expr<'a>(agg: &'a Aggregate, column: &Column) -> Option<&'a Expr> { + if let Ok(index) = agg.schema.index_of_column(column) { + agg.group_expr.iter().chain(agg.aggr_expr.iter()).nth(index) + } else { + None + } +} + +fn find_window_expr<'a>( + windows: &'a [&'a Window], + column_name: &'a str, +) -> Option<&'a Expr> { + windows + .iter() + .flat_map(|w| w.window_expr.iter()) + .find(|expr| expr.schema_name().to_string() == column_name) +} diff --git a/datafusion/sql/tests/cases/plan_to_sql.rs b/datafusion/sql/tests/cases/plan_to_sql.rs index 49f4720ed137..1725d471d067 100644 --- a/datafusion/sql/tests/cases/plan_to_sql.rs +++ b/datafusion/sql/tests/cases/plan_to_sql.rs @@ -149,7 +149,12 @@ fn roundtrip_statement() -> Result<()> { "WITH w1 AS (SELECT 'a' as col), w2 AS (SELECT 'b' as col), w3 as (SELECT 'c' as col) SELECT * FROM w1 UNION ALL SELECT * FROM w2 UNION ALL SELECT * FROM w3", "WITH w1 AS (SELECT 'a' as col), w2 AS (SELECT 'b' as col), w3 as (SELECT 'c' as col), w4 as (SELECT 'd' as col) SELECT * FROM w1 UNION ALL SELECT * FROM w2 UNION ALL SELECT * FROM w3 UNION ALL SELECT * FROM w4", "WITH w1 AS (SELECT 'a' as col), w2 AS (SELECT 'b' as col) SELECT * FROM w1 JOIN w2 ON w1.col = w2.col UNION ALL SELECT * FROM w1 JOIN w2 ON w1.col = w2.col UNION ALL SELECT * FROM w1 JOIN w2 ON w1.col = w2.col", - ]; + r#"SELECT id, first_name, + SUM(id) AS total_sum, + SUM(id) OVER (PARTITION BY first_name ROWS BETWEEN 5 PRECEDING AND 2 FOLLOWING) AS moving_sum, + MAX(SUM(id)) OVER (PARTITION BY first_name ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS max_total + FROM person GROUP BY id, first_name"#, + ]; // For each test sql string, we transform as follows: // sql -> ast::Statement (s1) -> LogicalPlan (p1) -> ast::Statement (s2) -> LogicalPlan (p2) @@ -164,6 +169,7 @@ fn roundtrip_statement() -> Result<()> { let state = MockSessionState::default() .with_aggregate_function(sum_udaf()) .with_aggregate_function(count_udaf()) + .with_aggregate_function(max_udaf()) .with_expr_planner(Arc::new(CoreFunctionPlanner::default())); let context = MockContextProvider { state }; let sql_to_rel = SqlToRel::new(&context); From e4288c251be0084ae92691c41214f7e557929a5a Mon Sep 17 00:00:00 2001 From: sgrebnov Date: Tue, 1 Oct 2024 11:12:22 -0700 Subject: [PATCH 2/7] Fix unparsing for aggregation grouping sets --- datafusion/sql/src/unparser/utils.rs | 25 ++++++++++++++++++------- 1 file changed, 18 insertions(+), 7 deletions(-) diff --git a/datafusion/sql/src/unparser/utils.rs b/datafusion/sql/src/unparser/utils.rs index 399bb876b3d0..2200b86fc5b9 100644 --- a/datafusion/sql/src/unparser/utils.rs +++ b/datafusion/sql/src/unparser/utils.rs @@ -20,7 +20,9 @@ use datafusion_common::{ tree_node::{Transformed, TreeNode}, Column, Result, }; -use datafusion_expr::{Aggregate, Expr, LogicalPlan, Window}; +use datafusion_expr::{ + utils::grouping_set_to_exprlist, Aggregate, Expr, LogicalPlan, Window, +}; /// Recursively searches children of [LogicalPlan] to find an Aggregate node if exists /// prior to encountering a Join, TableScan, or a nested subquery (derived table factor). @@ -109,16 +111,16 @@ pub(crate) fn unproject_agg_exprs( expr.clone() .transform(|sub_expr| { if let Expr::Column(c) = sub_expr { - if let Some(unprojected_expr) = find_agg_expr(agg, &c) { + if let Some(unprojected_expr) = find_agg_expr(agg, &c)? { Ok(Transformed::yes(unprojected_expr.clone())) } else if let Some(mut unprojected_expr) = windows.and_then(|w| find_window_expr(w, &c.name).cloned()) { if let Expr::WindowFunction(func) = &mut unprojected_expr { - // Window function can contain aggregation column, for ex 'avg(sum(ss_sales_price)) over ..' that needs to be unprojected + // Window function can contain an aggregation column, e.g., 'avg(sum(ss_sales_price)) over ...' that needs to be unprojected for arg in &mut func.args { if let Expr::Column(c) = arg { - if let Some(expr) = find_agg_expr(agg, c) { + if let Some(expr) = find_agg_expr(agg, c)? { *arg = expr.clone(); } } @@ -158,11 +160,20 @@ pub(crate) fn unproject_window_exprs(expr: &Expr, windows: &[&Window]) -> Result .map(|e| e.data) } -fn find_agg_expr<'a>(agg: &'a Aggregate, column: &Column) -> Option<&'a Expr> { +fn find_agg_expr<'a>(agg: &'a Aggregate, column: &Column) -> Result> { if let Ok(index) = agg.schema.index_of_column(column) { - agg.group_expr.iter().chain(agg.aggr_expr.iter()).nth(index) + if matches!(agg.group_expr.as_slice(), [Expr::GroupingSet(_)]) { + // For grouping set expr, we must operate by expression list from the grouping set + let grouping_expr = grouping_set_to_exprlist(agg.group_expr.as_slice())?; + return Ok(grouping_expr + .into_iter() + .chain(agg.aggr_expr.iter()) + .nth(index)); + } else { + return Ok(agg.group_expr.iter().chain(agg.aggr_expr.iter()).nth(index)); + }; } else { - None + Ok(None) } } From ca263ee9e583d250d4606c8fb93318b77420cf71 Mon Sep 17 00:00:00 2001 From: sgrebnov Date: Tue, 1 Oct 2024 11:33:44 -0700 Subject: [PATCH 3/7] Add test for grouping set unparsing --- datafusion/sql/src/unparser/utils.rs | 2 +- datafusion/sql/tests/cases/plan_to_sql.rs | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/datafusion/sql/src/unparser/utils.rs b/datafusion/sql/src/unparser/utils.rs index 2200b86fc5b9..85a77a087690 100644 --- a/datafusion/sql/src/unparser/utils.rs +++ b/datafusion/sql/src/unparser/utils.rs @@ -129,7 +129,7 @@ pub(crate) fn unproject_agg_exprs( Ok(Transformed::yes(unprojected_expr)) } else { internal_err!( - "Tried to unproject agg expr not found in provided Aggregate!" + "Tried to unproject agg expr for column '{}' that was not found in the provided Aggregate!", &c.name ) } } else { diff --git a/datafusion/sql/tests/cases/plan_to_sql.rs b/datafusion/sql/tests/cases/plan_to_sql.rs index 1725d471d067..a44b3e561f33 100644 --- a/datafusion/sql/tests/cases/plan_to_sql.rs +++ b/datafusion/sql/tests/cases/plan_to_sql.rs @@ -154,7 +154,8 @@ fn roundtrip_statement() -> Result<()> { SUM(id) OVER (PARTITION BY first_name ROWS BETWEEN 5 PRECEDING AND 2 FOLLOWING) AS moving_sum, MAX(SUM(id)) OVER (PARTITION BY first_name ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS max_total FROM person GROUP BY id, first_name"#, - ]; + "SELECT id, first_name, last_name, SUM(id) AS total_sum FROM person GROUP BY ROLLUP(id, first_name, last_name)", + ]; // For each test sql string, we transform as follows: // sql -> ast::Statement (s1) -> LogicalPlan (p1) -> ast::Statement (s2) -> LogicalPlan (p2) From ae9af9507ea7b5bdde0e459c581d6bff5f7a817b Mon Sep 17 00:00:00 2001 From: Sergei Grebnov Date: Wed, 2 Oct 2024 23:22:29 -0700 Subject: [PATCH 4/7] Update datafusion/sql/src/unparser/utils.rs Co-authored-by: Jax Liu --- datafusion/sql/src/unparser/utils.rs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/datafusion/sql/src/unparser/utils.rs b/datafusion/sql/src/unparser/utils.rs index 85a77a087690..25669f52c42d 100644 --- a/datafusion/sql/src/unparser/utils.rs +++ b/datafusion/sql/src/unparser/utils.rs @@ -165,13 +165,13 @@ fn find_agg_expr<'a>(agg: &'a Aggregate, column: &Column) -> Result Date: Wed, 2 Oct 2024 23:23:09 -0700 Subject: [PATCH 5/7] Update datafusion/sql/src/unparser/utils.rs Co-authored-by: Jax Liu --- datafusion/sql/src/unparser/utils.rs | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/datafusion/sql/src/unparser/utils.rs b/datafusion/sql/src/unparser/utils.rs index 25669f52c42d..6efa549d99de 100644 --- a/datafusion/sql/src/unparser/utils.rs +++ b/datafusion/sql/src/unparser/utils.rs @@ -118,13 +118,15 @@ pub(crate) fn unproject_agg_exprs( { if let Expr::WindowFunction(func) = &mut unprojected_expr { // Window function can contain an aggregation column, e.g., 'avg(sum(ss_sales_price)) over ...' that needs to be unprojected - for arg in &mut func.args { + func.args.iter_mut().try_for_each(|arg| { if let Expr::Column(c) = arg { if let Some(expr) = find_agg_expr(agg, c)? { *arg = expr.clone(); } } - } + Ok::<(), DataFusionError>(()) + })?; + } } Ok(Transformed::yes(unprojected_expr)) } else { From 2da2dbd50ba1e706eba959ba393fd7e47ad644c1 Mon Sep 17 00:00:00 2001 From: sgrebnov Date: Wed, 2 Oct 2024 23:29:54 -0700 Subject: [PATCH 6/7] Update --- datafusion/sql/src/unparser/utils.rs | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/datafusion/sql/src/unparser/utils.rs b/datafusion/sql/src/unparser/utils.rs index 6efa549d99de..0059aba25738 100644 --- a/datafusion/sql/src/unparser/utils.rs +++ b/datafusion/sql/src/unparser/utils.rs @@ -18,7 +18,7 @@ use datafusion_common::{ internal_err, tree_node::{Transformed, TreeNode}, - Column, Result, + Column, DataFusionError, Result, }; use datafusion_expr::{ utils::grouping_set_to_exprlist, Aggregate, Expr, LogicalPlan, Window, @@ -127,7 +127,6 @@ pub(crate) fn unproject_agg_exprs( Ok::<(), DataFusionError>(()) })?; } - } Ok(Transformed::yes(unprojected_expr)) } else { internal_err!( From 63c2ca5eefd7313f2fecb5098d8f6cfea0c27d4a Mon Sep 17 00:00:00 2001 From: sgrebnov Date: Thu, 3 Oct 2024 00:18:07 -0700 Subject: [PATCH 7/7] More tests --- datafusion/sql/tests/cases/plan_to_sql.rs | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/datafusion/sql/tests/cases/plan_to_sql.rs b/datafusion/sql/tests/cases/plan_to_sql.rs index a44b3e561f33..903d4e28520b 100644 --- a/datafusion/sql/tests/cases/plan_to_sql.rs +++ b/datafusion/sql/tests/cases/plan_to_sql.rs @@ -153,8 +153,22 @@ fn roundtrip_statement() -> Result<()> { SUM(id) AS total_sum, SUM(id) OVER (PARTITION BY first_name ROWS BETWEEN 5 PRECEDING AND 2 FOLLOWING) AS moving_sum, MAX(SUM(id)) OVER (PARTITION BY first_name ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS max_total - FROM person GROUP BY id, first_name"#, - "SELECT id, first_name, last_name, SUM(id) AS total_sum FROM person GROUP BY ROLLUP(id, first_name, last_name)", + FROM person JOIN orders ON person.id = orders.customer_id GROUP BY id, first_name"#, + r#"SELECT id, first_name, + SUM(id) AS total_sum, + SUM(id) OVER (PARTITION BY first_name ROWS BETWEEN 5 PRECEDING AND 2 FOLLOWING) AS moving_sum, + MAX(SUM(id)) OVER (PARTITION BY first_name ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS max_total + FROM (SELECT id, first_name from person) person JOIN (SELECT customer_id FROM orders) orders ON person.id = orders.customer_id GROUP BY id, first_name"#, + r#"SELECT id, first_name, last_name, customer_id, SUM(id) AS total_sum + FROM person + JOIN orders ON person.id = orders.customer_id + GROUP BY ROLLUP(id, first_name, last_name, customer_id)"#, + r#"SELECT id, first_name, last_name, + SUM(id) AS total_sum, + COUNT(*) AS total_count, + SUM(id) OVER (ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS running_total + FROM person + GROUP BY GROUPING SETS ((id, first_name, last_name), (first_name, last_name), (last_name))"#, ]; // For each test sql string, we transform as follows: