Skip to content

Commit

Permalink
Add eliminate_distinct_nested_union spagetti implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
maruschin authored and Evgeny Maruschenko committed Oct 11, 2023
1 parent 92ba6c3 commit 7b01b79
Showing 1 changed file with 193 additions and 15 deletions.
208 changes: 193 additions & 15 deletions datafusion/optimizer/src/eliminate_nested_union.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,11 @@
// under the License.

//! Optimizer rule to replace nested unions to single union.
use crate::optimizer::ApplyOrder;
use crate::{OptimizerConfig, OptimizerRule};
use datafusion_common::Result;
use datafusion_expr::logical_plan::{LogicalPlan, Union};

use crate::optimizer::ApplyOrder;
use datafusion_expr::expr_rewriter::coerce_plan_expr_for_schema;
use datafusion_expr::{Distinct, LogicalPlan, Union};
use std::sync::Arc;

#[derive(Default)]
Expand All @@ -41,29 +40,35 @@ impl OptimizerRule for EliminateNestedUnion {
plan: &LogicalPlan,
_config: &dyn OptimizerConfig,
) -> Result<Option<LogicalPlan>> {
// TODO: Add optimization for nested distinct unions.
match plan {
LogicalPlan::Union(Union { inputs, schema }) => {
let inputs = inputs
.iter()
.flat_map(|plan| match plan.as_ref() {
LogicalPlan::Union(Union { inputs, schema }) => inputs
.iter()
.map(|plan| {
Arc::new(
coerce_plan_expr_for_schema(plan, schema).unwrap(),
)
})
.collect::<Vec<_>>(),
_ => vec![plan.clone()],
})
.flat_map(extract_plans_from_union)
.collect::<Vec<_>>();

Ok(Some(LogicalPlan::Union(Union {
inputs,
schema: schema.clone(),
})))
}
LogicalPlan::Distinct(Distinct { input: plan }) => match plan.as_ref() {
LogicalPlan::Union(Union { inputs, schema }) => {
let inputs = inputs
.iter()
.map(extract_plan_from_distinct)
.flat_map(extract_plans_from_union)
.collect::<Vec<_>>();

Ok(Some(LogicalPlan::Distinct(Distinct {
input: Arc::new(LogicalPlan::Union(Union {
inputs,
schema: schema.clone(),
})),
})))
}
_ => Ok(None),
},
_ => Ok(None),
}
}
Expand All @@ -77,6 +82,23 @@ impl OptimizerRule for EliminateNestedUnion {
}
}

fn extract_plans_from_union(plan: &Arc<LogicalPlan>) -> Vec<Arc<LogicalPlan>> {
match plan.as_ref() {
LogicalPlan::Union(Union { inputs, schema }) => inputs
.iter()
.map(|plan| Arc::new(coerce_plan_expr_for_schema(plan, schema).unwrap()))
.collect::<Vec<_>>(),
_ => vec![plan.clone()],
}
}

