Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add distinct union optimization #7788

Merged
merged 3 commits into from
Oct 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
208 changes: 193 additions & 15 deletions datafusion/optimizer/src/eliminate_nested_union.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,11 @@
// under the License.

//! Optimizer rule to replace nested unions to single union.
use crate::optimizer::ApplyOrder;
use crate::{OptimizerConfig, OptimizerRule};
use datafusion_common::Result;
use datafusion_expr::logical_plan::{LogicalPlan, Union};

use crate::optimizer::ApplyOrder;
use datafusion_expr::expr_rewriter::coerce_plan_expr_for_schema;
use datafusion_expr::{Distinct, LogicalPlan, Union};
use std::sync::Arc;

#[derive(Default)]
Expand All @@ -41,29 +40,35 @@ impl OptimizerRule for EliminateNestedUnion {
plan: &LogicalPlan,
_config: &dyn OptimizerConfig,
) -> Result<Option<LogicalPlan>> {
// TODO: Add optimization for nested distinct unions.
match plan {
LogicalPlan::Union(Union { inputs, schema }) => {
let inputs = inputs
.iter()
.flat_map(|plan| match plan.as_ref() {
LogicalPlan::Union(Union { inputs, schema }) => inputs
.iter()
.map(|plan| {
Arc::new(
coerce_plan_expr_for_schema(plan, schema).unwrap(),
)
})
.collect::<Vec<_>>(),
_ => vec![plan.clone()],
})
.flat_map(extract_plans_from_union)
.collect::<Vec<_>>();

Ok(Some(LogicalPlan::Union(Union {
inputs,
schema: schema.clone(),
})))
}
LogicalPlan::Distinct(Distinct { input: plan }) => match plan.as_ref() {
LogicalPlan::Union(Union { inputs, schema }) => {
let inputs = inputs
.iter()
.map(extract_plan_from_distinct)
.flat_map(extract_plans_from_union)
.collect::<Vec<_>>();

Ok(Some(LogicalPlan::Distinct(Distinct {
input: Arc::new(LogicalPlan::Union(Union {
inputs,
schema: schema.clone(),
})),
})))
}
_ => Ok(None),
},
_ => Ok(None),
}
}
Expand All @@ -77,6 +82,23 @@ impl OptimizerRule for EliminateNestedUnion {
}
}

fn extract_plans_from_union(plan: &Arc<LogicalPlan>) -> Vec<Arc<LogicalPlan>> {
match plan.as_ref() {
LogicalPlan::Union(Union { inputs, schema }) => inputs
.iter()
.map(|plan| Arc::new(coerce_plan_expr_for_schema(plan, schema).unwrap()))
.collect::<Vec<_>>(),
_ => vec![plan.clone()],
}
}

