diff --git a/datafusion/optimizer/src/analyzer/function_rewrite.rs b/datafusion/optimizer/src/analyzer/function_rewrite.rs index 78f65c5b82ab..70cb54c24f6c 100644 --- a/datafusion/optimizer/src/analyzer/function_rewrite.rs +++ b/datafusion/optimizer/src/analyzer/function_rewrite.rs @@ -21,9 +21,10 @@ use super::AnalyzerRule; use datafusion_common::config::ConfigOptions; use datafusion_common::tree_node::{Transformed, TreeNodeRewriter}; use datafusion_common::{DFSchema, Result}; +use datafusion_expr::expr::{Exists, InSubquery}; use datafusion_expr::expr_rewriter::{rewrite_preserving_name, FunctionRewrite}; use datafusion_expr::utils::merge_schema; -use datafusion_expr::{Expr, LogicalPlan}; +use datafusion_expr::{Expr, LogicalPlan, Subquery}; use std::sync::Arc; /// Analyzer rule that invokes [`FunctionRewrite`]s on expressions @@ -45,54 +46,66 @@ impl AnalyzerRule for ApplyFunctionRewrites { } fn analyze(&self, plan: LogicalPlan, options: &ConfigOptions) -> Result { - self.analyze_internal(&plan, options) + analyze_internal(&plan, &self.function_rewrites, options) } } -impl ApplyFunctionRewrites { - fn analyze_internal( - &self, - plan: &LogicalPlan, - options: &ConfigOptions, - ) -> Result { - // optimize child plans first - let new_inputs = plan - .inputs() - .iter() - .map(|p| self.analyze_internal(p, options)) - .collect::>>()?; - - // get schema representing all available input fields. This is used for data type - // resolution only, so order does not matter here - let mut schema = merge_schema(new_inputs.iter().collect()); - - if let LogicalPlan::TableScan(ts) = plan { - let source_schema = DFSchema::try_from_qualified_schema( - ts.table_name.clone(), - &ts.source.schema(), - )?; - schema.merge(&source_schema); - } +fn analyze_internal( + plan: &LogicalPlan, + function_rewrites: &[Arc], + options: &ConfigOptions, +) -> Result { + // optimize child plans first + let new_inputs = plan + .inputs() + .iter() + .map(|p| analyze_internal(p, function_rewrites, options)) + .collect::>>()?; - let mut expr_rewrite = OperatorToFunctionRewriter { - function_rewrites: &self.function_rewrites, - options, - schema: &schema, - }; + // get schema representing all available input fields. This is used for data type + // resolution only, so order does not matter here + let mut schema = merge_schema(new_inputs.iter().collect()); - let new_expr = plan - .expressions() - .into_iter() - .map(|expr| { - // ensure names don't change: - // https://github.com/apache/arrow-datafusion/issues/3555 - rewrite_preserving_name(expr, &mut expr_rewrite) - }) - .collect::>>()?; - - plan.with_new_exprs(new_expr, new_inputs) + if let LogicalPlan::TableScan(ts) = plan { + let source_schema = DFSchema::try_from_qualified_schema( + ts.table_name.clone(), + &ts.source.schema(), + )?; + schema.merge(&source_schema); } + + let mut expr_rewrite = OperatorToFunctionRewriter { + function_rewrites, + options, + schema: &schema, + }; + + let new_expr = plan + .expressions() + .into_iter() + .map(|expr| { + // ensure names don't change: + // https://github.com/apache/arrow-datafusion/issues/3555 + rewrite_preserving_name(expr, &mut expr_rewrite) + }) + .collect::>>()?; + + plan.with_new_exprs(new_expr, new_inputs) } + +fn rewrite_subquery( + mut subquery: Subquery, + function_rewrites: &[Arc], + options: &ConfigOptions, +) -> Result { + subquery.subquery = Arc::new(analyze_internal( + &subquery.subquery, + function_rewrites, + options, + )?); + Ok(subquery) +} + struct OperatorToFunctionRewriter<'a> { function_rewrites: &'a [Arc], options: &'a ConfigOptions, @@ -113,6 +126,40 @@ impl<'a> TreeNodeRewriter for OperatorToFunctionRewriter<'a> { expr = result.data } + // recurse into subqueries if needed + let expr = match expr { + Expr::ScalarSubquery(subquery) => Expr::ScalarSubquery(rewrite_subquery( + subquery, + self.function_rewrites, + self.options, + )?), + + Expr::Exists(Exists { subquery, negated }) => Expr::Exists(Exists { + subquery: rewrite_subquery( + subquery, + self.function_rewrites, + self.options, + )?, + negated, + }), + + Expr::InSubquery(InSubquery { + expr, + subquery, + negated, + }) => Expr::InSubquery(InSubquery { + expr, + subquery: rewrite_subquery( + subquery, + self.function_rewrites, + self.options, + )?, + negated, + }), + + expr => expr, + }; + Ok(if transformed { Transformed::yes(expr) } else { diff --git a/datafusion/sqllogictest/test_files/subquery.slt b/datafusion/sqllogictest/test_files/subquery.slt index cc6428e51435..1ae89c9159f8 100644 --- a/datafusion/sqllogictest/test_files/subquery.slt +++ b/datafusion/sqllogictest/test_files/subquery.slt @@ -1060,3 +1060,58 @@ logical_plan Projection: t.a / Int64(2)Int64(2)t.a AS t.a / Int64(2), t.a / Int64(2)Int64(2)t.a AS t.a / Int64(2) + Int64(1) --Projection: t.a / Int64(2) AS t.a / Int64(2)Int64(2)t.a ----TableScan: t projection=[a] + +### +## Ensure that operators are rewritten in subqueries +### + +statement ok +create table foo(x int) as values (1); + +# Show input data +query ? +select struct(1, 'b') +---- +{c0: 1, c1: b} + + +query T +select (select struct(1, 'b')['c1']); +---- +b + +query T +select 'foo' || (select struct(1, 'b')['c1']); +---- +foob + +query I +SELECT * FROM (VALUES (1), (2)) +WHERE column1 IN (SELECT struct(1, 'b')['c0']); +---- +1 + +# also add an expression so the subquery is the output expr +query I +SELECT * FROM (VALUES (1), (2)) +WHERE 1+2 = 3 AND column1 IN (SELECT struct(1, 'b')['c0']); +---- +1 + + +query I +SELECT * FROM foo +WHERE EXISTS (SELECT * FROM (values (1)) WHERE column1 = foo.x AND struct(1, 'b')['c0'] = 1); +---- +1 + +# also add an expression so the subquery is the output expr +query I +SELECT * FROM foo +WHERE 1+2 = 3 AND EXISTS (SELECT * FROM (values (1)) WHERE column1 = foo.x AND struct(1, 'b')['c0'] = 1); +---- +1 + + +statement ok +drop table foo;