Skip to content

Commit

Permalink
Rewrite coerce_plan_expr_for_schema to fix union type coercion (#4862)
Browse files Browse the repository at this point in the history
* Rewrite coerce_plan_expr_for_schema

* add integration tests

* Update datafusion/expr/src/expr_rewriter.rs

Co-authored-by: Andrew Lamb <andrew@nerdnetworks.org>

* fix comment

* fix tests

* fix tests

Co-authored-by: Andrew Lamb <andrew@nerdnetworks.org>
  • Loading branch information
ygf11 and alamb authored Jan 12, 2023
1 parent 3b86643 commit eb19a67
Show file tree
Hide file tree
Showing 5 changed files with 225 additions and 19 deletions.
19 changes: 19 additions & 0 deletions datafusion/core/tests/sql/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -692,6 +692,25 @@ fn create_sort_merge_join_datatype_context() -> Result<SessionContext> {
Ok(ctx)
}

fn create_union_context() -> Result<SessionContext> {
let ctx = SessionContext::new();
let t1_schema = Arc::new(Schema::new(vec![
Field::new("id", DataType::Int32, true),
Field::new("name", DataType::UInt8, true),
]));
let t1_data = RecordBatch::new_empty(t1_schema);
ctx.register_batch("t1", t1_data)?;

let t2_schema = Arc::new(Schema::new(vec![
Field::new("id", DataType::UInt8, true),
Field::new("name", DataType::UInt8, true),
]));
let t2_data = RecordBatch::new_empty(t2_schema);
ctx.register_batch("t2", t2_data)?;

Ok(ctx)
}

fn get_tpch_table_schema(table: &str) -> Schema {
match table {
"customer" => Schema::new(vec![
Expand Down
82 changes: 82 additions & 0 deletions datafusion/core/tests/sql/union.rs
Original file line number Diff line number Diff line change
Expand Up @@ -140,3 +140,85 @@ async fn union_schemas() -> Result<()> {
assert_batches_eq!(expected, &result);
Ok(())
}

#[tokio::test]
async fn union_with_except_input() -> Result<()> {
let ctx = create_union_context()?;
let sql = "(
SELECT name FROM t1
EXCEPT
SELECT name FROM t2
)
UNION ALL
(
SELECT name FROM t2
EXCEPT
SELECT name FROM t1
)";
let msg = format!("Creating logical plan for '{sql}'");
let dataframe = ctx.sql(&("explain ".to_owned() + sql)).await.expect(&msg);
let plan = dataframe.into_optimized_plan()?;
let expected = vec![
"Explain [plan_type:Utf8, plan:Utf8]",
" Union [name:UInt8;N]",
" LeftAnti Join: t1.name = t2.name [name:UInt8;N]",
" Distinct: [name:UInt8;N]",
" TableScan: t1 projection=[name] [name:UInt8;N]",
" Projection: t2.name [name:UInt8;N]",
" TableScan: t2 projection=[name] [name:UInt8;N]",
" LeftAnti Join: t2.name = t1.name [name:UInt8;N]",
" Distinct: [name:UInt8;N]",
" TableScan: t2 projection=[name] [name:UInt8;N]",
" Projection: t1.name [name:UInt8;N]",
" TableScan: t1 projection=[name] [name:UInt8;N]",
];

let formatted = plan.display_indent_schema().to_string();
let actual: Vec<&str> = formatted.trim().lines().collect();
assert_eq!(
expected, actual,
"\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n"
);
Ok(())
}

#[tokio::test]
async fn union_with_type_coercion() -> Result<()> {
let ctx = create_union_context()?;
let sql = "(
SELECT id, name FROM t1
EXCEPT
SELECT id, name FROM t2
)
UNION ALL
(
SELECT id, name FROM t2
EXCEPT
SELECT id, name FROM t1
)";
let msg = format!("Creating logical plan for '{sql}'");
let dataframe = ctx.sql(&("explain ".to_owned() + sql)).await.expect(&msg);
let plan = dataframe.into_optimized_plan()?;
let expected = vec![
"Explain [plan_type:Utf8, plan:Utf8]",
" Union [id:Int32;N, name:UInt8;N]",
" LeftAnti Join: t1.id = CAST(t2.id AS Int32), t1.name = t2.name [id:Int32;N, name:UInt8;N]",
" Distinct: [id:Int32;N, name:UInt8;N]",
" TableScan: t1 projection=[id, name] [id:Int32;N, name:UInt8;N]",
" Projection: t2.id, t2.name [id:UInt8;N, name:UInt8;N]",
" TableScan: t2 projection=[id, name] [id:UInt8;N, name:UInt8;N]",
" Projection: CAST(t2.id AS Int32) AS id, t2.name [id:Int32;N, name:UInt8;N]",
" LeftAnti Join: CAST(t2.id AS Int32) = t1.id, t2.name = t1.name [id:UInt8;N, name:UInt8;N]",
" Distinct: [id:UInt8;N, name:UInt8;N]",
" TableScan: t2 projection=[id, name] [id:UInt8;N, name:UInt8;N]",
" Projection: t1.id, t1.name [id:Int32;N, name:UInt8;N]",
" TableScan: t1 projection=[id, name] [id:Int32;N, name:UInt8;N]",
];
let formatted = plan.display_indent_schema().to_string();
let actual: Vec<&str> = formatted.trim().lines().collect();
assert_eq!(
expected, actual,
"\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n"
);
Ok(())
}
78 changes: 78 additions & 0 deletions datafusion/core/tests/sqllogictests/test_files/union.slt
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

##########
## UNION Tests
##########

statement ok
CREATE TABLE t1(
id INT,
name TEXT,
) as VALUES
(1, 'Alex'),
(2, 'Bob'),
(3, 'Alice')
;

statement ok
CREATE TABLE t2(
id TINYINT,
name TEXT,
) as VALUES
(1, 'Alex'),
(2, 'Bob'),
(3, 'John')
;

# union with EXCEPT(JOIN)
query T
(
SELECT name FROM t1
EXCEPT
SELECT name FROM t2
)
UNION ALL
(
SELECT name FROM t2
EXCEPT
SELECT name FROM t1
)
ORDER BY name
----
Alice
John



# union with type coercion
query T
(
SELECT * FROM t1
EXCEPT
SELECT * FROM t2
)
UNION ALL
(
SELECT * FROM t2
EXCEPT
SELECT * FROM t1
)
ORDER BY name
----
3 Alice
3 John
63 changes: 45 additions & 18 deletions datafusion/expr/src/expr_rewriter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ use crate::expr::{
Like, Sort, TryCast, WindowFunction,
};
use crate::logical_plan::{Aggregate, Projection};
use crate::utils::{from_plan, grouping_set_to_exprlist};
use crate::utils::grouping_set_to_exprlist;
use crate::{Expr, ExprSchemable, LogicalPlan};
use datafusion_common::Result;
use datafusion_common::{Column, DFSchema};
Expand Down Expand Up @@ -525,29 +525,56 @@ pub fn coerce_plan_expr_for_schema(
plan: &LogicalPlan,
schema: &DFSchema,
) -> Result<LogicalPlan> {
let new_expr = plan
.expressions()
match plan {
// special case Projection to avoid adding multiple projections
LogicalPlan::Projection(Projection { expr, input, .. }) => {
let new_exprs =
coerce_exprs_for_schema(expr.clone(), input.schema(), schema)?;
let projection = Projection::try_new(new_exprs, input.clone())?;
Ok(LogicalPlan::Projection(projection))
}
_ => {
let exprs: Vec<Expr> = plan
.schema()
.fields()
.iter()
.map(|field| Expr::Column(field.qualified_column()))
.collect();

let new_exprs = coerce_exprs_for_schema(exprs, plan.schema(), schema)?;
let add_project = new_exprs.iter().any(|expr| expr.try_into_col().is_err());
if add_project {
let projection = Projection::try_new(new_exprs, Arc::new(plan.clone()))?;
Ok(LogicalPlan::Projection(projection))
} else {
Ok(plan.clone())
}
}
}
}

fn coerce_exprs_for_schema(
exprs: Vec<Expr>,
src_schema: &DFSchema,
dst_schema: &DFSchema,
) -> Result<Vec<Expr>> {
exprs
.into_iter()
.enumerate()
.map(|(i, expr)| {
let new_type = schema.field(i).data_type();
if plan.schema().field(i).data_type() != schema.field(i).data_type() {
match (plan, &expr) {
(
LogicalPlan::Projection(Projection { input, .. }),
Expr::Alias(e, alias),
) => Ok(e.clone().cast_to(new_type, input.schema())?.alias(alias)),
_ => expr.cast_to(new_type, plan.schema()),
.map(|(idx, expr)| {
let new_type = dst_schema.field(idx).data_type();
if new_type != &expr.get_type(src_schema)? {
match expr {
Expr::Alias(e, alias) => {
Ok(e.cast_to(new_type, src_schema)?.alias(alias))
}
_ => expr.cast_to(new_type, src_schema),
}
} else {
Ok(expr)
Ok(expr.clone())
}
})
.collect::<Result<Vec<_>>>()?;

let new_inputs = plan.inputs().into_iter().cloned().collect::<Vec<_>>();

from_plan(plan, &new_expr, &new_inputs)
.collect::<Result<_>>()
}

#[cfg(test)]
Expand Down
2 changes: 1 addition & 1 deletion datafusion/expr/src/logical_plan/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1016,7 +1016,7 @@ pub fn union(left_plan: LogicalPlan, right_plan: LogicalPlan) -> Result<LogicalP
match plan {
LogicalPlan::Projection(Projection { expr, input, .. }) => {
Ok(Arc::new(project_with_column_index(
expr.to_vec(),
expr,
input,
Arc::new(union_schema.clone()),
)?))
Expand Down

0 comments on commit eb19a67

Please sign in to comment.