Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: move coercion of union from builder to TypeCoercion #11961

Merged
merged 16 commits into from
Aug 14, 2024
112 changes: 6 additions & 106 deletions datafusion/expr/src/logical_plan/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
use std::any::Any;
use std::cmp::Ordering;
use std::collections::{HashMap, HashSet};
use std::iter::zip;
use std::sync::Arc;

use crate::dml::CopyTo;
Expand All @@ -36,7 +35,7 @@ use crate::logical_plan::{
Projection, Repartition, Sort, SubqueryAlias, TableScan, Union, Unnest, Values,
Window,
};
use crate::type_coercion::binary::{comparison_coercion, values_coercion};
use crate::type_coercion::binary::values_coercion;
use crate::utils::{
can_hash, columnize_expr, compare_sort_expr, expr_to_columns,
find_valid_equijoin_key_pair, group_window_expr_by_sort_keys,
Expand Down Expand Up @@ -1338,96 +1337,14 @@ pub fn validate_unique_names<'a>(
})
}

pub fn project_with_column_index(
expr: Vec<Expr>,
input: Arc<LogicalPlan>,
schema: DFSchemaRef,
) -> Result<LogicalPlan> {
let alias_expr = expr
.into_iter()
.enumerate()
.map(|(i, e)| match e {
Expr::Alias(Alias { ref name, .. }) if name != schema.field(i).name() => {
e.unalias().alias(schema.field(i).name())
}
Expr::Column(Column {
relation: _,
ref name,
}) if name != schema.field(i).name() => e.alias(schema.field(i).name()),
Expr::Alias { .. } | Expr::Column { .. } => e,
Expr::Wildcard { .. } => e,
_ => e.alias(schema.field(i).name()),
})
.collect::<Vec<_>>();

Projection::try_new_with_schema(alias_expr, input, schema)
.map(LogicalPlan::Projection)
}

/// Union two logical plans.
pub fn union(left_plan: LogicalPlan, right_plan: LogicalPlan) -> Result<LogicalPlan> {
let left_col_num = left_plan.schema().fields().len();

// check union plan length same.
let right_col_num = right_plan.schema().fields().len();
if right_col_num != left_col_num {
return plan_err!(
"Union queries must have the same number of columns, (left is {left_col_num}, right is {right_col_num})");
}

// create union schema
let union_qualified_fields =
zip(left_plan.schema().iter(), right_plan.schema().iter())
.map(
|((left_qualifier, left_field), (_right_qualifier, right_field))| {
let nullable = left_field.is_nullable() || right_field.is_nullable();
let data_type = comparison_coercion(
left_field.data_type(),
right_field.data_type(),
)
.ok_or_else(|| {
plan_datafusion_err!(
"UNION Column {} (type: {}) is not compatible with column {} (type: {})",
right_field.name(),
right_field.data_type(),
left_field.name(),
left_field.data_type()
)
})?;
Ok((
left_qualifier.cloned(),
Arc::new(Field::new(left_field.name(), data_type, nullable)),
))
},
)
.collect::<Result<Vec<_>>>()?;
let union_schema =
DFSchema::new_with_metadata(union_qualified_fields, HashMap::new())?;

let inputs = vec![left_plan, right_plan]
.into_iter()
.map(|p| {
let plan = coerce_plan_expr_for_schema(&p, &union_schema)?;
match plan {
LogicalPlan::Projection(Projection { expr, input, .. }) => {
Ok(Arc::new(project_with_column_index(
expr,
input,
Arc::new(union_schema.clone()),
)?))
}
other_plan => Ok(Arc::new(other_plan)),
}
})
.collect::<Result<Vec<_>>>()?;

if inputs.is_empty() {
return plan_err!("Empty UNION");
}

// Temporarily use the schema from the left input and later rely on the analyzer to
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

// coerce the two schemas into a common one.
let schema = Arc::clone(left_plan.schema());
Ok(LogicalPlan::Union(Union {
inputs,
schema: Arc::new(union_schema),
inputs: vec![Arc::new(left_plan), Arc::new(right_plan)],
schema,
}))
}

Expand Down Expand Up @@ -1850,23 +1767,6 @@ mod tests {
Ok(())
}

#[test]
fn plan_builder_union_different_num_columns_error() -> Result<()> {
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moved to slt.

let plan1 =
table_scan(TableReference::none(), &employee_schema(), Some(vec![3]))?;
let plan2 =
table_scan(TableReference::none(), &employee_schema(), Some(vec![3, 4]))?;

let expected = "Error during planning: Union queries must have the same number of columns, (left is 1, right is 2)";
let err_msg1 = plan1.clone().union(plan2.clone().build()?).unwrap_err();
let err_msg2 = plan1.union_distinct(plan2.build()?).unwrap_err();

assert_eq!(err_msg1.strip_backtrace(), expected);
assert_eq!(err_msg2.strip_backtrace(), expected);

Ok(())
}

