Skip to content

Commit

Permalink
Revert "fix: move coercion of union from builder to TypeCoercion (a…
Browse files Browse the repository at this point in the history
…pache#11961)"

This reverts commit afa23ab.
  • Loading branch information
wiedld committed Aug 19, 2024
1 parent c6c7a73 commit 4885e32
Show file tree
Hide file tree
Showing 6 changed files with 280 additions and 355 deletions.
112 changes: 106 additions & 6 deletions datafusion/expr/src/logical_plan/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
use std::any::Any;
use std::cmp::Ordering;
use std::collections::{HashMap, HashSet};
use std::iter::zip;
use std::sync::Arc;

use crate::dml::CopyTo;
Expand All @@ -35,7 +36,7 @@ use crate::logical_plan::{
Projection, Repartition, Sort, SubqueryAlias, TableScan, Union, Unnest, Values,
Window,
};
use crate::type_coercion::binary::values_coercion;
use crate::type_coercion::binary::{comparison_coercion, values_coercion};
use crate::utils::{
can_hash, columnize_expr, compare_sort_expr, expr_to_columns,
find_valid_equijoin_key_pair, group_window_expr_by_sort_keys,
Expand Down Expand Up @@ -1337,14 +1338,96 @@ pub fn validate_unique_names<'a>(
})
}

pub fn project_with_column_index(
expr: Vec<Expr>,
input: Arc<LogicalPlan>,
schema: DFSchemaRef,
) -> Result<LogicalPlan> {
let alias_expr = expr
.into_iter()
.enumerate()
.map(|(i, e)| match e {
Expr::Alias(Alias { ref name, .. }) if name != schema.field(i).name() => {
e.unalias().alias(schema.field(i).name())
}
Expr::Column(Column {
relation: _,
ref name,
}) if name != schema.field(i).name() => e.alias(schema.field(i).name()),
Expr::Alias { .. } | Expr::Column { .. } => e,
Expr::Wildcard { .. } => e,
_ => e.alias(schema.field(i).name()),
})
.collect::<Vec<_>>();

Projection::try_new_with_schema(alias_expr, input, schema)
.map(LogicalPlan::Projection)
}

/// Union two logical plans.
pub fn union(left_plan: LogicalPlan, right_plan: LogicalPlan) -> Result<LogicalPlan> {
// Temporarily use the schema from the left input and later rely on the analyzer to
// coerce the two schemas into a common one.
let schema = Arc::clone(left_plan.schema());
let left_col_num = left_plan.schema().fields().len();

// check union plan length same.
let right_col_num = right_plan.schema().fields().len();
if right_col_num != left_col_num {
return plan_err!(
"Union queries must have the same number of columns, (left is {left_col_num}, right is {right_col_num})");
}

// create union schema
let union_qualified_fields =
zip(left_plan.schema().iter(), right_plan.schema().iter())
.map(
|((left_qualifier, left_field), (_right_qualifier, right_field))| {
let nullable = left_field.is_nullable() || right_field.is_nullable();
let data_type = comparison_coercion(
left_field.data_type(),
right_field.data_type(),
)
.ok_or_else(|| {
plan_datafusion_err!(
"UNION Column {} (type: {}) is not compatible with column {} (type: {})",
right_field.name(),
right_field.data_type(),
left_field.name(),
left_field.data_type()
)
})?;
Ok((
left_qualifier.cloned(),
Arc::new(Field::new(left_field.name(), data_type, nullable)),
))
},
)
.collect::<Result<Vec<_>>>()?;
let union_schema =
DFSchema::new_with_metadata(union_qualified_fields, HashMap::new())?;

let inputs = vec![left_plan, right_plan]
.into_iter()
.map(|p| {
let plan = coerce_plan_expr_for_schema(&p, &union_schema)?;
match plan {
LogicalPlan::Projection(Projection { expr, input, .. }) => {
Ok(Arc::new(project_with_column_index(
expr,
input,
Arc::new(union_schema.clone()),
)?))
}
other_plan => Ok(Arc::new(other_plan)),
}
})
.collect::<Result<Vec<_>>>()?;

if inputs.is_empty() {
return plan_err!("Empty UNION");
}

Ok(LogicalPlan::Union(Union {
inputs: vec![Arc::new(left_plan), Arc::new(right_plan)],
schema,
inputs,
schema: Arc::new(union_schema),
}))
}

Expand Down Expand Up @@ -1767,6 +1850,23 @@ mod tests {
Ok(())
}

#[test]
fn plan_builder_union_different_num_columns_error() -> Result<()> {
let plan1 =
table_scan(TableReference::none(), &employee_schema(), Some(vec![3]))?;
let plan2 =
table_scan(TableReference::none(), &employee_schema(), Some(vec![3, 4]))?;

let expected = "Error during planning: Union queries must have the same number of columns, (left is 1, right is 2)";
let err_msg1 = plan1.clone().union(plan2.clone().build()?).unwrap_err();
let err_msg2 = plan1.union_distinct(plan2.build()?).unwrap_err();

assert_eq!(err_msg1.strip_backtrace(), expected);
assert_eq!(err_msg2.strip_backtrace(), expected);

Ok(())
}

