From eb19a675d0ba3abe3d041d3a3c9fb07d29ee73fc Mon Sep 17 00:00:00 2001 From: ygf11 Date: Fri, 13 Jan 2023 05:44:19 +0800 Subject: [PATCH] Rewrite coerce_plan_expr_for_schema to fix union type coercion (#4862) * Rewrite coerce_plan_expr_for_schema * add integration tests * Update datafusion/expr/src/expr_rewriter.rs Co-authored-by: Andrew Lamb * fix comment * fix tests * fix tests Co-authored-by: Andrew Lamb --- datafusion/core/tests/sql/mod.rs | 19 +++++ datafusion/core/tests/sql/union.rs | 82 +++++++++++++++++++ .../tests/sqllogictests/test_files/union.slt | 78 ++++++++++++++++++ datafusion/expr/src/expr_rewriter.rs | 63 ++++++++++---- datafusion/expr/src/logical_plan/builder.rs | 2 +- 5 files changed, 225 insertions(+), 19 deletions(-) create mode 100644 datafusion/core/tests/sqllogictests/test_files/union.slt diff --git a/datafusion/core/tests/sql/mod.rs b/datafusion/core/tests/sql/mod.rs index 1c50c8ad0f5b..516bd8c2431c 100644 --- a/datafusion/core/tests/sql/mod.rs +++ b/datafusion/core/tests/sql/mod.rs @@ -692,6 +692,25 @@ fn create_sort_merge_join_datatype_context() -> Result { Ok(ctx) } +fn create_union_context() -> Result { + let ctx = SessionContext::new(); + let t1_schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int32, true), + Field::new("name", DataType::UInt8, true), + ])); + let t1_data = RecordBatch::new_empty(t1_schema); + ctx.register_batch("t1", t1_data)?; + + let t2_schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::UInt8, true), + Field::new("name", DataType::UInt8, true), + ])); + let t2_data = RecordBatch::new_empty(t2_schema); + ctx.register_batch("t2", t2_data)?; + + Ok(ctx) +} + fn get_tpch_table_schema(table: &str) -> Schema { match table { "customer" => Schema::new(vec![ diff --git a/datafusion/core/tests/sql/union.rs b/datafusion/core/tests/sql/union.rs index ac0e39f4d479..abc1b00d76a2 100644 --- a/datafusion/core/tests/sql/union.rs +++ b/datafusion/core/tests/sql/union.rs @@ -140,3 +140,85 @@ async fn union_schemas() -> Result<()> { assert_batches_eq!(expected, &result); Ok(()) } + +#[tokio::test] +async fn union_with_except_input() -> Result<()> { + let ctx = create_union_context()?; + let sql = "( + SELECT name FROM t1 + EXCEPT + SELECT name FROM t2 + ) + UNION ALL + ( + SELECT name FROM t2 + EXCEPT + SELECT name FROM t1 + )"; + let msg = format!("Creating logical plan for '{sql}'"); + let dataframe = ctx.sql(&("explain ".to_owned() + sql)).await.expect(&msg); + let plan = dataframe.into_optimized_plan()?; + let expected = vec![ + "Explain [plan_type:Utf8, plan:Utf8]", + " Union [name:UInt8;N]", + " LeftAnti Join: t1.name = t2.name [name:UInt8;N]", + " Distinct: [name:UInt8;N]", + " TableScan: t1 projection=[name] [name:UInt8;N]", + " Projection: t2.name [name:UInt8;N]", + " TableScan: t2 projection=[name] [name:UInt8;N]", + " LeftAnti Join: t2.name = t1.name [name:UInt8;N]", + " Distinct: [name:UInt8;N]", + " TableScan: t2 projection=[name] [name:UInt8;N]", + " Projection: t1.name [name:UInt8;N]", + " TableScan: t1 projection=[name] [name:UInt8;N]", + ]; + + let formatted = plan.display_indent_schema().to_string(); + let actual: Vec<&str> = formatted.trim().lines().collect(); + assert_eq!( + expected, actual, + "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" + ); + Ok(()) +} + +#[tokio::test] +async fn union_with_type_coercion() -> Result<()> { + let ctx = create_union_context()?; + let sql = "( + SELECT id, name FROM t1 + EXCEPT + SELECT id, name FROM t2 + ) + UNION ALL + ( + SELECT id, name FROM t2 + EXCEPT + SELECT id, name FROM t1 + )"; + let msg = format!("Creating logical plan for '{sql}'"); + let dataframe = ctx.sql(&("explain ".to_owned() + sql)).await.expect(&msg); + let plan = dataframe.into_optimized_plan()?; + let expected = vec![ + "Explain [plan_type:Utf8, plan:Utf8]", + " Union [id:Int32;N, name:UInt8;N]", + " LeftAnti Join: t1.id = CAST(t2.id AS Int32), t1.name = t2.name [id:Int32;N, name:UInt8;N]", + " Distinct: [id:Int32;N, name:UInt8;N]", + " TableScan: t1 projection=[id, name] [id:Int32;N, name:UInt8;N]", + " Projection: t2.id, t2.name [id:UInt8;N, name:UInt8;N]", + " TableScan: t2 projection=[id, name] [id:UInt8;N, name:UInt8;N]", + " Projection: CAST(t2.id AS Int32) AS id, t2.name [id:Int32;N, name:UInt8;N]", + " LeftAnti Join: CAST(t2.id AS Int32) = t1.id, t2.name = t1.name [id:UInt8;N, name:UInt8;N]", + " Distinct: [id:UInt8;N, name:UInt8;N]", + " TableScan: t2 projection=[id, name] [id:UInt8;N, name:UInt8;N]", + " Projection: t1.id, t1.name [id:Int32;N, name:UInt8;N]", + " TableScan: t1 projection=[id, name] [id:Int32;N, name:UInt8;N]", + ]; + let formatted = plan.display_indent_schema().to_string(); + let actual: Vec<&str> = formatted.trim().lines().collect(); + assert_eq!( + expected, actual, + "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" + ); + Ok(()) +} diff --git a/datafusion/core/tests/sqllogictests/test_files/union.slt b/datafusion/core/tests/sqllogictests/test_files/union.slt new file mode 100644 index 000000000000..8e4034052af8 --- /dev/null +++ b/datafusion/core/tests/sqllogictests/test_files/union.slt @@ -0,0 +1,78 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +########## +## UNION Tests +########## + +statement ok +CREATE TABLE t1( + id INT, + name TEXT, +) as VALUES + (1, 'Alex'), + (2, 'Bob'), + (3, 'Alice') +; + +statement ok +CREATE TABLE t2( + id TINYINT, + name TEXT, +) as VALUES + (1, 'Alex'), + (2, 'Bob'), + (3, 'John') +; + +# union with EXCEPT(JOIN) +query T +( + SELECT name FROM t1 + EXCEPT + SELECT name FROM t2 +) +UNION ALL +( + SELECT name FROM t2 + EXCEPT + SELECT name FROM t1 +) +ORDER BY name +---- +Alice +John + + + +# union with type coercion +query T +( + SELECT * FROM t1 + EXCEPT + SELECT * FROM t2 +) +UNION ALL +( + SELECT * FROM t2 + EXCEPT + SELECT * FROM t1 +) +ORDER BY name +---- +3 Alice +3 John diff --git a/datafusion/expr/src/expr_rewriter.rs b/datafusion/expr/src/expr_rewriter.rs index cb907f6487e9..2199486a13ee 100644 --- a/datafusion/expr/src/expr_rewriter.rs +++ b/datafusion/expr/src/expr_rewriter.rs @@ -22,7 +22,7 @@ use crate::expr::{ Like, Sort, TryCast, WindowFunction, }; use crate::logical_plan::{Aggregate, Projection}; -use crate::utils::{from_plan, grouping_set_to_exprlist}; +use crate::utils::grouping_set_to_exprlist; use crate::{Expr, ExprSchemable, LogicalPlan}; use datafusion_common::Result; use datafusion_common::{Column, DFSchema}; @@ -525,29 +525,56 @@ pub fn coerce_plan_expr_for_schema( plan: &LogicalPlan, schema: &DFSchema, ) -> Result { - let new_expr = plan - .expressions() + match plan { + // special case Projection to avoid adding multiple projections + LogicalPlan::Projection(Projection { expr, input, .. }) => { + let new_exprs = + coerce_exprs_for_schema(expr.clone(), input.schema(), schema)?; + let projection = Projection::try_new(new_exprs, input.clone())?; + Ok(LogicalPlan::Projection(projection)) + } + _ => { + let exprs: Vec = plan + .schema() + .fields() + .iter() + .map(|field| Expr::Column(field.qualified_column())) + .collect(); + + let new_exprs = coerce_exprs_for_schema(exprs, plan.schema(), schema)?; + let add_project = new_exprs.iter().any(|expr| expr.try_into_col().is_err()); + if add_project { + let projection = Projection::try_new(new_exprs, Arc::new(plan.clone()))?; + Ok(LogicalPlan::Projection(projection)) + } else { + Ok(plan.clone()) + } + } + } +} + +fn coerce_exprs_for_schema( + exprs: Vec, + src_schema: &DFSchema, + dst_schema: &DFSchema, +) -> Result> { + exprs .into_iter() .enumerate() - .map(|(i, expr)| { - let new_type = schema.field(i).data_type(); - if plan.schema().field(i).data_type() != schema.field(i).data_type() { - match (plan, &expr) { - ( - LogicalPlan::Projection(Projection { input, .. }), - Expr::Alias(e, alias), - ) => Ok(e.clone().cast_to(new_type, input.schema())?.alias(alias)), - _ => expr.cast_to(new_type, plan.schema()), + .map(|(idx, expr)| { + let new_type = dst_schema.field(idx).data_type(); + if new_type != &expr.get_type(src_schema)? { + match expr { + Expr::Alias(e, alias) => { + Ok(e.cast_to(new_type, src_schema)?.alias(alias)) + } + _ => expr.cast_to(new_type, src_schema), } } else { - Ok(expr) + Ok(expr.clone()) } }) - .collect::>>()?; - - let new_inputs = plan.inputs().into_iter().cloned().collect::>(); - - from_plan(plan, &new_expr, &new_inputs) + .collect::>() } #[cfg(test)] diff --git a/datafusion/expr/src/logical_plan/builder.rs b/datafusion/expr/src/logical_plan/builder.rs index d5e5257ea79b..95d808121ae3 100644 --- a/datafusion/expr/src/logical_plan/builder.rs +++ b/datafusion/expr/src/logical_plan/builder.rs @@ -1016,7 +1016,7 @@ pub fn union(left_plan: LogicalPlan, right_plan: LogicalPlan) -> Result { Ok(Arc::new(project_with_column_index( - expr.to_vec(), + expr, input, Arc::new(union_schema.clone()), )?))