Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: move coercion of union from builder to TypeCoercion #11961

Merged
merged 16 commits into from
Aug 14, 2024
111 changes: 6 additions & 105 deletions datafusion/expr/src/logical_plan/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
use std::any::Any;
use std::cmp::Ordering;
use std::collections::{HashMap, HashSet};
use std::iter::zip;
use std::sync::Arc;

use crate::dml::CopyTo;
Expand All @@ -36,7 +35,7 @@ use crate::logical_plan::{
Projection, Repartition, Sort, SubqueryAlias, TableScan, Union, Unnest, Values,
Window,
};
use crate::type_coercion::binary::{comparison_coercion, values_coercion};
use crate::type_coercion::binary::values_coercion;
use crate::utils::{
can_hash, columnize_expr, compare_sort_expr, expand_qualified_wildcard,
expand_wildcard, expr_to_columns, find_valid_equijoin_key_pair,
Expand Down Expand Up @@ -1339,95 +1338,14 @@ pub(crate) fn validate_unique_names<'a>(
})
}

pub fn project_with_column_index(
expr: Vec<Expr>,
input: Arc<LogicalPlan>,
schema: DFSchemaRef,
) -> Result<LogicalPlan> {
let alias_expr = expr
.into_iter()
.enumerate()
.map(|(i, e)| match e {
Expr::Alias(Alias { ref name, .. }) if name != schema.field(i).name() => {
e.unalias().alias(schema.field(i).name())
}
Expr::Column(Column {
relation: _,
ref name,
}) if name != schema.field(i).name() => e.alias(schema.field(i).name()),
Expr::Alias { .. } | Expr::Column { .. } => e,
_ => e.alias(schema.field(i).name()),
})
.collect::<Vec<_>>();

Projection::try_new_with_schema(alias_expr, input, schema)
.map(LogicalPlan::Projection)
}

/// Union two logical plans.
pub fn union(left_plan: LogicalPlan, right_plan: LogicalPlan) -> Result<LogicalPlan> {
let left_col_num = left_plan.schema().fields().len();

// check union plan length same.
let right_col_num = right_plan.schema().fields().len();
if right_col_num != left_col_num {
return plan_err!(
"Union queries must have the same number of columns, (left is {left_col_num}, right is {right_col_num})");
}

// create union schema
let union_qualified_fields =
zip(left_plan.schema().iter(), right_plan.schema().iter())
.map(
|((left_qualifier, left_field), (_right_qualifier, right_field))| {
let nullable = left_field.is_nullable() || right_field.is_nullable();
let data_type = comparison_coercion(
left_field.data_type(),
right_field.data_type(),
)
.ok_or_else(|| {
plan_datafusion_err!(
"UNION Column {} (type: {}) is not compatible with column {} (type: {})",
right_field.name(),
right_field.data_type(),
left_field.name(),
left_field.data_type()
)
})?;
Ok((
left_qualifier.cloned(),
Arc::new(Field::new(left_field.name(), data_type, nullable)),
))
},
)
.collect::<Result<Vec<_>>>()?;
let union_schema =
DFSchema::new_with_metadata(union_qualified_fields, HashMap::new())?;

let inputs = vec![left_plan, right_plan]
.into_iter()
.map(|p| {
let plan = coerce_plan_expr_for_schema(&p, &union_schema)?;
match plan {
LogicalPlan::Projection(Projection { expr, input, .. }) => {
Ok(Arc::new(project_with_column_index(
expr,
input,
Arc::new(union_schema.clone()),
)?))
}
other_plan => Ok(Arc::new(other_plan)),
}
})
.collect::<Result<Vec<_>>>()?;

if inputs.is_empty() {
return plan_err!("Empty UNION");
}

// Temporarily use the schema from the left input and later rely on the analyzer to
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

// coerce the two schemas into a common one.
let schema = Arc::clone(left_plan.schema());
Ok(LogicalPlan::Union(Union {
inputs,
schema: Arc::new(union_schema),
inputs: vec![Arc::new(left_plan), Arc::new(right_plan)],
schema,
}))
}

Expand Down Expand Up @@ -1881,23 +1799,6 @@ mod tests {
Ok(())
}

#[test]
fn plan_builder_union_different_num_columns_error() -> Result<()> {
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moved to slt.

let plan1 =
table_scan(TableReference::none(), &employee_schema(), Some(vec![3]))?;
let plan2 =
table_scan(TableReference::none(), &employee_schema(), Some(vec![3, 4]))?;

let expected = "Error during planning: Union queries must have the same number of columns, (left is 1, right is 2)";
let err_msg1 = plan1.clone().union(plan2.clone().build()?).unwrap_err();
let err_msg2 = plan1.union_distinct(plan2.build()?).unwrap_err();

assert_eq!(err_msg1.strip_backtrace(), expected);
assert_eq!(err_msg2.strip_backtrace(), expected);

Ok(())
}

