From 2521043ddcb3895a2010b8e328f3fa10f77fc094 Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Thu, 8 Aug 2024 12:27:25 +0100 Subject: [PATCH] support `ANY()` op (#11849) * support ANY() op * use ExprPlanner * revert test changes * add planner tests * minimise diff * fix tests :fingers_crossed: * move error test to slt --- datafusion/expr/src/planner.rs | 7 +++ datafusion/functions-nested/src/array_has.rs | 22 +++++----- datafusion/functions-nested/src/planner.rs | 20 +++++++-- datafusion/sql/src/expr/mod.rs | 46 ++++++++++++++++---- datafusion/sqllogictest/test_files/array.slt | 19 ++++++++ 5 files changed, 92 insertions(+), 22 deletions(-) diff --git a/datafusion/expr/src/planner.rs b/datafusion/expr/src/planner.rs index c775427df138..24f589c41582 100644 --- a/datafusion/expr/src/planner.rs +++ b/datafusion/expr/src/planner.rs @@ -197,6 +197,13 @@ pub trait ExprPlanner: Send + Sync { "Default planner compound identifier hasn't been implemented for ExprPlanner" ) } + + /// Plans `ANY` expression, e.g., `expr = ANY(array_expr)` + /// + /// Returns origin binary expression if not possible + fn plan_any(&self, expr: RawBinaryExpr) -> Result> { + Ok(PlannerResult::Original(expr)) + } } /// An operator with two arguments to plan diff --git a/datafusion/functions-nested/src/array_has.rs b/datafusion/functions-nested/src/array_has.rs index bdda5a565947..fe1df2579932 100644 --- a/datafusion/functions-nested/src/array_has.rs +++ b/datafusion/functions-nested/src/array_has.rs @@ -34,19 +34,19 @@ use std::sync::Arc; // Create static instances of ScalarUDFs for each function make_udf_expr_and_func!(ArrayHas, array_has, - first_array second_array, // arg name + haystack_array element, // arg names "returns true, if the element appears in the first array, otherwise false.", // doc array_has_udf // internal function name ); make_udf_expr_and_func!(ArrayHasAll, array_has_all, - first_array second_array, // arg name + haystack_array needle_array, // arg names "returns true if each element of the second array appears in the first array; otherwise, it returns false.", // doc array_has_all_udf // internal function name ); make_udf_expr_and_func!(ArrayHasAny, array_has_any, - first_array second_array, // arg name + haystack_array needle_array, // arg names "returns true if at least one element of the second array appears in the first array; otherwise, it returns false.", // doc array_has_any_udf // internal function name ); @@ -262,26 +262,26 @@ enum ComparisonType { } fn general_array_has_dispatch( - array: &ArrayRef, - sub_array: &ArrayRef, + haystack: &ArrayRef, + needle: &ArrayRef, comparison_type: ComparisonType, ) -> Result { let array = if comparison_type == ComparisonType::Single { - let arr = as_generic_list_array::(array)?; - check_datatypes("array_has", &[arr.values(), sub_array])?; + let arr = as_generic_list_array::(haystack)?; + check_datatypes("array_has", &[arr.values(), needle])?; arr } else { - check_datatypes("array_has", &[array, sub_array])?; - as_generic_list_array::(array)? + check_datatypes("array_has", &[haystack, needle])?; + as_generic_list_array::(haystack)? }; let mut boolean_builder = BooleanArray::builder(array.len()); let converter = RowConverter::new(vec![SortField::new(array.value_type())])?; - let element = Arc::clone(sub_array); + let element = Arc::clone(needle); let sub_array = if comparison_type != ComparisonType::Single { - as_generic_list_array::(sub_array)? + as_generic_list_array::(needle)? } else { array }; diff --git a/datafusion/functions-nested/src/planner.rs b/datafusion/functions-nested/src/planner.rs index f980362105a1..4cd8faa3ca98 100644 --- a/datafusion/functions-nested/src/planner.rs +++ b/datafusion/functions-nested/src/planner.rs @@ -17,7 +17,7 @@ //! SQL planning extensions like [`NestedFunctionPlanner`] and [`FieldAccessPlanner`] -use datafusion_common::{exec_err, utils::list_ndims, DFSchema, Result}; +use datafusion_common::{plan_err, utils::list_ndims, DFSchema, Result}; use datafusion_expr::expr::ScalarFunction; use datafusion_expr::{ planner::{ExprPlanner, PlannerResult, RawBinaryExpr, RawFieldAccessExpr}, @@ -28,7 +28,7 @@ use datafusion_functions_aggregate::nth_value::nth_value_udaf; use crate::map::map_udf; use crate::{ - array_has::array_has_all, + array_has::{array_has_all, array_has_udf}, expr_fn::{array_append, array_concat, array_prepend}, extract::{array_element, array_slice}, make_array::make_array, @@ -102,7 +102,7 @@ impl ExprPlanner for NestedFunctionPlanner { fn plan_make_map(&self, args: Vec) -> Result>> { if args.len() % 2 != 0 { - return exec_err!("make_map requires an even number of arguments"); + return plan_err!("make_map requires an even number of arguments"); } let (keys, values): (Vec<_>, Vec<_>) = @@ -114,6 +114,20 @@ impl ExprPlanner for NestedFunctionPlanner { ScalarFunction::new_udf(map_udf(), vec![keys, values]), ))) } + + fn plan_any(&self, expr: RawBinaryExpr) -> Result> { + if expr.op == sqlparser::ast::BinaryOperator::Eq { + Ok(PlannerResult::Planned(Expr::ScalarFunction( + ScalarFunction::new_udf( + array_has_udf(), + // left and right are reversed here so `needle=any(haystack)` -> `array_has(haystack, needle)` + vec![expr.right, expr.left], + ), + ))) + } else { + plan_err!("Unsupported AnyOp: '{}', only '=' is supported", expr.op) + } + } } pub struct FieldAccessPlanner; diff --git a/datafusion/sql/src/expr/mod.rs b/datafusion/sql/src/expr/mod.rs index b80ffb6aed3f..edb0002842a8 100644 --- a/datafusion/sql/src/expr/mod.rs +++ b/datafusion/sql/src/expr/mod.rs @@ -17,12 +17,12 @@ use arrow_schema::DataType; use arrow_schema::TimeUnit; -use datafusion_expr::planner::PlannerResult; -use datafusion_expr::planner::RawDictionaryExpr; -use datafusion_expr::planner::RawFieldAccessExpr; +use datafusion_expr::planner::{ + PlannerResult, RawBinaryExpr, RawDictionaryExpr, RawFieldAccessExpr, +}; use sqlparser::ast::{ - CastKind, DictionaryField, Expr as SQLExpr, MapEntry, StructField, Subscript, - TrimWhereField, Value, + BinaryOperator, CastKind, DictionaryField, Expr as SQLExpr, MapEntry, StructField, + Subscript, TrimWhereField, Value, }; use datafusion_common::{ @@ -104,13 +104,13 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { fn build_logical_expr( &self, - op: sqlparser::ast::BinaryOperator, + op: BinaryOperator, left: Expr, right: Expr, schema: &DFSchema, ) -> Result { // try extension planers - let mut binary_expr = datafusion_expr::planner::RawBinaryExpr { op, left, right }; + let mut binary_expr = RawBinaryExpr { op, left, right }; for planner in self.context_provider.get_expr_planners() { match planner.plan_binary_op(binary_expr, schema)? { PlannerResult::Planned(expr) => { @@ -122,7 +122,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } } - let datafusion_expr::planner::RawBinaryExpr { op, left, right } = binary_expr; + let RawBinaryExpr { op, left, right } = binary_expr; Ok(Expr::BinaryExpr(BinaryExpr::new( Box::new(left), self.parse_sql_binary_op(op)?, @@ -631,6 +631,36 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { SQLExpr::Map(map) => { self.try_plan_map_literal(map.entries, schema, planner_context) } + SQLExpr::AnyOp { + left, + compare_op, + right, + } => { + let mut binary_expr = RawBinaryExpr { + op: compare_op, + left: self.sql_expr_to_logical_expr( + *left, + schema, + planner_context, + )?, + right: self.sql_expr_to_logical_expr( + *right, + schema, + planner_context, + )?, + }; + for planner in self.context_provider.get_expr_planners() { + match planner.plan_any(binary_expr)? { + PlannerResult::Planned(expr) => { + return Ok(expr); + } + PlannerResult::Original(expr) => { + binary_expr = expr; + } + } + } + not_impl_err!("AnyOp not supported by ExprPlanner: {binary_expr:?}") + } _ => not_impl_err!("Unsupported ast node in sqltorel: {sql:?}"), } } diff --git a/datafusion/sqllogictest/test_files/array.slt b/datafusion/sqllogictest/test_files/array.slt index f2972e4c14c2..b71bc765ba37 100644 --- a/datafusion/sqllogictest/test_files/array.slt +++ b/datafusion/sqllogictest/test_files/array.slt @@ -5351,6 +5351,25 @@ true false true false false false true true false false true false true #---- #true false true false false false true true false false true false true +# any operator +query ? +select column3 from arrays where 'L'=any(column3); +---- +[L, o, r, e, m] + +query I +select count(*) from arrays where 'L'=any(column3); +---- +1 + +query I +select count(*) from arrays where 'X'=any(column3); +---- +0 + +query error DataFusion error: Error during planning: Unsupported AnyOp: '>', only '=' is supported +select count(*) from arrays where 'X'>any(column3); + ## array_distinct #TODO: https://github.com/apache/datafusion/issues/7142