fn extract_plan_from_distinct(plan: &Arc<LogicalPlan>) -> &Arc<LogicalPlan> {
match plan.as_ref() {
LogicalPlan::Distinct(Distinct { input: plan }) => plan,
_ => plan,
}
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down Expand Up @@ -112,6 +134,22 @@ mod tests {
assert_optimized_plan_equal(&plan, expected)
}

#[test]
fn eliminate_distinct_nothing() -> Result<()> {
let plan_builder = table_scan(Some("table"), &schema(), None)?;

let plan = plan_builder
.clone()
.union_distinct(plan_builder.clone().build()?)?
.build()?;

let expected = "Distinct:\
\n Union\
\n TableScan: table\
\n TableScan: table";
assert_optimized_plan_equal(&plan, expected)
}

#[test]
fn eliminate_nested_union() -> Result<()> {
let plan_builder = table_scan(Some("table"), &schema(), None)?;
Expand All @@ -132,6 +170,69 @@ mod tests {
assert_optimized_plan_equal(&plan, expected)
}

#[test]
fn eliminate_nested_union_with_distinct_union() -> Result<()> {
let plan_builder = table_scan(Some("table"), &schema(), None)?;

let plan = plan_builder
.clone()
.union_distinct(plan_builder.clone().build()?)?
.union(plan_builder.clone().build()?)?
.union(plan_builder.clone().build()?)?
.build()?;

let expected = "Union\
\n Distinct:\
\n Union\
\n TableScan: table\
\n TableScan: table\
\n TableScan: table\
\n TableScan: table";
assert_optimized_plan_equal(&plan, expected)
}

#[test]
fn eliminate_nested_distinct_union() -> Result<()> {
let plan_builder = table_scan(Some("table"), &schema(), None)?;

let plan = plan_builder
.clone()
.union(plan_builder.clone().build()?)?
.union_distinct(plan_builder.clone().build()?)?
.union(plan_builder.clone().build()?)?
.union_distinct(plan_builder.clone().build()?)?
.build()?;

let expected = "Distinct:\
\n Union\
\n TableScan: table\
\n TableScan: table\
\n TableScan: table\
\n TableScan: table\
\n TableScan: table";
assert_optimized_plan_equal(&plan, expected)
}

#[test]
fn eliminate_nested_distinct_union_with_distinct_table() -> Result<()> {
let plan_builder = table_scan(Some("table"), &schema(), None)?;

let plan = plan_builder
.clone()
.union_distinct(plan_builder.clone().distinct()?.build()?)?
.union(plan_builder.clone().distinct()?.build()?)?
.union_distinct(plan_builder.clone().build()?)?
.build()?;

let expected = "Distinct:\
\n Union\
\n TableScan: table\
\n TableScan: table\
\n TableScan: table\
\n TableScan: table";
assert_optimized_plan_equal(&plan, expected)
}

// We don't need to use project_with_column_index in logical optimizer,
// after LogicalPlanBuilder::union, we already have all equal expression aliases
#[test]
Expand Down Expand Up @@ -163,6 +264,36 @@ mod tests {
assert_optimized_plan_equal(&plan, expected)
}

#[test]
fn eliminate_nested_distinct_union_with_projection() -> Result<()> {
let plan_builder = table_scan(Some("table"), &schema(), None)?;

let plan = plan_builder
.clone()
.union_distinct(
plan_builder
.clone()
.project(vec![col("id").alias("table_id"), col("key"), col("value")])?
.build()?,
)?
.union_distinct(
plan_builder
.clone()
.project(vec![col("id").alias("_id"), col("key"), col("value")])?
.build()?,
)?
.build()?;

let expected = "Distinct:\
\n Union\
\n TableScan: table\
\n Projection: table.id AS id, table.key, table.value\
\n TableScan: table\
\n Projection: table.id AS id, table.key, table.value\
\n TableScan: table";
assert_optimized_plan_equal(&plan, expected)
}

#[test]
fn eliminate_nested_union_with_type_cast_projection() -> Result<()> {
let table_1 = table_scan(
Expand Down Expand Up @@ -208,4 +339,51 @@ mod tests {
\n TableScan: table_1";
assert_optimized_plan_equal(&plan, expected)
}

#[test]
fn eliminate_nested_distinct_union_with_type_cast_projection() -> Result<()> {
let table_1 = table_scan(
Some("table_1"),
&Schema::new(vec![
Field::new("id", DataType::Int64, false),
Field::new("key", DataType::Utf8, false),
Field::new("value", DataType::Float64, false),
]),
None,
)?;

let table_2 = table_scan(
Some("table_1"),
&Schema::new(vec![
Field::new("id", DataType::Int32, false),
Field::new("key", DataType::Utf8, false),
Field::new("value", DataType::Float32, false),
]),
None,
)?;

let table_3 = table_scan(
Some("table_1"),
&Schema::new(vec![
Field::new("id", DataType::Int16, false),
Field::new("key", DataType::Utf8, false),
Field::new("value", DataType::Float32, false),
]),
None,
)?;

let plan = table_1
.union_distinct(table_2.build()?)?
.union_distinct(table_3.build()?)?
.build()?;

let expected = "Distinct:\
\n Union\
\n TableScan: table_1\
\n Projection: CAST(table_1.id AS Int64) AS id, table_1.key, CAST(table_1.value AS Float64) AS value\
\n TableScan: table_1\
\n Projection: CAST(table_1.id AS Int64) AS id, table_1.key, CAST(table_1.value AS Float64) AS value\
\n TableScan: table_1";
assert_optimized_plan_equal(&plan, expected)
}
}
24 changes: 8 additions & 16 deletions datafusion/sqllogictest/test_files/union.slt
Original file line number Diff line number Diff line change
Expand Up @@ -186,35 +186,27 @@ Bob_new
John
John_new

# should be un-nested
# https://github.com/apache/arrow-datafusion/issues/7786
# should be un-nested, with a single (logical) aggregate
query TT
EXPLAIN SELECT name FROM t1 UNION (SELECT name from t2 UNION SELECT name || '_new' from t2)
----
logical_plan
Aggregate: groupBy=[[t1.name]], aggr=[[]]
--Union
----TableScan: t1 projection=[name]
----Aggregate: groupBy=[[t2.name]], aggr=[[]]
------Union
--------TableScan: t2 projection=[name]
--------Projection: t2.name || Utf8("_new") AS name
----------TableScan: t2 projection=[name]
----TableScan: t2 projection=[name]
----Projection: t2.name || Utf8("_new") AS name
------TableScan: t2 projection=[name]
physical_plan
AggregateExec: mode=FinalPartitioned, gby=[name@0 as name], aggr=[]
--CoalesceBatchesExec: target_batch_size=8192
----RepartitionExec: partitioning=Hash([name@0], 4), input_partitions=8
----RepartitionExec: partitioning=Hash([name@0], 4), input_partitions=12
------AggregateExec: mode=Partial, gby=[name@0 as name], aggr=[]
--------UnionExec
----------MemoryExec: partitions=4, partition_sizes=[1, 0, 0, 0]
----------AggregateExec: mode=FinalPartitioned, gby=[name@0 as name], aggr=[]
------------CoalesceBatchesExec: target_batch_size=8192
--------------RepartitionExec: partitioning=Hash([name@0], 4), input_partitions=8
----------------AggregateExec: mode=Partial, gby=[name@0 as name], aggr=[]
------------------UnionExec
--------------------MemoryExec: partitions=4, partition_sizes=[1, 0, 0, 0]
--------------------ProjectionExec: expr=[name@0 || _new as name]
----------------------MemoryExec: partitions=4, partition_sizes=[1, 0, 0, 0]
----------MemoryExec: partitions=4, partition_sizes=[1, 0, 0, 0]
----------ProjectionExec: expr=[name@0 || _new as name]
------------MemoryExec: partitions=4, partition_sizes=[1, 0, 0, 0]

# nested_union_all
query T rowsort
Expand Down