#[test]
fn plan_builder_simple_distinct() -> Result<()> {
let plan =
Expand Down
147 changes: 133 additions & 14 deletions datafusion/optimizer/src/analyzer/type_coercion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,20 +17,24 @@

//! Optimizer rule for type validation and coercion

use std::collections::HashMap;
use std::sync::Arc;

use arrow::datatypes::{DataType, IntervalUnit};
use itertools::izip;

use arrow::datatypes::{DataType, Field, IntervalUnit};

use datafusion_common::config::ConfigOptions;
use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRewriter};
use datafusion_common::{
exec_err, internal_err, not_impl_err, plan_datafusion_err, plan_err, DFSchema,
DataFusionError, Result, ScalarValue,
exec_err, internal_err, not_impl_err, plan_datafusion_err, plan_err, Column,
DFSchema, DFSchemaRef, DataFusionError, Result, ScalarValue,
};
use datafusion_expr::expr::{
self, Between, BinaryExpr, Case, Exists, InList, InSubquery, Like, ScalarFunction,
WindowFunction,
self, Alias, Between, BinaryExpr, Case, Exists, InList, InSubquery, Like,
ScalarFunction, WindowFunction,
};
use datafusion_expr::expr_rewriter::coerce_plan_expr_for_schema;
use datafusion_expr::expr_schema::cast_subquery;
use datafusion_expr::logical_plan::tree_node::unwrap_arc;
use datafusion_expr::logical_plan::Subquery;
Expand All @@ -47,8 +51,8 @@ use datafusion_expr::type_coercion::{is_datetime, is_utf8_or_large_utf8};
use datafusion_expr::utils::merge_schema;
use datafusion_expr::{
is_false, is_not_false, is_not_true, is_not_unknown, is_true, is_unknown, not,
AggregateUDF, Expr, ExprFunctionExt, ExprSchemable, LogicalPlan, Operator, ScalarUDF,
WindowFrame, WindowFrameBound, WindowFrameUnits,
AggregateUDF, Expr, ExprFunctionExt, ExprSchemable, Join, LogicalPlan, Operator,
Projection, ScalarUDF, Union, WindowFrame, WindowFrameBound, WindowFrameUnits,
};

use crate::analyzer::AnalyzerRule;
Expand Down Expand Up @@ -120,8 +124,8 @@ fn analyze_internal(
expr.rewrite(&mut expr_rewrite)?
.map_data(|expr| original_name.restore(expr))
})?
// coerce join expressions specially
.map_data(|plan| expr_rewrite.coerce_joins(plan))?
// some plans need extra coercion after their expressions are coerced
.map_data(|plan| expr_rewrite.coerce_plan(plan))?
// recompute the schema after the expressions have been rewritten as the types may have changed
.map_data(|plan| plan.recompute_schema())
}
Expand All @@ -135,6 +139,14 @@ impl<'a> TypeCoercionRewriter<'a> {
Self { schema }
}

fn coerce_plan(&mut self, plan: LogicalPlan) -> Result<LogicalPlan> {
match plan {
LogicalPlan::Join(join) => self.coerce_join(join),
LogicalPlan::Union(union) => coerce_union(union),
_ => Ok(plan),
}
}

