diff --git a/datafusion/expr/src/logical_plan/builder.rs b/datafusion/expr/src/logical_plan/builder.rs index 6171d43b37f5..3d62bcf55d6c 100644 --- a/datafusion/expr/src/logical_plan/builder.rs +++ b/datafusion/expr/src/logical_plan/builder.rs @@ -54,6 +54,7 @@ use std::any::Any; use std::cmp::Ordering; use std::collections::{HashMap, HashSet}; use std::convert::TryFrom; +use std::iter::zip; use std::sync::Arc; /// Default table name for unnamed table @@ -1196,39 +1197,36 @@ pub fn union(left_plan: LogicalPlan, right_plan: LogicalPlan) -> Result>>()? - .to_dfschema()?; + })?; + + Ok(DFField::new( + left_field.qualifier().cloned(), + left_field.name(), + data_type, + nullable, + )) + }) + .collect::>>()? + .to_dfschema()?; let inputs = vec![left_plan, right_plan] .into_iter() - .flat_map(|p| match p { - LogicalPlan::Union(Union { inputs, .. }) => inputs, - other_plan => vec![Arc::new(other_plan)], - }) .map(|p| { let plan = coerce_plan_expr_for_schema(&p, &union_schema)?; match plan { @@ -1596,7 +1594,7 @@ mod tests { } #[test] - fn plan_builder_union_combined_single_union() -> Result<()> { + fn plan_builder_union() -> Result<()> { let plan = table_scan(Some("employee_csv"), &employee_schema(), Some(vec![3, 4]))?; @@ -1607,11 +1605,12 @@ mod tests { .union(plan.build()?)? .build()?; - // output has only one union let expected = "Union\ - \n TableScan: employee_csv projection=[state, salary]\ - \n TableScan: employee_csv projection=[state, salary]\ - \n TableScan: employee_csv projection=[state, salary]\ + \n Union\ + \n Union\ + \n TableScan: employee_csv projection=[state, salary]\ + \n TableScan: employee_csv projection=[state, salary]\ + \n TableScan: employee_csv projection=[state, salary]\ \n TableScan: employee_csv projection=[state, salary]"; assert_eq!(expected, format!("{plan:?}")); @@ -1620,7 +1619,7 @@ mod tests { } #[test] - fn plan_builder_union_distinct_combined_single_union() -> Result<()> { + fn plan_builder_union_distinct() -> Result<()> { let plan = table_scan(Some("employee_csv"), &employee_schema(), Some(vec![3, 4]))?; diff --git a/datafusion/optimizer/src/eliminate_nested_union.rs b/datafusion/optimizer/src/eliminate_nested_union.rs new file mode 100644 index 000000000000..e22c73e5794d --- /dev/null +++ b/datafusion/optimizer/src/eliminate_nested_union.rs @@ -0,0 +1,211 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Optimizer rule to replace nested unions to single union. +use crate::{OptimizerConfig, OptimizerRule}; +use datafusion_common::Result; +use datafusion_expr::logical_plan::{LogicalPlan, Union}; + +use crate::optimizer::ApplyOrder; +use datafusion_expr::expr_rewriter::coerce_plan_expr_for_schema; +use std::sync::Arc; + +#[derive(Default)] +/// An optimization rule that replaces nested unions with a single union. +pub struct EliminateNestedUnion; + +impl EliminateNestedUnion { + #[allow(missing_docs)] + pub fn new() -> Self { + Self {} + } +} + +impl OptimizerRule for EliminateNestedUnion { + fn try_optimize( + &self, + plan: &LogicalPlan, + _config: &dyn OptimizerConfig, + ) -> Result> { + // TODO: Add optimization for nested distinct unions. + match plan { + LogicalPlan::Union(Union { inputs, schema }) => { + let inputs = inputs + .iter() + .flat_map(|plan| match plan.as_ref() { + LogicalPlan::Union(Union { inputs, schema }) => inputs + .iter() + .map(|plan| { + Arc::new( + coerce_plan_expr_for_schema(plan, schema).unwrap(), + ) + }) + .collect::>(), + _ => vec![plan.clone()], + }) + .collect::>(); + + Ok(Some(LogicalPlan::Union(Union { + inputs, + schema: schema.clone(), + }))) + } + _ => Ok(None), + } + } + + fn name(&self) -> &str { + "eliminate_nested_union" + } + + fn apply_order(&self) -> Option { + Some(ApplyOrder::BottomUp) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::test::*; + use arrow::datatypes::{DataType, Field, Schema}; + use datafusion_expr::{col, logical_plan::table_scan}; + + fn schema() -> Schema { + Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new("key", DataType::Utf8, false), + Field::new("value", DataType::Float64, false), + ]) + } + + fn assert_optimized_plan_equal(plan: &LogicalPlan, expected: &str) -> Result<()> { + assert_optimized_plan_eq(Arc::new(EliminateNestedUnion::new()), plan, expected) + } + + #[test] + fn eliminate_nothing() -> Result<()> { + let plan_builder = table_scan(Some("table"), &schema(), None)?; + + let plan = plan_builder + .clone() + .union(plan_builder.clone().build()?)? + .build()?; + + let expected = "\ + Union\ + \n TableScan: table\ + \n TableScan: table"; + assert_optimized_plan_equal(&plan, expected) + } + + #[test] + fn eliminate_nested_union() -> Result<()> { + let plan_builder = table_scan(Some("table"), &schema(), None)?; + + let plan = plan_builder + .clone() + .union(plan_builder.clone().build()?)? + .union(plan_builder.clone().build()?)? + .union(plan_builder.clone().build()?)? + .build()?; + + let expected = "\ + Union\ + \n TableScan: table\ + \n TableScan: table\ + \n TableScan: table\ + \n TableScan: table"; + assert_optimized_plan_equal(&plan, expected) + } + + // We don't need to use project_with_column_index in logical optimizer, + // after LogicalPlanBuilder::union, we already have all equal expression aliases + #[test] + fn eliminate_nested_union_with_projection() -> Result<()> { + let plan_builder = table_scan(Some("table"), &schema(), None)?; + + let plan = plan_builder + .clone() + .union( + plan_builder + .clone() + .project(vec![col("id").alias("table_id"), col("key"), col("value")])? + .build()?, + )? + .union( + plan_builder + .clone() + .project(vec![col("id").alias("_id"), col("key"), col("value")])? + .build()?, + )? + .build()?; + + let expected = "Union\ + \n TableScan: table\ + \n Projection: table.id AS id, table.key, table.value\ + \n TableScan: table\ + \n Projection: table.id AS id, table.key, table.value\ + \n TableScan: table"; + assert_optimized_plan_equal(&plan, expected) + } + + #[test] + fn eliminate_nested_union_with_type_cast_projection() -> Result<()> { + let table_1 = table_scan( + Some("table_1"), + &Schema::new(vec![ + Field::new("id", DataType::Int64, false), + Field::new("key", DataType::Utf8, false), + Field::new("value", DataType::Float64, false), + ]), + None, + )?; + + let table_2 = table_scan( + Some("table_1"), + &Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new("key", DataType::Utf8, false), + Field::new("value", DataType::Float32, false), + ]), + None, + )?; + + let table_3 = table_scan( + Some("table_1"), + &Schema::new(vec![ + Field::new("id", DataType::Int16, false), + Field::new("key", DataType::Utf8, false), + Field::new("value", DataType::Float32, false), + ]), + None, + )?; + + let plan = table_1 + .union(table_2.build()?)? + .union(table_3.build()?)? + .build()?; + + let expected = "Union\ + \n TableScan: table_1\ + \n Projection: CAST(table_1.id AS Int64) AS id, table_1.key, CAST(table_1.value AS Float64) AS value\ + \n TableScan: table_1\ + \n Projection: CAST(table_1.id AS Int64) AS id, table_1.key, CAST(table_1.value AS Float64) AS value\ + \n TableScan: table_1"; + assert_optimized_plan_equal(&plan, expected) + } +} diff --git a/datafusion/optimizer/src/eliminate_one_union.rs b/datafusion/optimizer/src/eliminate_one_union.rs new file mode 100644 index 000000000000..70ee490346ff --- /dev/null +++ b/datafusion/optimizer/src/eliminate_one_union.rs @@ -0,0 +1,118 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Optimizer rule to eliminate one union. +use crate::{OptimizerConfig, OptimizerRule}; +use datafusion_common::Result; +use datafusion_expr::logical_plan::{LogicalPlan, Union}; + +use crate::optimizer::ApplyOrder; + +#[derive(Default)] +/// An optimization rule that eliminates union with one element. +pub struct EliminateOneUnion; + +impl EliminateOneUnion { + #[allow(missing_docs)] + pub fn new() -> Self { + Self {} + } +} + +impl OptimizerRule for EliminateOneUnion { + fn try_optimize( + &self, + plan: &LogicalPlan, + _config: &dyn OptimizerConfig, + ) -> Result> { + match plan { + LogicalPlan::Union(Union { inputs, .. }) if inputs.len() == 1 => { + Ok(inputs.first().map(|input| input.as_ref().clone())) + } + _ => Ok(None), + } + } + + fn name(&self) -> &str { + "eliminate_one_union" + } + + fn apply_order(&self) -> Option { + Some(ApplyOrder::TopDown) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::test::*; + use arrow::datatypes::{DataType, Field, Schema}; + use datafusion_common::ToDFSchema; + use datafusion_expr::{ + expr_rewriter::coerce_plan_expr_for_schema, + logical_plan::{table_scan, Union}, + }; + use std::sync::Arc; + + fn schema() -> Schema { + Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new("key", DataType::Utf8, false), + Field::new("value", DataType::Int32, false), + ]) + } + + fn assert_optimized_plan_equal(plan: &LogicalPlan, expected: &str) -> Result<()> { + assert_optimized_plan_eq_with_rules( + vec![Arc::new(EliminateOneUnion::new())], + plan, + expected, + ) + } + + #[test] + fn eliminate_nothing() -> Result<()> { + let plan_builder = table_scan(Some("table"), &schema(), None)?; + + let plan = plan_builder + .clone() + .union(plan_builder.clone().build()?)? + .build()?; + + let expected = "\ + Union\ + \n TableScan: table\ + \n TableScan: table"; + assert_optimized_plan_equal(&plan, expected) + } + + #[test] + fn eliminate_one_union() -> Result<()> { + let table_plan = coerce_plan_expr_for_schema( + &table_scan(Some("table"), &schema(), None)?.build()?, + &schema().to_dfschema()?, + )?; + let schema = table_plan.schema().clone(); + let single_union_plan = LogicalPlan::Union(Union { + inputs: vec![Arc::new(table_plan)], + schema, + }); + + let expected = "TableScan: table"; + assert_optimized_plan_equal(&single_union_plan, expected) + } +} diff --git a/datafusion/optimizer/src/lib.rs b/datafusion/optimizer/src/lib.rs index 1d12ca7e3950..ede0ac5c7164 100644 --- a/datafusion/optimizer/src/lib.rs +++ b/datafusion/optimizer/src/lib.rs @@ -24,6 +24,8 @@ pub mod eliminate_duplicated_expr; pub mod eliminate_filter; pub mod eliminate_join; pub mod eliminate_limit; +pub mod eliminate_nested_union; +pub mod eliminate_one_union; pub mod eliminate_outer_join; pub mod eliminate_project; pub mod extract_equijoin_predicate; diff --git a/datafusion/optimizer/src/optimizer.rs b/datafusion/optimizer/src/optimizer.rs index d3bdd47c5cb3..5231dc869875 100644 --- a/datafusion/optimizer/src/optimizer.rs +++ b/datafusion/optimizer/src/optimizer.rs @@ -24,6 +24,8 @@ use crate::eliminate_duplicated_expr::EliminateDuplicatedExpr; use crate::eliminate_filter::EliminateFilter; use crate::eliminate_join::EliminateJoin; use crate::eliminate_limit::EliminateLimit; +use crate::eliminate_nested_union::EliminateNestedUnion; +use crate::eliminate_one_union::EliminateOneUnion; use crate::eliminate_outer_join::EliminateOuterJoin; use crate::eliminate_project::EliminateProjection; use crate::extract_equijoin_predicate::ExtractEquijoinPredicate; @@ -220,6 +222,7 @@ impl Optimizer { /// Create a new optimizer using the recommended list of rules pub fn new() -> Self { let rules: Vec> = vec![ + Arc::new(EliminateNestedUnion::new()), Arc::new(SimplifyExpressions::new()), Arc::new(UnwrapCastInComparison::new()), Arc::new(ReplaceDistinctWithAggregate::new()), @@ -239,6 +242,8 @@ impl Optimizer { Arc::new(CommonSubexprEliminate::new()), Arc::new(EliminateLimit::new()), Arc::new(PropagateEmptyRelation::new()), + // Must be after PropagateEmptyRelation + Arc::new(EliminateOneUnion::new()), Arc::new(FilterNullJoinKeys::default()), Arc::new(EliminateOuterJoin::new()), // Filters can't be pushed down past Limits, we should do PushDownFilter after PushDownLimit diff --git a/datafusion/optimizer/src/propagate_empty_relation.rs b/datafusion/optimizer/src/propagate_empty_relation.rs index 4de7596b329c..040b69fc8bf3 100644 --- a/datafusion/optimizer/src/propagate_empty_relation.rs +++ b/datafusion/optimizer/src/propagate_empty_relation.rs @@ -182,12 +182,11 @@ fn empty_child(plan: &LogicalPlan) -> Result> { #[cfg(test)] mod tests { use crate::eliminate_filter::EliminateFilter; - use crate::optimizer::Optimizer; + use crate::eliminate_nested_union::EliminateNestedUnion; use crate::test::{ - assert_optimized_plan_eq, test_table_scan, test_table_scan_fields, - test_table_scan_with_name, + assert_optimized_plan_eq, assert_optimized_plan_eq_with_rules, test_table_scan, + test_table_scan_fields, test_table_scan_with_name, }; - use crate::OptimizerContext; use arrow::datatypes::{DataType, Field, Schema}; use datafusion_common::{Column, DFField, DFSchema, ScalarValue}; use datafusion_expr::logical_plan::table_scan; @@ -206,21 +205,15 @@ mod tests { plan: &LogicalPlan, expected: &str, ) -> Result<()> { - fn observe(_plan: &LogicalPlan, _rule: &dyn OptimizerRule) {} - let optimizer = Optimizer::with_rules(vec![ - Arc::new(EliminateFilter::new()), - Arc::new(PropagateEmptyRelation::new()), - ]); - let config = &mut OptimizerContext::new() - .with_max_passes(1) - .with_skip_failing_rules(false); - let optimized_plan = optimizer - .optimize(plan, config, observe) - .expect("failed to optimize plan"); - let formatted_plan = format!("{optimized_plan:?}"); - assert_eq!(formatted_plan, expected); - assert_eq!(plan.schema(), optimized_plan.schema()); - Ok(()) + assert_optimized_plan_eq_with_rules( + vec![ + Arc::new(EliminateFilter::new()), + Arc::new(EliminateNestedUnion::new()), + Arc::new(PropagateEmptyRelation::new()), + ], + plan, + expected, + ) } #[test] diff --git a/datafusion/optimizer/src/test/mod.rs b/datafusion/optimizer/src/test/mod.rs index 7d334a80b682..3eac2317b849 100644 --- a/datafusion/optimizer/src/test/mod.rs +++ b/datafusion/optimizer/src/test/mod.rs @@ -169,6 +169,25 @@ pub fn assert_optimized_plan_eq( Ok(()) } +pub fn assert_optimized_plan_eq_with_rules( + rules: Vec>, + plan: &LogicalPlan, + expected: &str, +) -> Result<()> { + fn observe(_plan: &LogicalPlan, _rule: &dyn OptimizerRule) {} + let config = &mut OptimizerContext::new() + .with_max_passes(1) + .with_skip_failing_rules(false); + let optimizer = Optimizer::with_rules(rules); + let optimized_plan = optimizer + .optimize(plan, config, observe) + .expect("failed to optimize plan"); + let formatted_plan = format!("{optimized_plan:?}"); + assert_eq!(formatted_plan, expected); + assert_eq!(plan.schema(), optimized_plan.schema()); + Ok(()) +} + pub fn assert_optimized_plan_eq_display_indent( rule: Arc, plan: &LogicalPlan, diff --git a/datafusion/sql/tests/sql_integration.rs b/datafusion/sql/tests/sql_integration.rs index b1de5a12bcd0..661890e12533 100644 --- a/datafusion/sql/tests/sql_integration.rs +++ b/datafusion/sql/tests/sql_integration.rs @@ -2061,24 +2061,6 @@ fn union_all() { quick_test(sql, expected); } -#[test] -fn union_4_combined_in_one() { - let sql = "SELECT order_id from orders - UNION ALL SELECT order_id FROM orders - UNION ALL SELECT order_id FROM orders - UNION ALL SELECT order_id FROM orders"; - let expected = "Union\ - \n Projection: orders.order_id\ - \n TableScan: orders\ - \n Projection: orders.order_id\ - \n TableScan: orders\ - \n Projection: orders.order_id\ - \n TableScan: orders\ - \n Projection: orders.order_id\ - \n TableScan: orders"; - quick_test(sql, expected); -} - #[test] fn union_with_different_column_names() { let sql = "SELECT order_id from orders UNION ALL SELECT customer_id FROM orders"; diff --git a/datafusion/sqllogictest/test_files/explain.slt b/datafusion/sqllogictest/test_files/explain.slt index 23055cd978ff..0e190a6acd62 100644 --- a/datafusion/sqllogictest/test_files/explain.slt +++ b/datafusion/sqllogictest/test_files/explain.slt @@ -184,6 +184,7 @@ logical_plan after inline_table_scan SAME TEXT AS ABOVE logical_plan after type_coercion SAME TEXT AS ABOVE logical_plan after count_wildcard_rule SAME TEXT AS ABOVE analyzed_logical_plan SAME TEXT AS ABOVE +logical_plan after eliminate_nested_union SAME TEXT AS ABOVE logical_plan after simplify_expressions SAME TEXT AS ABOVE logical_plan after unwrap_cast_in_comparison SAME TEXT AS ABOVE logical_plan after replace_distinct_aggregate SAME TEXT AS ABOVE @@ -200,6 +201,7 @@ logical_plan after eliminate_cross_join SAME TEXT AS ABOVE logical_plan after common_sub_expression_eliminate SAME TEXT AS ABOVE logical_plan after eliminate_limit SAME TEXT AS ABOVE logical_plan after propagate_empty_relation SAME TEXT AS ABOVE +logical_plan after eliminate_one_union SAME TEXT AS ABOVE logical_plan after filter_null_join_keys SAME TEXT AS ABOVE logical_plan after eliminate_outer_join SAME TEXT AS ABOVE logical_plan after push_down_limit SAME TEXT AS ABOVE @@ -213,6 +215,7 @@ Projection: simple_explain_test.a, simple_explain_test.b, simple_explain_test.c --TableScan: simple_explain_test projection=[a, b, c] logical_plan after eliminate_projection TableScan: simple_explain_test projection=[a, b, c] logical_plan after push_down_limit SAME TEXT AS ABOVE +logical_plan after eliminate_nested_union SAME TEXT AS ABOVE logical_plan after simplify_expressions SAME TEXT AS ABOVE logical_plan after unwrap_cast_in_comparison SAME TEXT AS ABOVE logical_plan after replace_distinct_aggregate SAME TEXT AS ABOVE @@ -229,6 +232,7 @@ logical_plan after eliminate_cross_join SAME TEXT AS ABOVE logical_plan after common_sub_expression_eliminate SAME TEXT AS ABOVE logical_plan after eliminate_limit SAME TEXT AS ABOVE logical_plan after propagate_empty_relation SAME TEXT AS ABOVE +logical_plan after eliminate_one_union SAME TEXT AS ABOVE logical_plan after filter_null_join_keys SAME TEXT AS ABOVE logical_plan after eliminate_outer_join SAME TEXT AS ABOVE logical_plan after push_down_limit SAME TEXT AS ABOVE