diff --git a/datafusion/optimizer/src/eliminate_distinct_nested_union.rs b/datafusion/optimizer/src/eliminate_distinct_nested_union.rs new file mode 100644 index 000000000000..85018c36ce78 --- /dev/null +++ b/datafusion/optimizer/src/eliminate_distinct_nested_union.rs @@ -0,0 +1,244 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Optimizer rule to replace nested unions to single union. +use crate::{OptimizerConfig, OptimizerRule}; +use datafusion_common::Result; +use datafusion_expr::{ + logical_plan::{LogicalPlan, Union}, + Distinct, +}; + +use crate::optimizer::ApplyOrder; +use datafusion_expr::expr_rewriter::coerce_plan_expr_for_schema; +use std::sync::Arc; + +#[derive(Default)] +/// An optimization rule that replaces nested distinct unions with a single union. +pub struct EliminateDistinctNestedUnion; + +impl EliminateDistinctNestedUnion { + #[allow(missing_docs)] + pub fn new() -> Self { + Self {} + } +} + +impl OptimizerRule for EliminateDistinctNestedUnion { + fn try_optimize( + &self, + plan: &LogicalPlan, + _config: &dyn OptimizerConfig, + ) -> Result> { + // TODO: Add optimization for nested distinct unions. + match plan { + LogicalPlan::Distinct(Distinct { input: plan }) => match plan.as_ref() { + LogicalPlan::Union(Union { inputs, schema }) => { + let inputs = inputs + .iter() + .flat_map(|plan| match plan.as_ref() { + LogicalPlan::Distinct(Distinct { input }) => { + match input.as_ref() { + LogicalPlan::Union(Union { inputs, schema }) => { + inputs + .iter() + .map(|plan| { + Arc::new( + coerce_plan_expr_for_schema( + plan, schema, + ) + .unwrap(), + ) + }) + .collect::>() + } + _ => vec![plan.clone()], + } + } + LogicalPlan::Union(Union { inputs, schema }) => inputs + .iter() + .map(|plan| { + Arc::new( + coerce_plan_expr_for_schema(plan, schema) + .unwrap(), + ) + }) + .collect::>(), + _ => vec![plan.clone()], + }) + .collect::>(); + + Ok(Some(LogicalPlan::Distinct(Distinct { + input: Arc::new(LogicalPlan::Union(Union { + inputs, + schema: schema.clone(), + })), + }))) + } + _ => Ok(None), + }, + _ => Ok(None), + } + } + + fn name(&self) -> &str { + "eliminate_nested_union" + } + + fn apply_order(&self) -> Option { + Some(ApplyOrder::BottomUp) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::test::*; + use arrow::datatypes::{DataType, Field, Schema}; + use datafusion_expr::{col, logical_plan::table_scan}; + + fn schema() -> Schema { + Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new("key", DataType::Utf8, false), + Field::new("value", DataType::Float64, false), + ]) + } + + fn assert_optimized_plan_equal(plan: &LogicalPlan, expected: &str) -> Result<()> { + assert_optimized_plan_eq( + Arc::new(EliminateDistinctNestedUnion::new()), + plan, + expected, + ) + } + + #[test] + fn eliminate_nothing() -> Result<()> { + let plan_builder = table_scan(Some("table"), &schema(), None)?; + + let plan = plan_builder + .clone() + .union_distinct(plan_builder.clone().build()?)? + .build()?; + + let expected = "Distinct:\ + \n Union\ + \n TableScan: table\ + \n TableScan: table"; + assert_optimized_plan_equal(&plan, expected) + } + + #[test] + fn eliminate_nested_distinct_union() -> Result<()> { + let plan_builder = table_scan(Some("table"), &schema(), None)?; + + let plan = plan_builder + .clone() + .union_distinct(plan_builder.clone().build()?)? + .union_distinct(plan_builder.clone().build()?)? + .union_distinct(plan_builder.clone().build()?)? + .build()?; + + let expected = "Distinct:\ + \n Union\ + \n TableScan: table\ + \n TableScan: table\ + \n TableScan: table\ + \n TableScan: table"; + assert_optimized_plan_equal(&plan, expected) + } + + // We don't need to use project_with_column_index in logical optimizer, + // after LogicalPlanBuilder::union, we already have all equal expression aliases + #[test] + fn eliminate_nested_distinct_union_with_projection() -> Result<()> { + let plan_builder = table_scan(Some("table"), &schema(), None)?; + + let plan = plan_builder + .clone() + .union_distinct( + plan_builder + .clone() + .project(vec![col("id").alias("table_id"), col("key"), col("value")])? + .build()?, + )? + .union_distinct( + plan_builder + .clone() + .project(vec![col("id").alias("_id"), col("key"), col("value")])? + .build()?, + )? + .build()?; + + let expected = "Distinct:\ + \n Union\ + \n TableScan: table\ + \n Projection: table.id AS id, table.key, table.value\ + \n TableScan: table\ + \n Projection: table.id AS id, table.key, table.value\ + \n TableScan: table"; + assert_optimized_plan_equal(&plan, expected) + } + + #[test] + fn eliminate_nested_distinct_union_with_type_cast_projection() -> Result<()> { + let table_1 = table_scan( + Some("table_1"), + &Schema::new(vec![ + Field::new("id", DataType::Int64, false), + Field::new("key", DataType::Utf8, false), + Field::new("value", DataType::Float64, false), + ]), + None, + )?; + + let table_2 = table_scan( + Some("table_1"), + &Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new("key", DataType::Utf8, false), + Field::new("value", DataType::Float32, false), + ]), + None, + )?; + + let table_3 = table_scan( + Some("table_1"), + &Schema::new(vec![ + Field::new("id", DataType::Int16, false), + Field::new("key", DataType::Utf8, false), + Field::new("value", DataType::Float32, false), + ]), + None, + )?; + + let plan = table_1 + .union_distinct(table_2.build()?)? + .union_distinct(table_3.build()?)? + .build()?; + + let expected = "Distinct:\ + \n Union\ + \n TableScan: table_1\ + \n Projection: CAST(table_1.id AS Int64) AS id, table_1.key, CAST(table_1.value AS Float64) AS value\ + \n TableScan: table_1\ + \n Projection: CAST(table_1.id AS Int64) AS id, table_1.key, CAST(table_1.value AS Float64) AS value\ + \n TableScan: table_1"; + assert_optimized_plan_equal(&plan, expected) + } +} diff --git a/datafusion/optimizer/src/lib.rs b/datafusion/optimizer/src/lib.rs index ede0ac5c7164..828c55c46b64 100644 --- a/datafusion/optimizer/src/lib.rs +++ b/datafusion/optimizer/src/lib.rs @@ -20,6 +20,7 @@ pub mod common_subexpr_eliminate; pub mod decorrelate; pub mod decorrelate_predicate_subquery; pub mod eliminate_cross_join; +pub mod eliminate_distinct_nested_union; pub mod eliminate_duplicated_expr; pub mod eliminate_filter; pub mod eliminate_join;