Skip to content

Commit

Permalink
Move nested union optimization from plan builder to logical optimizer (
Browse files Browse the repository at this point in the history
…#7695)

* Add naive implementation of eliminate_nested_union

* Remove union optimization from LogicalPlanBuilder::union

* Fix propagate_union_children_different_schema test

* Add implementation of eliminate_one_union

* Simplified eliminate_nested_union test

* Fix

* clippy

---------

Co-authored-by: Evgeny Maruschenko <evgeny.maruschenko@x5.ru>
Co-authored-by: Andrew Lamb <andrew@nerdnetworks.org>
  • Loading branch information
3 people committed Oct 10, 2023
1 parent a0c5aff commit 704e034
Show file tree
Hide file tree
Showing 9 changed files with 404 additions and 71 deletions.
67 changes: 33 additions & 34 deletions datafusion/expr/src/logical_plan/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ use std::any::Any;
use std::cmp::Ordering;
use std::collections::{HashMap, HashSet};
use std::convert::TryFrom;
use std::iter::zip;
use std::sync::Arc;

/// Default table name for unnamed table
Expand Down Expand Up @@ -1196,39 +1197,36 @@ pub fn union(left_plan: LogicalPlan, right_plan: LogicalPlan) -> Result<LogicalP
}

// create union schema
let union_schema = (0..left_col_num)
.map(|i| {
let left_field = left_plan.schema().field(i);
let right_field = right_plan.schema().field(i);
let nullable = left_field.is_nullable() || right_field.is_nullable();
let data_type =
comparison_coercion(left_field.data_type(), right_field.data_type())
.ok_or_else(|| {
DataFusionError::Plan(format!(
"UNION Column {} (type: {}) is not compatible with column {} (type: {})",
right_field.name(),
right_field.data_type(),
left_field.name(),
left_field.data_type()
))
})?;

Ok(DFField::new(
left_field.qualifier().cloned(),
let union_schema = zip(
left_plan.schema().fields().iter(),
right_plan.schema().fields().iter(),
)
.map(|(left_field, right_field)| {
let nullable = left_field.is_nullable() || right_field.is_nullable();
let data_type =
comparison_coercion(left_field.data_type(), right_field.data_type())
.ok_or_else(|| {
DataFusionError::Plan(format!(
"UNION Column {} (type: {}) is not compatible with column {} (type: {})",
right_field.name(),
right_field.data_type(),
left_field.name(),
data_type,
nullable,
left_field.data_type()
))
})
.collect::<Result<Vec<_>>>()?
.to_dfschema()?;
})?;

Ok(DFField::new(
left_field.qualifier().cloned(),
left_field.name(),
data_type,
nullable,
))
})
.collect::<Result<Vec<_>>>()?
.to_dfschema()?;

let inputs = vec![left_plan, right_plan]
.into_iter()
.flat_map(|p| match p {
LogicalPlan::Union(Union { inputs, .. }) => inputs,
other_plan => vec![Arc::new(other_plan)],
})
.map(|p| {
let plan = coerce_plan_expr_for_schema(&p, &union_schema)?;
match plan {
Expand Down Expand Up @@ -1596,7 +1594,7 @@ mod tests {
}

#[test]
fn plan_builder_union_combined_single_union() -> Result<()> {
fn plan_builder_union() -> Result<()> {
let plan =
table_scan(Some("employee_csv"), &employee_schema(), Some(vec![3, 4]))?;

Expand All @@ -1607,11 +1605,12 @@ mod tests {
.union(plan.build()?)?
.build()?;

// output has only one union
let expected = "Union\
\n TableScan: employee_csv projection=[state, salary]\
\n TableScan: employee_csv projection=[state, salary]\
\n TableScan: employee_csv projection=[state, salary]\
\n Union\
\n Union\
\n TableScan: employee_csv projection=[state, salary]\
\n TableScan: employee_csv projection=[state, salary]\
\n TableScan: employee_csv projection=[state, salary]\
\n TableScan: employee_csv projection=[state, salary]";

assert_eq!(expected, format!("{plan:?}"));
Expand All @@ -1620,7 +1619,7 @@ mod tests {
}

#[test]
fn plan_builder_union_distinct_combined_single_union() -> Result<()> {
fn plan_builder_union_distinct() -> Result<()> {
let plan =
table_scan(Some("employee_csv"), &employee_schema(), Some(vec![3, 4]))?;

Expand Down
211 changes: 211 additions & 0 deletions datafusion/optimizer/src/eliminate_nested_union.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,211 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

//! Optimizer rule to replace nested unions to single union.
use crate::{OptimizerConfig, OptimizerRule};
use datafusion_common::Result;
use datafusion_expr::logical_plan::{LogicalPlan, Union};

use crate::optimizer::ApplyOrder;
use datafusion_expr::expr_rewriter::coerce_plan_expr_for_schema;
use std::sync::Arc;

#[derive(Default)]
/// An optimization rule that replaces nested unions with a single union.
pub struct EliminateNestedUnion;

impl EliminateNestedUnion {
#[allow(missing_docs)]
pub fn new() -> Self {
Self {}
}
}

impl OptimizerRule for EliminateNestedUnion {
fn try_optimize(
&self,
plan: &LogicalPlan,
_config: &dyn OptimizerConfig,
) -> Result<Option<LogicalPlan>> {
// TODO: Add optimization for nested distinct unions.
match plan {
LogicalPlan::Union(Union { inputs, schema }) => {
let inputs = inputs
.iter()
.flat_map(|plan| match plan.as_ref() {
LogicalPlan::Union(Union { inputs, schema }) => inputs
.iter()
.map(|plan| {
Arc::new(
coerce_plan_expr_for_schema(plan, schema).unwrap(),
)
})
.collect::<Vec<_>>(),
_ => vec![plan.clone()],
})
.collect::<Vec<_>>();

Ok(Some(LogicalPlan::Union(Union {
inputs,
schema: schema.clone(),
})))
}
_ => Ok(None),
}
}

fn name(&self) -> &str {
"eliminate_nested_union"
}

fn apply_order(&self) -> Option<ApplyOrder> {
Some(ApplyOrder::BottomUp)
}
}

#[cfg(test)]
mod tests {
use super::*;
use crate::test::*;
use arrow::datatypes::{DataType, Field, Schema};
use datafusion_expr::{col, logical_plan::table_scan};

fn schema() -> Schema {
Schema::new(vec![
Field::new("id", DataType::Int32, false),
Field::new("key", DataType::Utf8, false),
Field::new("value", DataType::Float64, false),
])
}

fn assert_optimized_plan_equal(plan: &LogicalPlan, expected: &str) -> Result<()> {
assert_optimized_plan_eq(Arc::new(EliminateNestedUnion::new()), plan, expected)
}

#[test]
fn eliminate_nothing() -> Result<()> {
let plan_builder = table_scan(Some("table"), &schema(), None)?;

let plan = plan_builder
.clone()
.union(plan_builder.clone().build()?)?
.build()?;

let expected = "\
Union\
\n TableScan: table\
\n TableScan: table";
assert_optimized_plan_equal(&plan, expected)
}

#[test]
fn eliminate_nested_union() -> Result<()> {
let plan_builder = table_scan(Some("table"), &schema(), None)?;

let plan = plan_builder
.clone()
.union(plan_builder.clone().build()?)?
.union(plan_builder.clone().build()?)?
.union(plan_builder.clone().build()?)?
.build()?;

let expected = "\
Union\
\n TableScan: table\
\n TableScan: table\
\n TableScan: table\
\n TableScan: table";
assert_optimized_plan_equal(&plan, expected)
}

// We don't need to use project_with_column_index in logical optimizer,
// after LogicalPlanBuilder::union, we already have all equal expression aliases
#[test]
fn eliminate_nested_union_with_projection() -> Result<()> {
let plan_builder = table_scan(Some("table"), &schema(), None)?;

let plan = plan_builder
.clone()
.union(
plan_builder
.clone()
.project(vec![col("id").alias("table_id"), col("key"), col("value")])?
.build()?,
)?
.union(
plan_builder
.clone()
.project(vec![col("id").alias("_id"), col("key"), col("value")])?
.build()?,
)?
.build()?;

let expected = "Union\
\n TableScan: table\
\n Projection: table.id AS id, table.key, table.value\
\n TableScan: table\
\n Projection: table.id AS id, table.key, table.value\
\n TableScan: table";
assert_optimized_plan_equal(&plan, expected)
}

#[test]
fn eliminate_nested_union_with_type_cast_projection() -> Result<()> {
let table_1 = table_scan(
Some("table_1"),
&Schema::new(vec![
Field::new("id", DataType::Int64, false),
Field::new("key", DataType::Utf8, false),
Field::new("value", DataType::Float64, false),
]),
None,
)?;

let table_2 = table_scan(
Some("table_1"),
&Schema::new(vec![
Field::new("id", DataType::Int32, false),
Field::new("key", DataType::Utf8, false),
Field::new("value", DataType::Float32, false),
]),
None,
)?;

let table_3 = table_scan(
Some("table_1"),
&Schema::new(vec![
Field::new("id", DataType::Int16, false),
Field::new("key", DataType::Utf8, false),
Field::new("value", DataType::Float32, false),
]),
None,
)?;

let plan = table_1
.union(table_2.build()?)?
.union(table_3.build()?)?
.build()?;

let expected = "Union\
\n TableScan: table_1\
\n Projection: CAST(table_1.id AS Int64) AS id, table_1.key, CAST(table_1.value AS Float64) AS value\
\n TableScan: table_1\
\n Projection: CAST(table_1.id AS Int64) AS id, table_1.key, CAST(table_1.value AS Float64) AS value\
\n TableScan: table_1";
assert_optimized_plan_equal(&plan, expected)
}
}
Loading

0 comments on commit 704e034

Please sign in to comment.