#[test]
fn plan_builder_simple_distinct() -> Result<()> {
let plan =
Expand Down
147 changes: 31 additions & 116 deletions datafusion/optimizer/src/analyzer/type_coercion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,24 +17,22 @@

//! Optimizer rule for type validation and coercion

use std::collections::HashMap;
use std::sync::Arc;

use itertools::izip;

use arrow::datatypes::{DataType, Field, IntervalUnit};
use arrow::datatypes::{DataType, IntervalUnit};

use crate::analyzer::AnalyzerRule;
use crate::utils::NamePreserver;
use datafusion_common::config::ConfigOptions;
use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRewriter};
use datafusion_common::{
exec_err, internal_err, not_impl_err, plan_datafusion_err, plan_err, Column,
DFSchema, DFSchemaRef, DataFusionError, Result, ScalarValue,
exec_err, internal_err, not_impl_err, plan_datafusion_err, plan_err, DFSchema,
DataFusionError, Result, ScalarValue,
};
use datafusion_expr::builder::project_with_column_index;
use datafusion_expr::expr::{
self, Alias, Between, BinaryExpr, Case, Exists, InList, InSubquery, Like,
ScalarFunction, WindowFunction,
self, Between, BinaryExpr, Case, Exists, InList, InSubquery, Like, ScalarFunction,
WindowFunction,
};
use datafusion_expr::expr_rewriter::coerce_plan_expr_for_schema;
use datafusion_expr::expr_schema::cast_subquery;
Expand All @@ -53,7 +51,7 @@ use datafusion_expr::type_coercion::{is_datetime, is_utf8_or_large_utf8};
use datafusion_expr::utils::merge_schema;
use datafusion_expr::{
is_false, is_not_false, is_not_true, is_not_unknown, is_true, is_unknown, not,
AggregateUDF, Expr, ExprFunctionExt, ExprSchemable, Join, LogicalPlan, Operator,
AggregateUDF, Expr, ExprFunctionExt, ExprSchemable, LogicalPlan, Operator,
Projection, ScalarUDF, Union, WindowFrame, WindowFrameBound, WindowFrameUnits,
};

Expand Down Expand Up @@ -123,8 +121,9 @@ fn analyze_internal(
expr.rewrite(&mut expr_rewrite)?
.map_data(|expr| original_name.restore(expr))
})?
// some plans need extra coercion after their expressions are coerced
.map_data(|plan| expr_rewrite.coerce_plan(plan))?
// coerce join expressions specially
.map_data(|plan| expr_rewrite.coerce_joins(plan))?
.map_data(|plan| expr_rewrite.coerce_union(plan))?
// recompute the schema after the expressions have been rewritten as the types may have changed
.map_data(|plan| plan.recompute_schema())
}
Expand All @@ -138,14 +137,6 @@ impl<'a> TypeCoercionRewriter<'a> {
Self { schema }
}

fn coerce_plan(&mut self, plan: LogicalPlan) -> Result<LogicalPlan> {
match plan {
LogicalPlan::Join(join) => self.coerce_join(join),
LogicalPlan::Union(union) => Self::coerce_union(union),
_ => Ok(plan),
}
}

