Skip to content

Commit

Permalink
fix: generate logical plan for UPDATE SET FROM statement (apache#7984)
Browse files Browse the repository at this point in the history
  • Loading branch information
jonahgao committed Oct 31, 2023
1 parent 747cb50 commit 656c6a9
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 32 deletions.
68 changes: 36 additions & 32 deletions datafusion/sql/src/statement.rs
Original file line number Diff line number Diff line change
Expand Up @@ -978,10 +978,9 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
// Do a table lookup to verify the table exists
let table_name = self.object_name_to_table_reference(table_name)?;
let table_source = self.context_provider.get_table_source(table_name.clone())?;
let arrow_schema = (*table_source.schema()).clone();
let table_schema = Arc::new(DFSchema::try_from_qualified_schema(
table_name.clone(),
&arrow_schema,
&table_source.schema(),
)?);

// Overwrite with assignment expressions
Expand All @@ -1000,55 +999,60 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
})
.collect::<Result<HashMap<String, Expr>>>()?;

let values_and_types = table_schema
.fields()
.iter()
.map(|f| {
let col_name = f.name();
let val = assign_map.remove(col_name).unwrap_or_else(|| {
ast::Expr::Identifier(ast::Ident::from(col_name.as_str()))
});
(col_name, val, f.data_type())
})
.collect::<Vec<_>>();

// Build scan
let from = from.unwrap_or(table);
let scan = self.plan_from_tables(vec![from], &mut planner_context)?;
// Build scan, join with from table if it exists.
let mut input_tables = vec![table];
input_tables.extend(from);
let scan = self.plan_from_tables(input_tables, &mut planner_context)?;

// Filter
let source = match predicate_expr {
None => scan,
Some(predicate_expr) => {
let filter_expr = self.sql_to_expr(
predicate_expr,
&table_schema,
scan.schema(),
&mut planner_context,
)?;
let mut using_columns = HashSet::new();
expr_to_columns(&filter_expr, &mut using_columns)?;
let filter_expr = normalize_col_with_schemas_and_ambiguity_check(
filter_expr,
&[&[&table_schema]],
&[&[&scan.schema()]],
&[using_columns],
)?;
LogicalPlan::Filter(Filter::try_new(filter_expr, Arc::new(scan))?)
}
};

// Projection
let mut exprs = vec![];
for (col_name, expr, dt) in values_and_types.into_iter() {
let mut expr = self.sql_to_expr(expr, &table_schema, &mut planner_context)?;
// Update placeholder's datatype to the type of the target column
if let datafusion_expr::Expr::Placeholder(placeholder) = &mut expr {
placeholder.data_type =
placeholder.data_type.take().or_else(|| Some(dt.clone()));
}
// Cast to target column type, if necessary
let expr = expr.cast_to(dt, source.schema())?.alias(col_name);
exprs.push(expr);
}
// Build updated values for each column, using the previous value if not modified
let exprs = table_schema
.fields()
.iter()
.map(|field| {
let expr = match assign_map.remove(field.name()) {
Some(new_value) => {
let mut expr = self.sql_to_expr(
new_value,
source.schema(),
&mut planner_context,
)?;
// Update placeholder's datatype to the type of the target column
if let datafusion_expr::Expr::Placeholder(placeholder) = &mut expr
{
placeholder.data_type = placeholder
.data_type
.take()
.or_else(|| Some(field.data_type().clone()));
}
// Cast to target column type, if necessary
expr.cast_to(field.data_type(), source.schema())?
}
None => datafusion_expr::Expr::Column(field.qualified_column()),
};
Ok(expr.alias(field.name()))
})
.collect::<Result<Vec<_>>>()?;

let source = project(source, exprs)?;

let plan = LogicalPlan::Dml(DmlStatement {
Expand Down
36 changes: 36 additions & 0 deletions datafusion/sqllogictest/test_files/update.slt
Original file line number Diff line number Diff line change
Expand Up @@ -41,3 +41,39 @@ logical_plan
Dml: op=[Update] table=[t1]
--Projection: CAST(t1.c + CAST(Int64(1) AS Float64) AS Int32) AS a, CAST(t1.a AS Utf8) AS b, t1.c + Float64(1) AS c, CAST(t1.b AS Int32) AS d
----TableScan: t1

statement ok
create table t2(a int, b varchar, c double, d int);

## set from subquery
query TT
explain update t1 set b = (select max(b) from t2 where t1.a = t2.a)
----
logical_plan
Dml: op=[Update] table=[t1]
--Projection: t1.a AS a, (<subquery>) AS b, t1.c AS c, t1.d AS d
----Subquery:
------Projection: MAX(t2.b)
--------Aggregate: groupBy=[[]], aggr=[[MAX(t2.b)]]
----------Filter: outer_ref(t1.a) = t2.a
------------TableScan: t2
----TableScan: t1

# set from other table
query TT
explain update t1 set b = t2.b, c = t2.a, d = 1 from t2 where t1.a = t2.a and t1.b > 'foo' and t2.c > 1.0;
----
logical_plan
Dml: op=[Update] table=[t1]
--Projection: t1.a AS a, t2.b AS b, CAST(t2.a AS Float64) AS c, CAST(Int64(1) AS Int32) AS d
----Filter: t1.a = t2.a AND t1.b > Utf8("foo") AND t2.c > Float64(1)
------CrossJoin:
--------TableScan: t1
--------TableScan: t2

statement ok
create table t3(a int, b varchar, c double, d int);

# set from mutiple tables, sqlparser only supports from one table
query error DataFusion error: SQL error: ParserError\("Expected end of statement, found: ,"\)
explain update t1 set b = t2.b, c = t3.a, d = 1 from t2, t3 where t1.a = t2.a and t1.a = t3.a;

0 comments on commit 656c6a9

Please sign in to comment.