#[test]
fn plan_builder_simple_distinct() -> Result<()> {
let plan =
Expand Down
147 changes: 116 additions & 31 deletions datafusion/optimizer/src/analyzer/type_coercion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,22 +17,24 @@

//! Optimizer rule for type validation and coercion

use std::collections::HashMap;
use std::sync::Arc;

use arrow::datatypes::{DataType, IntervalUnit};
use itertools::izip;

use arrow::datatypes::{DataType, Field, IntervalUnit};

use crate::analyzer::AnalyzerRule;
use crate::utils::NamePreserver;
use datafusion_common::config::ConfigOptions;
use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRewriter};
use datafusion_common::{
exec_err, internal_err, not_impl_err, plan_datafusion_err, plan_err, DFSchema,
DataFusionError, Result, ScalarValue,
exec_err, internal_err, not_impl_err, plan_datafusion_err, plan_err, Column,
DFSchema, DFSchemaRef, DataFusionError, Result, ScalarValue,
};
use datafusion_expr::builder::project_with_column_index;
use datafusion_expr::expr::{
self, Between, BinaryExpr, Case, Exists, InList, InSubquery, Like, ScalarFunction,
WindowFunction,
self, Alias, Between, BinaryExpr, Case, Exists, InList, InSubquery, Like,
ScalarFunction, WindowFunction,
};
use datafusion_expr::expr_rewriter::coerce_plan_expr_for_schema;
use datafusion_expr::expr_schema::cast_subquery;
Expand All @@ -51,7 +53,7 @@ use datafusion_expr::type_coercion::{is_datetime, is_utf8_or_large_utf8};
use datafusion_expr::utils::merge_schema;
use datafusion_expr::{
is_false, is_not_false, is_not_true, is_not_unknown, is_true, is_unknown, not,
AggregateUDF, Expr, ExprFunctionExt, ExprSchemable, LogicalPlan, Operator,
AggregateUDF, Expr, ExprFunctionExt, ExprSchemable, Join, LogicalPlan, Operator,
Projection, ScalarUDF, Union, WindowFrame, WindowFrameBound, WindowFrameUnits,
};

Expand Down Expand Up @@ -121,9 +123,8 @@ fn analyze_internal(
expr.rewrite(&mut expr_rewrite)?
.map_data(|expr| original_name.restore(expr))
})?
// coerce join expressions specially
.map_data(|plan| expr_rewrite.coerce_joins(plan))?
.map_data(|plan| expr_rewrite.coerce_union(plan))?
// some plans need extra coercion after their expressions are coerced
.map_data(|plan| expr_rewrite.coerce_plan(plan))?
// recompute the schema after the expressions have been rewritten as the types may have changed
.map_data(|plan| plan.recompute_schema())
}
Expand All @@ -137,6 +138,14 @@ impl<'a> TypeCoercionRewriter<'a> {
Self { schema }
}

fn coerce_plan(&mut self, plan: LogicalPlan) -> Result<LogicalPlan> {
match plan {
LogicalPlan::Join(join) => self.coerce_join(join),
LogicalPlan::Union(union) => Self::coerce_union(union),
_ => Ok(plan),
}
}