/// Coerce join equality expressions and join filter
///
/// Joins must be treated specially as their equality expressions are stored
Expand All @@ -154,7 +145,11 @@ impl<'a> TypeCoercionRewriter<'a> {
///
/// For example, on_exprs like `t1.a = t2.b AND t1.x = t2.y` will be stored
/// as a list of `(t1.a, t2.b), (t1.x, t2.y)`
fn coerce_join(&mut self, mut join: Join) -> Result<LogicalPlan> {
fn coerce_joins(&mut self, plan: LogicalPlan) -> Result<LogicalPlan> {
let LogicalPlan::Join(mut join) = plan else {
return Ok(plan);
};

join.on = join
.on
.into_iter()
Expand All @@ -175,30 +170,36 @@ impl<'a> TypeCoercionRewriter<'a> {
Ok(LogicalPlan::Join(join))
}

/// Coerce the union’s inputs to a common schema compatible with all inputs.
/// This occurs after wildcard expansion and the coercion of the input expressions.
fn coerce_union(union_plan: Union) -> Result<LogicalPlan> {
let union_schema = Arc::new(coerce_union_schema(&union_plan.inputs)?);
let new_inputs = union_plan
/// Corece the union inputs after expanding the wildcard expressions
///
/// Union inputs must have the same schema, so we coerce the expressions to match the schema
/// after expanding the wildcard expressions
fn coerce_union(&self, plan: LogicalPlan) -> Result<LogicalPlan> {
let LogicalPlan::Union(union) = plan else {
return Ok(plan);
};

let inputs = union
.inputs
.iter()
.into_iter()
.map(|p| {
let plan = coerce_plan_expr_for_schema(p, &union_schema)?;
let plan = coerce_plan_expr_for_schema(&p, &union.schema)?;
match plan {
LogicalPlan::Projection(Projection { expr, input, .. }) => {
Ok(Arc::new(project_with_column_index(
expr,
input,
Arc::clone(&union_schema),
Arc::clone(&union.schema),
)?))
}
other_plan => Ok(Arc::new(other_plan)),
}
})
.collect::<Result<Vec<_>>>()?;

Ok(LogicalPlan::Union(Union {
inputs: new_inputs,
schema: union_schema,
inputs,
schema: Arc::clone(&union.schema),
}))
}

Expand Down Expand Up @@ -808,92 +809,6 @@ fn coerce_case_expression(case: Case, schema: &DFSchema) -> Result<Case> {
Ok(Case::new(case_expr, when_then, else_expr))
}

/// Get a common schema that is compatible with all inputs of UNION.
fn coerce_union_schema(inputs: &[Arc<LogicalPlan>]) -> Result<DFSchema> {
let base_schema = inputs[0].schema();
let mut union_datatypes = base_schema
.fields()
.iter()
.map(|f| f.data_type().clone())
.collect::<Vec<_>>();
let mut union_nullabilities = base_schema
.fields()
.iter()
.map(|f| f.is_nullable())
.collect::<Vec<_>>();

for (i, plan) in inputs.iter().enumerate().skip(1) {
let plan_schema = plan.schema();
if plan_schema.fields().len() != base_schema.fields().len() {
return plan_err!(
"Union schemas have different number of fields: \
query 1 has {} fields whereas query {} has {} fields",
base_schema.fields().len(),
i + 1,
plan_schema.fields().len()
);
}
// coerce data type and nullablity for each field
for (union_datatype, union_nullable, plan_field) in izip!(
union_datatypes.iter_mut(),
union_nullabilities.iter_mut(),
plan_schema.fields()
) {
let coerced_type =
comparison_coercion(union_datatype, plan_field.data_type()).ok_or_else(
|| {
plan_datafusion_err!(
"Incompatible inputs for Union: Previous inputs were \
of type {}, but got incompatible type {} on column '{}'",
union_datatype,
plan_field.data_type(),
plan_field.name()
)
},
)?;
*union_datatype = coerced_type;
*union_nullable = *union_nullable || plan_field.is_nullable();
}
}
let union_qualified_fields = izip!(
base_schema.iter(),
union_datatypes.into_iter(),
union_nullabilities
)
.map(|((qualifier, field), datatype, nullable)| {
let field = Arc::new(Field::new(field.name().clone(), datatype, nullable));
(qualifier.cloned(), field)
})
.collect::<Vec<_>>();
DFSchema::new_with_metadata(union_qualified_fields, HashMap::new())
}

/// See `<https://github.com/apache/datafusion/pull/2108>`
fn project_with_column_index(
expr: Vec<Expr>,
input: Arc<LogicalPlan>,
schema: DFSchemaRef,
) -> Result<LogicalPlan> {
let alias_expr = expr
.into_iter()
.enumerate()
.map(|(i, e)| match e {
Expr::Alias(Alias { ref name, .. }) if name != schema.field(i).name() => {
e.unalias().alias(schema.field(i).name())
}
Expr::Column(Column {
relation: _,
ref name,
}) if name != schema.field(i).name() => e.alias(schema.field(i).name()),
Expr::Alias { .. } | Expr::Column { .. } => e,
_ => e.alias(schema.field(i).name()),
})
.collect::<Vec<_>>();

Projection::try_new_with_schema(alias_expr, input, schema)
.map(LogicalPlan::Projection)
}

#[cfg(test)]
mod test {
use std::any::Any;
Expand Down
12 changes: 1 addition & 11 deletions datafusion/optimizer/src/eliminate_nested_union.rs
Original file line number Diff line number Diff line change
Expand Up @@ -114,11 +114,8 @@ fn extract_plan_from_distinct(plan: Arc<LogicalPlan>) -> Arc<LogicalPlan> {
#[cfg(test)]
mod tests {
use super::*;
use crate::analyzer::type_coercion::TypeCoercion;
use crate::analyzer::Analyzer;
use crate::test::*;
use arrow::datatypes::{DataType, Field, Schema};
use datafusion_common::config::ConfigOptions;
use datafusion_expr::{col, logical_plan::table_scan};

fn schema() -> Schema {
Expand All @@ -130,14 +127,7 @@ mod tests {
}

fn assert_optimized_plan_equal(plan: LogicalPlan, expected: &str) -> Result<()> {
let options = ConfigOptions::default();
let analyzed_plan = Analyzer::with_rules(vec![Arc::new(TypeCoercion::new())])
.execute_and_check(plan, &options, |_, _| {})?;
assert_optimized_plan_eq(
Arc::new(EliminateNestedUnion::new()),
analyzed_plan,
expected,
)
assert_optimized_plan_eq(Arc::new(EliminateNestedUnion::new()), plan, expected)
}

#[test]
Expand Down
Loading

0 comments on commit 4885e32

Please sign in to comment.