fn extract_plan_from_distinct(plan: &Arc<LogicalPlan>) -> &Arc<LogicalPlan> {
match plan.as_ref() {
LogicalPlan::Distinct(Distinct { input: plan }) => plan,
_ => plan,
}
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down Expand Up @@ -112,6 +134,22 @@ mod tests {
assert_optimized_plan_equal(&plan, expected)
}

#[test]
fn eliminate_distinct_nothing() -> Result<()> {
let plan_builder = table_scan(Some("table"), &schema(), None)?;

let plan = plan_builder
.clone()
.union_distinct(plan_builder.clone().build()?)?
.build()?;

let expected = "Distinct:\
\n Union\
\n TableScan: table\
\n TableScan: table";
assert_optimized_plan_equal(&plan, expected)
}

#[test]
fn eliminate_nested_union() -> Result<()> {
let plan_builder = table_scan(Some("table"), &schema(), None)?;
Expand All @@ -132,6 +170,69 @@ mod tests {
assert_optimized_plan_equal(&plan, expected)
}

#[test]
fn eliminate_nested_union_with_distinct_union() -> Result<()> {
let plan_builder = table_scan(Some("table"), &schema(), None)?;

let plan = plan_builder
.clone()
.union_distinct(plan_builder.clone().build()?)?
.union(plan_builder.clone().build()?)?
.union(plan_builder.clone().build()?)?
.build()?;

let expected = "Union\
\n Distinct:\
\n Union\
\n TableScan: table\
\n TableScan: table\
\n TableScan: table\
\n TableScan: table";
assert_optimized_plan_equal(&plan, expected)
}

#[test]
fn eliminate_nested_distinct_union() -> Result<()> {
let plan_builder = table_scan(Some("table"), &schema(), None)?;

let plan = plan_builder
.clone()
.union(plan_builder.clone().build()?)?
.union_distinct(plan_builder.clone().build()?)?
.union(plan_builder.clone().build()?)?
.union_distinct(plan_builder.clone().build()?)?
.build()?;

let expected = "Distinct:\
\n Union\
\n TableScan: table\
\n TableScan: table\
\n TableScan: table\
\n TableScan: table\
\n TableScan: table";
assert_optimized_plan_equal(&plan, expected)
}

#[test]
fn eliminate_nested_distinct_union_with_distinct_table() -> Result<()> {
let plan_builder = table_scan(Some("table"), &schema(), None)?;

let plan = plan_builder
.clone()
.union_distinct(plan_builder.clone().distinct()?.build()?)?
.union(plan_builder.clone().distinct()?.build()?)?
.union_distinct(plan_builder.clone().build()?)?
.build()?;

let expected = "Distinct:\
\n Union\
\n TableScan: table\
\n TableScan: table\
\n TableScan: table\
\n TableScan: table";
assert_optimized_plan_equal(&plan, expected)
}

// We don't need to use project_with_column_index in logical optimizer,
// after LogicalPlanBuilder::union, we already have all equal expression aliases
#[test]
Expand Down Expand Up @@ -163,6 +264,36 @@ mod tests {
assert_optimized_plan_equal(&plan, expected)
}

#[test]
fn eliminate_nested_distinct_union_with_projection() -> Result<()> {
let plan_builder = table_scan(Some("table"), &schema(), None)?;

let plan = plan_builder
.clone()
.union_distinct(
plan_builder
.clone()
.project(vec![col("id").alias("table_id"), col("key"), col("value")])?
.build()?,
)?
.union_distinct(
plan_builder
.clone()
.project(vec![col("id").alias("_id"), col("key"), col("value")])?
.build()?,
)?
.build()?;

let expected = "Distinct:\
\n Union\
\n TableScan: table\
\n Projection: table.id AS id, table.key, table.value\
\n TableScan: table\
\n Projection: table.id AS id, table.key, table.value\
\n TableScan: table";
assert_optimized_plan_equal(&plan, expected)
}

#[test]
fn eliminate_nested_union_with_type_cast_projection() -> Result<()> {
let table_1 = table_scan(
Expand Down Expand Up @@ -208,4 +339,51 @@ mod tests {
\n TableScan: table_1";
assert_optimized_plan_equal(&plan, expected)
}

#[test]
fn eliminate_nested_distinct_union_with_type_cast_projection() -> Result<()> {
let table_1 = table_scan(
Some("table_1"),
&Schema::new(vec![
Field::new("id", DataType::Int64, false),
Field::new("key", DataType::Utf8, false),
Field::new("value", DataType::Float64, false),
]),
None,
)?;

let table_2 = table_scan(
Some("table_1"),
&Schema::new(vec![
Field::new("id", DataType::Int32, false),
Field::new("key", DataType::Utf8, false),
Field::new("value", DataType::Float32, false),
]),
None,
)?;

let table_3 = table_scan(
Some("table_1"),
&Schema::new(vec![
Field::new("id", DataType::Int16, false),
Field::new("key", DataType::Utf8, false),
Field::new("value", DataType::Float32, false),
]),
None,
)?;

let plan = table_1
.union_distinct(table_2.build()?)?
.union_distinct(table_3.build()?)?
.build()?;

let expected = "Distinct:\
\n Union\
\n TableScan: table_1\
\n Projection: CAST(table_1.id AS Int64) AS id, table_1.key, CAST(table_1.value AS Float64) AS value\
\n TableScan: table_1\
\n Projection: CAST(table_1.id AS Int64) AS id, table_1.key, CAST(table_1.value AS Float64) AS value\
\n TableScan: table_1";
assert_optimized_plan_equal(&plan, expected)
}
}

0 comments on commit 7b01b79

Please sign in to comment.