Skip to content

Commit

Permalink
Replace placeholders in ScalarSubqueries (#5216)
Browse files Browse the repository at this point in the history
* Failing subquery test

* Fix test

* fmt
  • Loading branch information
avantgardnerio authored Feb 8, 2023
1 parent f0c6719 commit b4cf60a
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 24 deletions.
57 changes: 33 additions & 24 deletions datafusion/expr/src/logical_plan/plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ use crate::expr_visitor::{ExprVisitable, ExpressionVisitor, Recursion};
use crate::logical_plan::builder::validate_unique_names;
use crate::logical_plan::display::{GraphvizVisitor, IndentVisitor};
use crate::logical_plan::extension::UserDefinedLogicalNode;
use crate::logical_plan::plan;
use crate::utils::{
self, exprlist_to_fields, from_plan, grouping_set_expr_count,
grouping_set_to_exprlist,
Expand Down Expand Up @@ -710,31 +711,39 @@ impl LogicalPlan {
param_values: &[ScalarValue],
) -> Result<Expr, DataFusionError> {
rewrite_expr(expr, |expr| {
if let Expr::Placeholder { id, data_type } = &expr {
// convert id (in format $1, $2, ..) to idx (0, 1, ..)
let idx = id[1..].parse::<usize>().map_err(|e| {
DataFusionError::Internal(format!(
"Failed to parse placeholder id: {e}"
))
})? - 1;
// value at the idx-th position in param_values should be the value for the placeholder
let value = param_values.get(idx).ok_or_else(|| {
DataFusionError::Internal(format!(
"No value found for placeholder with id {id}"
))
})?;
// check if the data type of the value matches the data type of the placeholder
if Some(value.get_datatype()) != *data_type {
return Err(DataFusionError::Internal(format!(
"Placeholder value type mismatch: expected {:?}, got {:?}",
data_type,
value.get_datatype()
)));
match &expr {
Expr::Placeholder { id, data_type } => {
// convert id (in format $1, $2, ..) to idx (0, 1, ..)
let idx = id[1..].parse::<usize>().map_err(|e| {
DataFusionError::Internal(format!(
"Failed to parse placeholder id: {e}"
))
})? - 1;
// value at the idx-th position in param_values should be the value for the placeholder
let value = param_values.get(idx).ok_or_else(|| {
DataFusionError::Internal(format!(
"No value found for placeholder with id {id}"
))
})?;
// check if the data type of the value matches the data type of the placeholder
if Some(value.get_datatype()) != *data_type {
return Err(DataFusionError::Internal(format!(
"Placeholder value type mismatch: expected {:?}, got {:?}",
data_type,
value.get_datatype()
)));
}
// Replace the placeholder with the value
Ok(Expr::Literal(value.clone()))
}
// Replace the placeholder with the value
Ok(Expr::Literal(value.clone()))
} else {
Ok(expr)
Expr::ScalarSubquery(qry) => {
let subquery = Arc::new(
qry.subquery
.replace_params_with_values(&param_values.to_vec())?,
);
Ok(Expr::ScalarSubquery(plan::Subquery { subquery }))
}
_ => Ok(expr),
}
})
}
Expand Down
41 changes: 41 additions & 0 deletions datafusion/sql/tests/integration_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3527,6 +3527,47 @@ Projection: person.id, person.age
prepare_stmt_replace_params_quick_test(plan, param_values, expected_plan);
}

#[test]
fn test_prepare_statement_infer_types_subquery() {
let sql = "SELECT id, age FROM person WHERE age = (select max(age) from person where id = $1)";

let expected_plan = r#"
Projection: person.id, person.age
Filter: person.age = (<subquery>)
Subquery:
Projection: MAX(person.age)
Aggregate: groupBy=[[]], aggr=[[MAX(person.age)]]
Filter: person.id = $1
TableScan: person
TableScan: person
"#
.trim();

let expected_dt = "[Int32]";
let plan = prepare_stmt_quick_test(sql, expected_plan, expected_dt);

let actual_types = plan.get_parameter_types().unwrap();
let expected_types = HashMap::from([("$1".to_string(), Some(DataType::UInt32))]);
assert_eq!(actual_types, expected_types);

// replace params with values
let param_values = vec![ScalarValue::UInt32(Some(10))];
let expected_plan = r#"
Projection: person.id, person.age
Filter: person.age = (<subquery>)
Subquery:
Projection: MAX(person.age)
Aggregate: groupBy=[[]], aggr=[[MAX(person.age)]]
Filter: person.id = UInt32(10)
TableScan: person
TableScan: person
"#
.trim();
let plan = plan.replace_params_with_values(&param_values).unwrap();

prepare_stmt_replace_params_quick_test(plan, param_values, expected_plan);
}

#[test]
fn test_prepare_statement_update_infer() {
let sql = "update person set age=$1 where id=$2";
Expand Down

0 comments on commit b4cf60a

Please sign in to comment.