From f9dddad83c14a5afe42a72c8883b28137adbc6c0 Mon Sep 17 00:00:00 2001 From: AssHero Date: Sun, 29 May 2022 18:37:32 +0800 Subject: [PATCH] if none columns in window expr are needed, remove the window exprs (#2634) * if none columns in window expr are needed, remove the window exprs * add test case for windo expr eliminate --- .../src/optimizer/projection_push_down.rs | 12 ++ datafusion/core/tests/sql/window.rs | 135 ++++++++++++++++++ 2 files changed, 147 insertions(+) diff --git a/datafusion/core/src/optimizer/projection_push_down.rs b/datafusion/core/src/optimizer/projection_push_down.rs index 494b7dd64044..4feef4f99057 100644 --- a/datafusion/core/src/optimizer/projection_push_down.rs +++ b/datafusion/core/src/optimizer/projection_push_down.rs @@ -277,6 +277,18 @@ fn optimize_plan( })?; } + // none columns in window expr are needed, remove the window expr + if new_window_expr.is_empty() { + return LogicalPlanBuilder::from(optimize_plan( + _optimizer, + input, + required_columns, + true, + _execution_props, + )?) + .build(); + }; + // for all the retained window expr, find their sort expressions if any, and retain these exprlist_to_columns( &find_sort_exprs(&new_window_expr), diff --git a/datafusion/core/tests/sql/window.rs b/datafusion/core/tests/sql/window.rs index 7e36177fde9a..bdbc77067ebd 100644 --- a/datafusion/core/tests/sql/window.rs +++ b/datafusion/core/tests/sql/window.rs @@ -298,3 +298,138 @@ async fn window_partition_by_order_by() -> Result<()> { assert_batches_eq!(expected, &results); Ok(()) } + +#[tokio::test] +async fn window_expr_eliminate() -> Result<()> { + let ctx = SessionContext::new(); + + // window expr is not referenced anywhere, eliminate it. + let sql = "WITH _sample_data AS ( + SELECT 1 as a, 'aa' AS b + UNION ALL + SELECT 3 as a, 'aa' AS b + UNION ALL + SELECT 5 as a, 'bb' AS b + UNION ALL + SELECT 7 as a, 'bb' AS b + ), _data2 AS ( + SELECT + row_number() OVER (PARTITION BY s.b ORDER BY s.a) AS seq, + s.a, + s.b + FROM _sample_data s + ) + SELECT d.b, MAX(d.a) AS max_a + FROM _data2 d + GROUP BY d.b + ORDER BY d.b;"; + + let msg = format!("Creating logical plan for '{}'", sql); + let plan = ctx + .create_logical_plan(&("explain ".to_owned() + sql)) + .expect(&msg); + let state = ctx.state.read().clone(); + let plan = state.optimize(&plan)?; + let expected = vec![ + "Explain [plan_type:Utf8, plan:Utf8]", + " Sort: #d.b ASC NULLS LAST [b:Utf8, max_a:Int64;N]", + " Projection: #d.b, #MAX(d.a) AS max_a [b:Utf8, max_a:Int64;N]", + " Aggregate: groupBy=[[#d.b]], aggr=[[MAX(#d.a)]] [b:Utf8, MAX(d.a):Int64;N]", + " Projection: #_data2.a, #_data2.b, alias=d [a:Int64, b:Utf8]", + " Projection: #s.a, #s.b, alias=_data2 [a:Int64, b:Utf8]", + " Projection: #a, #b, alias=s [a:Int64, b:Utf8]", + " Union [a:Int64, b:Utf8]", + " Projection: Int64(1) AS a, Utf8(\"aa\") AS b [a:Int64, b:Utf8]", + " EmptyRelation []", + " Projection: Int64(3) AS a, Utf8(\"aa\") AS b [a:Int64, b:Utf8]", + " EmptyRelation []", + " Projection: Int64(5) AS a, Utf8(\"bb\") AS b [a:Int64, b:Utf8]", + " EmptyRelation []", + " Projection: Int64(7) AS a, Utf8(\"bb\") AS b [a:Int64, b:Utf8]", + " EmptyRelation []", + ]; + let formatted = plan.display_indent_schema().to_string(); + let actual: Vec<&str> = formatted.trim().lines().collect(); + assert_eq!( + expected, actual, + "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n", + expected, actual + ); + + let results = execute_to_batches(&ctx, sql).await; + let expected = vec![ + "+----+-------+", + "| b | max_a |", + "+----+-------+", + "| aa | 3 |", + "| bb | 7 |", + "+----+-------+", + ]; + + assert_batches_eq!(expected, &results); + + // window expr is referenced by the output, keep it + let sql = "WITH _sample_data AS ( + SELECT 1 as a, 'aa' AS b + UNION ALL + SELECT 3 as a, 'aa' AS b + UNION ALL + SELECT 5 as a, 'bb' AS b + UNION ALL + SELECT 7 as a, 'bb' AS b + ), _data2 AS ( + SELECT + row_number() OVER (PARTITION BY s.b ORDER BY s.a) AS seq, + s.a, + s.b + FROM _sample_data s + ) + SELECT d.b, MAX(d.a) AS max_a, max(d.seq) + FROM _data2 d + GROUP BY d.b + ORDER BY d.b;"; + + let plan = ctx + .create_logical_plan(&("explain ".to_owned() + sql)) + .expect(&msg); + let plan = state.optimize(&plan)?; + let expected = vec![ + "Explain [plan_type:Utf8, plan:Utf8]", + " Sort: #d.b ASC NULLS LAST [b:Utf8, max_a:Int64;N, MAX(d.seq):UInt64;N]", + " Projection: #d.b, #MAX(d.a) AS max_a, #MAX(d.seq) [b:Utf8, max_a:Int64;N, MAX(d.seq):UInt64;N]", + " Aggregate: groupBy=[[#d.b]], aggr=[[MAX(#d.a), MAX(#d.seq)]] [b:Utf8, MAX(d.a):Int64;N, MAX(d.seq):UInt64;N]", + " Projection: #_data2.seq, #_data2.a, #_data2.b, alias=d [seq:UInt64;N, a:Int64, b:Utf8]", + " Projection: #ROW_NUMBER() PARTITION BY [#s.b] ORDER BY [#s.a ASC NULLS LAST] AS seq, #s.a, #s.b, alias=_data2 [seq:UInt64;N, a:Int64, b:Utf8]", + " WindowAggr: windowExpr=[[ROW_NUMBER() PARTITION BY [#s.b] ORDER BY [#s.a ASC NULLS LAST]]] [ROW_NUMBER() PARTITION BY [#s.b] ORDER BY [#s.a ASC NULLS LAST]:UInt64;N, a:Int64, b:Utf8]", + " Projection: #a, #b, alias=s [a:Int64, b:Utf8]", + " Union [a:Int64, b:Utf8]", + " Projection: Int64(1) AS a, Utf8(\"aa\") AS b [a:Int64, b:Utf8]", + " EmptyRelation []", + " Projection: Int64(3) AS a, Utf8(\"aa\") AS b [a:Int64, b:Utf8]", + " EmptyRelation []", + " Projection: Int64(5) AS a, Utf8(\"bb\") AS b [a:Int64, b:Utf8]", + " EmptyRelation []", + " Projection: Int64(7) AS a, Utf8(\"bb\") AS b [a:Int64, b:Utf8]", + " EmptyRelation []", + ]; + let formatted = plan.display_indent_schema().to_string(); + let actual: Vec<&str> = formatted.trim().lines().collect(); + assert_eq!( + expected, actual, + "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n", + expected, actual + ); + + let results = execute_to_batches(&ctx, sql).await; + let expected = vec![ + "+----+-------+------------+", + "| b | max_a | MAX(d.seq) |", + "+----+-------+------------+", + "| aa | 3 | 2 |", + "| bb | 7 | 2 |", + "+----+-------+------------+", + ]; + + assert_batches_eq!(expected, &results); + Ok(()) +}