/// Coerce join equality expressions and join filter
///
/// Joins must be treated specially as their equality expressions are stored
Expand All @@ -145,11 +154,7 @@ impl<'a> TypeCoercionRewriter<'a> {
///
/// For example, on_exprs like `t1.a = t2.b AND t1.x = t2.y` will be stored
/// as a list of `(t1.a, t2.b), (t1.x, t2.y)`
fn coerce_joins(&mut self, plan: LogicalPlan) -> Result<LogicalPlan> {
let LogicalPlan::Join(mut join) = plan else {
return Ok(plan);
};

fn coerce_join(&mut self, mut join: Join) -> Result<LogicalPlan> {
join.on = join
.on
.into_iter()
Expand All @@ -170,36 +175,30 @@ impl<'a> TypeCoercionRewriter<'a> {
Ok(LogicalPlan::Join(join))
}

/// Corece the union inputs after expanding the wildcard expressions
///
/// Union inputs must have the same schema, so we coerce the expressions to match the schema
/// after expanding the wildcard expressions
fn coerce_union(&self, plan: LogicalPlan) -> Result<LogicalPlan> {
let LogicalPlan::Union(union) = plan else {
return Ok(plan);
};

let inputs = union
/// Coerce the union’s inputs to a common schema compatible with all inputs.
/// This occurs after wildcard expansion and the coercion of the input expressions.
fn coerce_union(union_plan: Union) -> Result<LogicalPlan> {
let union_schema = Arc::new(coerce_union_schema(&union_plan.inputs)?);
let new_inputs = union_plan
.inputs
.into_iter()
.iter()
.map(|p| {
let plan = coerce_plan_expr_for_schema(&p, &union.schema)?;
let plan = coerce_plan_expr_for_schema(p, &union_schema)?;
match plan {
LogicalPlan::Projection(Projection { expr, input, .. }) => {
Ok(Arc::new(project_with_column_index(
expr,
input,
Arc::clone(&union.schema),
Arc::clone(&union_schema),
)?))
}
other_plan => Ok(Arc::new(other_plan)),
}
})
.collect::<Result<Vec<_>>>()?;

Ok(LogicalPlan::Union(Union {
inputs,
schema: Arc::clone(&union.schema),
inputs: new_inputs,
schema: union_schema,
}))
}

Expand Down Expand Up @@ -809,6 +808,92 @@ fn coerce_case_expression(case: Case, schema: &DFSchema) -> Result<Case> {
Ok(Case::new(case_expr, when_then, else_expr))
}

/// Get a common schema that is compatible with all inputs of UNION.
fn coerce_union_schema(inputs: &[Arc<LogicalPlan>]) -> Result<DFSchema> {
let base_schema = inputs[0].schema();
let mut union_datatypes = base_schema
.fields()
.iter()
.map(|f| f.data_type().clone())
.collect::<Vec<_>>();
let mut union_nullabilities = base_schema
.fields()
.iter()
.map(|f| f.is_nullable())
.collect::<Vec<_>>();

for (i, plan) in inputs.iter().enumerate().skip(1) {
let plan_schema = plan.schema();
if plan_schema.fields().len() != base_schema.fields().len() {
return plan_err!(
"Union schemas have different number of fields: \
query 1 has {} fields whereas query {} has {} fields",
base_schema.fields().len(),
i + 1,
plan_schema.fields().len()
);
}
// coerce data type and nullablity for each field
for (union_datatype, union_nullable, plan_field) in izip!(
union_datatypes.iter_mut(),
union_nullabilities.iter_mut(),
plan_schema.fields()
) {
let coerced_type =
comparison_coercion(union_datatype, plan_field.data_type()).ok_or_else(
|| {
plan_datafusion_err!(
"Incompatible inputs for Union: Previous inputs were \
of type {}, but got incompatible type {} on column '{}'",
union_datatype,
plan_field.data_type(),
plan_field.name()
)
},
)?;
*union_datatype = coerced_type;
*union_nullable = *union_nullable || plan_field.is_nullable();
}
}
let union_qualified_fields = izip!(
base_schema.iter(),
union_datatypes.into_iter(),
union_nullabilities
)
.map(|((qualifier, field), datatype, nullable)| {
let field = Arc::new(Field::new(field.name().clone(), datatype, nullable));
(qualifier.cloned(), field)
})
.collect::<Vec<_>>();
DFSchema::new_with_metadata(union_qualified_fields, HashMap::new())
}

/// See `<https://github.com/apache/datafusion/pull/2108>`
fn project_with_column_index(
expr: Vec<Expr>,
input: Arc<LogicalPlan>,
schema: DFSchemaRef,
) -> Result<LogicalPlan> {
let alias_expr = expr
.into_iter()
.enumerate()
.map(|(i, e)| match e {
Expr::Alias(Alias { ref name, .. }) if name != schema.field(i).name() => {
e.unalias().alias(schema.field(i).name())
}
Expr::Column(Column {
relation: _,
ref name,
}) if name != schema.field(i).name() => e.alias(schema.field(i).name()),
Expr::Alias { .. } | Expr::Column { .. } => e,
_ => e.alias(schema.field(i).name()),
})
.collect::<Vec<_>>();

Projection::try_new_with_schema(alias_expr, input, schema)
.map(LogicalPlan::Projection)
}

#[cfg(test)]
mod test {
use std::any::Any;
Expand Down
12 changes: 11 additions & 1 deletion datafusion/optimizer/src/eliminate_nested_union.rs
Original file line number Diff line number Diff line change
Expand Up @@ -114,8 +114,11 @@ fn extract_plan_from_distinct(plan: Arc<LogicalPlan>) -> Arc<LogicalPlan> {
#[cfg(test)]
mod tests {
use super::*;
use crate::analyzer::type_coercion::TypeCoercion;
use crate::analyzer::Analyzer;
use crate::test::*;
use arrow::datatypes::{DataType, Field, Schema};
use datafusion_common::config::ConfigOptions;
use datafusion_expr::{col, logical_plan::table_scan};

fn schema() -> Schema {
Expand All @@ -127,7 +130,14 @@ mod tests {
}

fn assert_optimized_plan_equal(plan: LogicalPlan, expected: &str) -> Result<()> {
assert_optimized_plan_eq(Arc::new(EliminateNestedUnion::new()), plan, expected)
let options = ConfigOptions::default();
let analyzed_plan = Analyzer::with_rules(vec![Arc::new(TypeCoercion::new())])
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add TypeCoercion to avoid breaking the tests.

.execute_and_check(plan, &options, |_, _| {})?;
assert_optimized_plan_eq(
Arc::new(EliminateNestedUnion::new()),
analyzed_plan,
expected,
)
}

#[test]
Expand Down
Loading