/// Coerce join equality expressions and join filter
///
/// Joins must be treated specially as their equality expressions are stored
Expand All @@ -143,11 +155,7 @@ impl<'a> TypeCoercionRewriter<'a> {
///
/// For example, on_exprs like `t1.a = t2.b AND t1.x = t2.y` will be stored
/// as a list of `(t1.a, t2.b), (t1.x, t2.y)`
fn coerce_joins(&mut self, plan: LogicalPlan) -> Result<LogicalPlan> {
let LogicalPlan::Join(mut join) = plan else {
return Ok(plan);
};

fn coerce_join(&mut self, mut join: Join) -> Result<LogicalPlan> {
join.on = join
.on
.into_iter()
Expand Down Expand Up @@ -774,6 +782,117 @@ fn coerce_case_expression(case: Case, schema: &DFSchema) -> Result<Case> {
Ok(Case::new(case_expr, when_then, else_expr))
}

/// Get a common schema that is compatible with all inputs of UNION.
fn coerce_union_schema(inputs: Vec<Arc<LogicalPlan>>) -> Result<DFSchema> {
let base_schema = inputs[0].schema();
let mut union_datatypes = base_schema
.fields()
.iter()
.map(|f| f.data_type().clone())
.collect::<Vec<_>>();
let mut union_nullabilities = base_schema
.fields()
.iter()
.map(|f| f.is_nullable())
.collect::<Vec<_>>();

for (i, plan) in inputs.iter().enumerate().skip(1) {
let plan_schema = plan.schema();
if plan_schema.fields().len() != base_schema.fields().len() {
return plan_err!(
"Union schemas have different number of fields: \
query 1 has {} fields whereas query {} has {} fields",
base_schema.fields().len(),
i + 1,
plan_schema.fields().len()
);
}
// coerce data type and nullablity for each field
for (union_datatype, union_nullable, plan_field) in izip!(
union_datatypes.iter_mut(),
union_nullabilities.iter_mut(),
plan_schema.fields()
) {
let coerced_type =
comparison_coercion(union_datatype, plan_field.data_type()).ok_or_else(
|| {
plan_datafusion_err!(
"UNION Column '{}' (type: {}) is not compatible with other type: {}",
plan_field.name(),
plan_field.data_type(),
union_datatype
)
},
)?;
*union_datatype = coerced_type;
*union_nullable = *union_nullable || plan_field.is_nullable();
}
}
let union_qualified_fields = izip!(
base_schema.iter(),
union_datatypes.into_iter(),
union_nullabilities
)
.map(|((qualifier, field), datatype, nullable)| {
let field = Arc::new(Field::new(field.name().clone(), datatype, nullable));
(qualifier.cloned(), field)
})
.collect::<Vec<_>>();
DFSchema::new_with_metadata(union_qualified_fields, HashMap::new())
}

/// Coerce the union's inputs to a common schema
fn coerce_union(union_plan: Union) -> Result<LogicalPlan> {
let union_schema = coerce_union_schema(union_plan.inputs.clone())?;
let new_inputs = union_plan
.inputs
.iter()
.map(|p| {
let plan = coerce_plan_expr_for_schema(p, &union_schema)?;
match plan {
LogicalPlan::Projection(Projection { expr, input, .. }) => {
Ok(Arc::new(project_with_column_index(
expr,
input,
Arc::new(union_schema.clone()),
)?))
}
other_plan => Ok(Arc::new(other_plan)),
}
})
.collect::<Result<Vec<_>>>()?;
Ok(LogicalPlan::Union(Union {
inputs: new_inputs,
schema: Arc::new(union_schema),
}))
}

/// See `<https://github.com/apache/datafusion/pull/2108>`
fn project_with_column_index(
expr: Vec<Expr>,
input: Arc<LogicalPlan>,
schema: DFSchemaRef,
) -> Result<LogicalPlan> {
let alias_expr = expr
.into_iter()
.enumerate()
.map(|(i, e)| match e {
Expr::Alias(Alias { ref name, .. }) if name != schema.field(i).name() => {
e.unalias().alias(schema.field(i).name())
}
Expr::Column(Column {
relation: _,
ref name,
}) if name != schema.field(i).name() => e.alias(schema.field(i).name()),
Expr::Alias { .. } | Expr::Column { .. } => e,
_ => e.alias(schema.field(i).name()),
})
.collect::<Vec<_>>();

Projection::try_new_with_schema(alias_expr, input, schema)
.map(LogicalPlan::Projection)
}

#[cfg(test)]
mod test {
use std::any::Any;
Expand Down
12 changes: 11 additions & 1 deletion datafusion/optimizer/src/eliminate_nested_union.rs
Original file line number Diff line number Diff line change
Expand Up @@ -114,8 +114,11 @@ fn extract_plan_from_distinct(plan: Arc<LogicalPlan>) -> Arc<LogicalPlan> {
#[cfg(test)]
mod tests {
use super::*;
use crate::analyzer::type_coercion::TypeCoercion;
use crate::analyzer::Analyzer;
use crate::test::*;
use arrow::datatypes::{DataType, Field, Schema};
use datafusion_common::config::ConfigOptions;
use datafusion_expr::{col, logical_plan::table_scan};

fn schema() -> Schema {
Expand All @@ -127,7 +130,14 @@ mod tests {
}

fn assert_optimized_plan_equal(plan: LogicalPlan, expected: &str) -> Result<()> {
assert_optimized_plan_eq(Arc::new(EliminateNestedUnion::new()), plan, expected)
let options = ConfigOptions::default();
let analyzed_plan = Analyzer::with_rules(vec![Arc::new(TypeCoercion::new())])
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add TypeCoercion to avoid breaking the tests.

.execute_and_check(plan, &options, |_, _| {})?;
assert_optimized_plan_eq(
Arc::new(EliminateNestedUnion::new()),
analyzed_plan,
expected,
)
}

#[test]
Expand Down
Loading