diff --git a/datafusion/core/src/execution/context/mod.rs b/datafusion/core/src/execution/context/mod.rs index 9ec0148d9122..89a50e5749bf 100644 --- a/datafusion/core/src/execution/context/mod.rs +++ b/datafusion/core/src/execution/context/mod.rs @@ -60,7 +60,7 @@ use datafusion_execution::registry::SerializerRegistry; use datafusion_expr::{ expr_rewriter::FunctionRewrite, logical_plan::{DdlStatement, Statement}, - Expr, UserDefinedLogicalNode, WindowUDF, + Expr, ParseCustomOperator, UserDefinedLogicalNode, WindowUDF, }; // backwards compatibility @@ -1390,6 +1390,19 @@ impl FunctionRegistry for SessionContext { ) -> Result<()> { self.state.write().register_function_rewrite(rewrite) } + + fn register_parse_custom_operator( + &mut self, + parse_custom_operator: Arc, + ) -> Result<()> { + self.state + .write() + .register_parse_custom_operator(parse_custom_operator) + } + + fn parse_custom_operators(&self) -> Vec> { + self.state.read().parse_custom_operators() + } } /// Create a new task context instance from SessionContext diff --git a/datafusion/core/src/execution/session_state.rs b/datafusion/core/src/execution/session_state.rs index d2bac134b54a..f500076193a6 100644 --- a/datafusion/core/src/execution/session_state.rs +++ b/datafusion/core/src/execution/session_state.rs @@ -56,8 +56,8 @@ use datafusion_expr::registry::{FunctionRegistry, SerializerRegistry}; use datafusion_expr::simplify::SimplifyInfo; use datafusion_expr::var_provider::{is_system_variables, VarType}; use datafusion_expr::{ - AggregateUDF, Explain, Expr, ExprSchemable, LogicalPlan, ScalarUDF, TableSource, - WindowUDF, + AggregateUDF, Explain, Expr, ExprSchemable, LogicalPlan, ParseCustomOperator, + ScalarUDF, TableSource, WindowUDF, }; use datafusion_optimizer::simplify_expressions::ExprSimplifier; use datafusion_optimizer::{ @@ -91,6 +91,8 @@ pub struct SessionState { session_id: String, /// Responsible for analyzing and rewrite a logical plan before optimization analyzer: Analyzer, + /// Provides support for parsing custom SQL operators, e.g. `->>` or `?` + parse_custom_operators: Vec>, /// Responsible for optimizing a logical plan optimizer: Optimizer, /// Responsible for optimizing a physical execution plan @@ -221,6 +223,7 @@ impl SessionState { let mut new_self = SessionState { session_id, analyzer: Analyzer::new(), + parse_custom_operators: vec![], optimizer: Optimizer::new(), physical_optimizers: PhysicalOptimizer::new(), query_planner: Arc::new(DefaultQueryPlanner {}), @@ -543,7 +546,7 @@ impl SessionState { /// Convert an AST Statement into a LogicalPlan pub async fn statement_to_plan( &self, - statement: datafusion_sql::parser::Statement, + statement: Statement, ) -> datafusion_common::Result { let references = self.resolve_table_references(&statement)?; @@ -573,6 +576,7 @@ impl SessionState { ParserOptions { parse_float_as_decimal: sql_parser_options.parse_float_as_decimal, enable_ident_normalization: sql_parser_options.enable_ident_normalization, + parse_custom_operator: self.parse_custom_operators(), } } @@ -1074,6 +1078,18 @@ impl FunctionRegistry for SessionState { self.analyzer.add_function_rewrite(rewrite); Ok(()) } + + fn register_parse_custom_operator( + &mut self, + parse_custom_operator: Arc, + ) -> datafusion_common::Result<()> { + self.parse_custom_operators.push(parse_custom_operator); + Ok(()) + } + + fn parse_custom_operators(&self) -> Vec> { + self.parse_custom_operators.clone() + } } impl OptimizerConfig for SessionState { diff --git a/datafusion/core/tests/user_defined/mod.rs b/datafusion/core/tests/user_defined/mod.rs index 6c6d966cc3aa..a893ca02b322 100644 --- a/datafusion/core/tests/user_defined/mod.rs +++ b/datafusion/core/tests/user_defined/mod.rs @@ -15,6 +15,8 @@ // specific language governing permissions and limitations // under the License. +/// Tests for custom operator parsing and substitution +mod user_defined_custom_operators; /// Tests for user defined Scalar functions mod user_defined_scalar_functions; diff --git a/datafusion/core/tests/user_defined/user_defined_custom_operators.rs b/datafusion/core/tests/user_defined/user_defined_custom_operators.rs new file mode 100644 index 000000000000..38ab7f615a9c --- /dev/null +++ b/datafusion/core/tests/user_defined/user_defined_custom_operators.rs @@ -0,0 +1,169 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow_array::RecordBatch; +use std::sync::Arc; + +use datafusion::arrow::datatypes::DataType; +use datafusion::common::config::ConfigOptions; +use datafusion::common::tree_node::Transformed; +use datafusion::common::{assert_batches_eq, DFSchema}; +use datafusion::error::Result; +use datafusion::execution::FunctionRegistry; +use datafusion::logical_expr::expr_rewriter::FunctionRewrite; +use datafusion::logical_expr::{ + CustomOperator, Operator, ParseCustomOperator, WrapCustomOperator, +}; +use datafusion::prelude::*; +use datafusion::sql::sqlparser::ast::BinaryOperator; + +#[derive(Debug)] +enum MyCustomOperator { + Arrow, + LongArrow, +} + +impl std::fmt::Display for MyCustomOperator { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + MyCustomOperator::Arrow => write!(f, "->"), + MyCustomOperator::LongArrow => write!(f, "->>"), + } + } +} + +impl CustomOperator for MyCustomOperator { + fn binary_signature( + &self, + lhs: &DataType, + rhs: &DataType, + ) -> Result<(DataType, DataType, DataType)> { + Ok((lhs.clone(), rhs.clone(), lhs.clone())) + } + + fn op_to_sql(&self) -> Result { + match self { + MyCustomOperator::Arrow => Ok(BinaryOperator::Arrow), + MyCustomOperator::LongArrow => Ok(BinaryOperator::LongArrow), + } + } + + fn name(&self) -> &'static str { + match self { + MyCustomOperator::Arrow => "Arrow", + MyCustomOperator::LongArrow => "LongArrow", + } + } +} + +impl TryFrom<&str> for MyCustomOperator { + type Error = (); + + fn try_from(value: &str) -> std::result::Result { + match value { + "Arrow" => Ok(MyCustomOperator::Arrow), + "LongArrow" => Ok(MyCustomOperator::LongArrow), + _ => Err(()), + } + } +} + +#[derive(Debug)] +struct CustomOperatorParser; + +impl ParseCustomOperator for CustomOperatorParser { + fn name(&self) -> &str { + "CustomOperatorParser" + } + + fn op_from_ast(&self, op: &BinaryOperator) -> Result> { + match op { + BinaryOperator::Arrow => Ok(Some(MyCustomOperator::Arrow.into())), + BinaryOperator::LongArrow => Ok(Some(MyCustomOperator::LongArrow.into())), + _ => Ok(None), + } + } + + fn op_from_name(&self, raw_op: &str) -> Result> { + if let Ok(op) = MyCustomOperator::try_from(raw_op) { + Ok(Some(op.into())) + } else { + Ok(None) + } + } +} + +impl FunctionRewrite for CustomOperatorParser { + fn name(&self) -> &str { + "CustomOperatorParser" + } + + fn rewrite( + &self, + expr: Expr, + _schema: &DFSchema, + _config: &ConfigOptions, + ) -> Result> { + if let Expr::BinaryExpr(bin_expr) = &expr { + if let Operator::Custom(WrapCustomOperator(op)) = &bin_expr.op { + if let Ok(pg_op) = MyCustomOperator::try_from(op.name()) { + // return BinaryExpr with a different operator + let mut bin_expr = bin_expr.clone(); + bin_expr.op = match pg_op { + MyCustomOperator::Arrow => Operator::StringConcat, + MyCustomOperator::LongArrow => Operator::Plus, + }; + return Ok(Transformed::yes(Expr::BinaryExpr(bin_expr))); + } + } + } + Ok(Transformed::no(expr)) + } +} + +async fn plan_and_collect(sql: &str) -> Result> { + let mut ctx = SessionContext::new(); + ctx.register_function_rewrite(Arc::new(CustomOperatorParser))?; + ctx.register_parse_custom_operator(Arc::new(CustomOperatorParser))?; + ctx.sql(sql).await?.collect().await +} + +#[tokio::test] +async fn test_custom_operators_arrow() { + let actual = plan_and_collect("select 'foo'->'bar';").await.unwrap(); + let expected = [ + "+----------------------------+", + "| Utf8(\"foo\") -> Utf8(\"bar\") |", + "+----------------------------+", + "| foobar |", + "+----------------------------+", + ]; + assert_batches_eq!(&expected, &actual); +} + +#[tokio::test] +async fn test_custom_operators_long_arrow() { + let actual = plan_and_collect("select 1->>2;").await.unwrap(); + let expected = [ + "+-----------------------+", + "| Int64(1) ->> Int64(2) |", + "+-----------------------+", + "| 3 |", + "+-----------------------+", + ]; + assert_batches_eq!(&expected, &actual); +} diff --git a/datafusion/expr/src/lib.rs b/datafusion/expr/src/lib.rs index 89ee94f9f845..c434da82bcbe 100644 --- a/datafusion/expr/src/lib.rs +++ b/datafusion/expr/src/lib.rs @@ -48,6 +48,7 @@ pub mod function; pub mod groups_accumulator; pub mod interval_arithmetic; pub mod logical_plan; +pub mod parse_custom_operator; pub mod registry; pub mod simplify; pub mod sort_properties; @@ -76,7 +77,8 @@ pub use function::{ pub use groups_accumulator::{EmitTo, GroupsAccumulator}; pub use literal::{lit, lit_timestamp_nano, Literal, TimestampLiteral}; pub use logical_plan::*; -pub use operator::Operator; +pub use operator::{CustomOperator, Operator, WrapCustomOperator}; +pub use parse_custom_operator::ParseCustomOperator; pub use partition_evaluator::PartitionEvaluator; pub use signature::{ ArrayFunctionSignature, Signature, TypeSignature, Volatility, TIMEZONE_WILDCARD, diff --git a/datafusion/expr/src/operator.rs b/datafusion/expr/src/operator.rs index 742511822a0f..77e2f5fda666 100644 --- a/datafusion/expr/src/operator.rs +++ b/datafusion/expr/src/operator.rs @@ -21,8 +21,13 @@ use crate::expr_fn::binary_expr; use crate::Expr; use crate::Like; use std::fmt; +use std::hash::{Hash, Hasher}; use std::ops; use std::ops::Not; +use std::sync::Arc; + +use arrow::datatypes::DataType; +use datafusion_common::Result; /// Operators applied to expressions #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash)] @@ -93,11 +98,13 @@ pub enum Operator { AtArrow, /// Arrow at, like `<@` ArrowAt, + /// Custom operator + Custom(WrapCustomOperator), } impl Operator { /// If the operator can be negated, return the negated operator - /// otherwise return None + /// otherwise return `None` pub fn negate(&self) -> Option { match self { Operator::Eq => Some(Operator::NotEq), @@ -131,51 +138,64 @@ impl Operator { | Operator::StringConcat | Operator::AtArrow | Operator::ArrowAt => None, + Operator::Custom(op) => op.0.negate(), } } /// Return true if the operator is a numerical operator. /// - /// For example, 'Binary(a, +, b)' would be a numerical expression. + /// For example, `Binary(a, +, b)` would be a numerical expression. /// PostgresSQL concept: pub fn is_numerical_operators(&self) -> bool { - matches!( - self, - Operator::Plus - | Operator::Minus - | Operator::Multiply - | Operator::Divide - | Operator::Modulo - ) + if let Self::Custom(op) = self { + op.0.is_numerical_operators() + } else { + matches!( + self, + Operator::Plus + | Operator::Minus + | Operator::Multiply + | Operator::Divide + | Operator::Modulo + ) + } } /// Return true if the operator is a comparison operator. /// - /// For example, 'Binary(a, >, b)' would be a comparison expression. + /// For example, `Binary(a, >, b)` would be a comparison expression. pub fn is_comparison_operator(&self) -> bool { - matches!( - self, - Operator::Eq - | Operator::NotEq - | Operator::Lt - | Operator::LtEq - | Operator::Gt - | Operator::GtEq - | Operator::IsDistinctFrom - | Operator::IsNotDistinctFrom - | Operator::RegexMatch - | Operator::RegexIMatch - | Operator::RegexNotMatch - | Operator::RegexNotIMatch - ) + if let Self::Custom(WrapCustomOperator(op)) = self { + op.is_comparison_operator() + } else { + matches!( + self, + Operator::Eq + | Operator::NotEq + | Operator::Lt + | Operator::LtEq + | Operator::Gt + | Operator::GtEq + | Operator::IsDistinctFrom + | Operator::IsNotDistinctFrom + | Operator::RegexMatch + | Operator::RegexIMatch + | Operator::RegexNotMatch + | Operator::RegexNotIMatch + ) + } } /// Return true if the operator is a logic operator. /// - /// For example, 'Binary(Binary(a, >, b), AND, Binary(a, <, b + 3))' would + /// For example, `Binary(Binary(a, >, b), AND, Binary(a, <, b + 3))` would /// be a logical expression. pub fn is_logic_operator(&self) -> bool { - matches!(self, Operator::And | Operator::Or) + if let Self::Custom(WrapCustomOperator(op)) = self { + op.is_logic_operator() + } else { + matches!(self, Operator::And | Operator::Or) + } } /// Return the operator where swapping lhs and rhs wouldn't change the result. @@ -214,6 +234,7 @@ impl Operator { | Operator::BitwiseShiftRight | Operator::BitwiseShiftLeft | Operator::StringConcat => None, + Operator::Custom(WrapCustomOperator(op)) => op.swap(), } } @@ -249,46 +270,143 @@ impl Operator { | Operator::StringConcat | Operator::AtArrow | Operator::ArrowAt => 0, + Operator::Custom(WrapCustomOperator(op)) => op.precedence(), } } } impl fmt::Display for Operator { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - let display = match &self { - Operator::Eq => "=", - Operator::NotEq => "!=", - Operator::Lt => "<", - Operator::LtEq => "<=", - Operator::Gt => ">", - Operator::GtEq => ">=", - Operator::Plus => "+", - Operator::Minus => "-", - Operator::Multiply => "*", - Operator::Divide => "/", - Operator::Modulo => "%", - Operator::And => "AND", - Operator::Or => "OR", - Operator::RegexMatch => "~", - Operator::RegexIMatch => "~*", - Operator::RegexNotMatch => "!~", - Operator::RegexNotIMatch => "!~*", - Operator::LikeMatch => "~~", - Operator::ILikeMatch => "~~*", - Operator::NotLikeMatch => "!~~", - Operator::NotILikeMatch => "!~~*", - Operator::IsDistinctFrom => "IS DISTINCT FROM", - Operator::IsNotDistinctFrom => "IS NOT DISTINCT FROM", - Operator::BitwiseAnd => "&", - Operator::BitwiseOr => "|", - Operator::BitwiseXor => "BIT_XOR", - Operator::BitwiseShiftRight => ">>", - Operator::BitwiseShiftLeft => "<<", - Operator::StringConcat => "||", - Operator::AtArrow => "@>", - Operator::ArrowAt => "<@", - }; - write!(f, "{display}") + match &self { + Operator::Eq => write!(f, "="), + Operator::NotEq => write!(f, "!="), + Operator::Lt => write!(f, "<"), + Operator::LtEq => write!(f, "<="), + Operator::Gt => write!(f, ">"), + Operator::GtEq => write!(f, ">="), + Operator::Plus => write!(f, "+"), + Operator::Minus => write!(f, "-"), + Operator::Multiply => write!(f, "*"), + Operator::Divide => write!(f, "/"), + Operator::Modulo => write!(f, "%"), + Operator::And => write!(f, "AND"), + Operator::Or => write!(f, "OR"), + Operator::RegexMatch => write!(f, "~"), + Operator::RegexIMatch => write!(f, "~*"), + Operator::RegexNotMatch => write!(f, "!~"), + Operator::RegexNotIMatch => write!(f, "!~*"), + Operator::LikeMatch => write!(f, "~~"), + Operator::ILikeMatch => write!(f, "~~*"), + Operator::NotLikeMatch => write!(f, "!~~"), + Operator::NotILikeMatch => write!(f, "!~~*"), + Operator::IsDistinctFrom => write!(f, "IS DISTINCT FROM"), + Operator::IsNotDistinctFrom => write!(f, "IS NOT DISTINCT FROM"), + Operator::BitwiseAnd => write!(f, "&"), + Operator::BitwiseOr => write!(f, "|"), + Operator::BitwiseXor => write!(f, "BIT_XOR"), + Operator::BitwiseShiftRight => write!(f, ">>"), + Operator::BitwiseShiftLeft => write!(f, "<<"), + Operator::StringConcat => write!(f, "||"), + Operator::AtArrow => write!(f, "@>"), + Operator::ArrowAt => write!(f, "<@"), + Operator::Custom(WrapCustomOperator(op)) => write!(f, "{op}"), + } + } +} + +impl From for Operator { + fn from(op: T) -> Self { + Operator::Custom(WrapCustomOperator(Arc::new(op))) + } +} + +pub trait CustomOperator: fmt::Debug + fmt::Display + Send + Sync { + /// Use in `datafusion/expr/src/type_coercion/binary.rs::Signature`, but the struct there isn't public, + /// hence returning a tuple. + /// + /// Returns `(lhs_type, rhs_type, return_type)` + fn binary_signature( + &self, + lhs: &DataType, + rhs: &DataType, + ) -> Result<(DataType, DataType, DataType)>; + + /// Used by unparse to convert the operator back to SQL + fn op_to_sql(&self) -> Result; + + /// Name used to uniquely identify the operator, and in logical plan producer + fn name(&self) -> &'static str; + + /// If the operator can be negated, return the negated operator + /// otherwise return None + fn negate(&self) -> Option { + None + } + + /// Return true if the operator is a numerical operator. + /// + /// For example, `Binary(a, +, b)` would be a numerical expression. + /// PostgresSQL concept: + fn is_numerical_operators(&self) -> bool { + false + } + + /// Return true if the operator is a comparison operator. + /// + /// For example, `Binary(a, >, b)` would be a comparison expression. + fn is_comparison_operator(&self) -> bool { + false + } + + /// Return true if the operator is a logic operator. + /// + /// For example, `Binary(Binary(a, >, b), AND, Binary(a, <, b + 3))` would + /// be a logical expression. + fn is_logic_operator(&self) -> bool { + false + } + + /// Return the operator where swapping lhs and rhs wouldn't change the result. + /// + /// For example `Binary(50, >=, a)` could also be represented as `Binary(a, <=, 50)`. + fn swap(&self) -> Option { + None + } + + /// Get the operator precedence + /// use as a reference + fn precedence(&self) -> u8 { + 0 + } +} + +/// Wraps a [`CustomOperator`] and implements traits required by [`Operator`]. +/// +/// This uses [`CustomOperator::name`] for equality, partial equality, ordering, and hashing; and therefore assumes +/// it is unique for each custom operator. +/// +/// See details on why dyn traits can't implement +/// `PartialEq` and friends. +#[derive(Debug, Clone)] +pub struct WrapCustomOperator(pub Arc); + +impl Eq for WrapCustomOperator {} + +impl PartialEq for WrapCustomOperator { + fn eq(&self, rhs: &Self) -> bool { + self.0.name() == rhs.0.name() + } +} + +impl PartialOrd for WrapCustomOperator { + fn partial_cmp(&self, rhs: &Self) -> Option { + self.0.name().partial_cmp(rhs.0.name()) + } +} + +impl Hash for WrapCustomOperator { + fn hash(&self, state: &mut H) { + self.0.name().hash(state) } } diff --git a/datafusion/expr/src/parse_custom_operator.rs b/datafusion/expr/src/parse_custom_operator.rs new file mode 100644 index 000000000000..e83e89d5c663 --- /dev/null +++ b/datafusion/expr/src/parse_custom_operator.rs @@ -0,0 +1,35 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::fmt::Debug; + +use datafusion_common::Result; +use sqlparser::ast::BinaryOperator; + +use crate::Operator; + +pub trait ParseCustomOperator: Debug + Send + Sync { + /// Return a human-readable name for this parser + fn name(&self) -> &str; + + /// potentially parse a custom operator. + /// + /// Return `None` if the operator is not recognized + fn op_from_ast(&self, op: &BinaryOperator) -> Result>; + + fn op_from_name(&self, raw_op: &str) -> Result>; +} diff --git a/datafusion/expr/src/registry.rs b/datafusion/expr/src/registry.rs index 70d0a21a870e..9cd04269d6c4 100644 --- a/datafusion/expr/src/registry.rs +++ b/datafusion/expr/src/registry.rs @@ -18,7 +18,9 @@ //! FunctionRegistry trait use crate::expr_rewriter::FunctionRewrite; -use crate::{AggregateUDF, ScalarUDF, UserDefinedLogicalNode, WindowUDF}; +use crate::{ + AggregateUDF, ParseCustomOperator, ScalarUDF, UserDefinedLogicalNode, WindowUDF, +}; use datafusion_common::{not_impl_err, plan_datafusion_err, Result}; use std::collections::HashMap; use std::{collections::HashSet, sync::Arc}; @@ -108,6 +110,21 @@ pub trait FunctionRegistry { ) -> Result<()> { not_impl_err!("Registering FunctionRewrite") } + + /// Registers a new [`ParseCustomOperator`] with the registry. + /// + /// `ParseCustomOperator` is used to parse custom operators from SQL, + /// e.g. `->>` or `?`. + fn register_parse_custom_operator( + &mut self, + _parse_custom_operator: Arc, + ) -> Result<()> { + not_impl_err!("Registering ParseCustomOperator") + } + + fn parse_custom_operators(&self) -> Vec> { + vec![] + } } /// Serializer and deserializer registry for extensions like [UserDefinedLogicalNode]. diff --git a/datafusion/expr/src/type_coercion/binary.rs b/datafusion/expr/src/type_coercion/binary.rs index 5645a2a4dede..e517f5f7dc91 100644 --- a/datafusion/expr/src/type_coercion/binary.rs +++ b/datafusion/expr/src/type_coercion/binary.rs @@ -20,7 +20,7 @@ use std::collections::HashSet; use std::sync::Arc; -use crate::Operator; +use crate::{Operator, WrapCustomOperator}; use arrow::array::{new_empty_array, Array}; use arrow::compute::can_cast_types; @@ -189,6 +189,10 @@ fn signature(lhs: &DataType, op: &Operator, rhs: &DataType) -> Result ) } } + Custom(WrapCustomOperator(op)) => { + let (lhs, rhs, ret) = op.binary_signature(lhs, rhs)?; + Ok(Signature { lhs, rhs, ret }) + } } } diff --git a/datafusion/physical-expr/src/expressions/binary.rs b/datafusion/physical-expr/src/expressions/binary.rs index d19279c20d10..5ba306d36999 100644 --- a/datafusion/physical-expr/src/expressions/binary.rs +++ b/datafusion/physical-expr/src/expressions/binary.rs @@ -638,6 +638,9 @@ impl BinaryExpr { AtArrow | ArrowAt => { unreachable!("ArrowAt and AtArrow should be rewritten to function") } + Custom(_) => { + internal_err!("Custom operator should be rewritten to functions") + } } } } diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index 21331a94c18c..d223789df033 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -31,8 +31,8 @@ use datafusion_expr::{ AggregateFunction, Between, BinaryExpr, BuiltInWindowFunction, Case, Cast, Expr, GroupingSet, GroupingSet::GroupingSets, - JoinConstraint, JoinType, Like, Operator, TryCast, WindowFrame, WindowFrameBound, - WindowFrameUnits, + JoinConstraint, JoinType, Like, Operator, ParseCustomOperator, TryCast, WindowFrame, + WindowFrameBound, WindowFrameUnits, }; use datafusion_proto_common::{from_proto::FromOptionalField, FromProtoError as Error}; @@ -252,7 +252,10 @@ pub fn parse_expr( match expr_type { ExprType::BinaryExpr(binary_expr) => { - let op = from_proto_binary_op(&binary_expr.op)?; + let op = from_proto_binary_op( + &binary_expr.op, + ®istry.parse_custom_operators(), + )?; let operands = parse_exprs(&binary_expr.operands, registry, codec)?; if operands.len() < 2 { @@ -676,7 +679,16 @@ fn parse_escape_char(s: &str) -> Result> { } } -pub fn from_proto_binary_op(op: &str) -> Result { +pub fn from_proto_binary_op( + op: &str, + parse_custom_operator: &[Arc], +) -> Result { + for parse_custom_op in parse_custom_operator { + if let Some(op) = parse_custom_op.op_from_name(op)? { + return Ok(op); + } + } + match op { "And" => Ok(Operator::And), "Or" => Ok(Operator::Or), diff --git a/datafusion/proto/src/physical_plan/from_proto.rs b/datafusion/proto/src/physical_plan/from_proto.rs index b636c77641c7..545d2f81f408 100644 --- a/datafusion/proto/src/physical_plan/from_proto.rs +++ b/datafusion/proto/src/physical_plan/from_proto.rs @@ -241,7 +241,10 @@ pub fn parse_physical_expr( input_schema, codec, )?, - logical_plan::from_proto::from_proto_binary_op(&binary_expr.op)?, + logical_plan::from_proto::from_proto_binary_op( + &binary_expr.op, + ®istry.parse_custom_operators(), + )?, parse_required_physical_expr( binary_expr.r.as_deref(), registry, diff --git a/datafusion/sql/src/expr/binary_op.rs b/datafusion/sql/src/expr/binary_op.rs index fcb57e8a82e4..a9ca1583f448 100644 --- a/datafusion/sql/src/expr/binary_op.rs +++ b/datafusion/sql/src/expr/binary_op.rs @@ -22,6 +22,12 @@ use sqlparser::ast::BinaryOperator; impl<'a, S: ContextProvider> SqlToRel<'a, S> { pub(crate) fn parse_sql_binary_op(&self, op: BinaryOperator) -> Result { + for parse_custom_op in &self.options.parse_custom_operator { + if let Some(op) = parse_custom_op.op_from_ast(&op)? { + return Ok(op); + } + } + match op { BinaryOperator::Gt => Ok(Operator::Gt), BinaryOperator::GtEq => Ok(Operator::GtEq), diff --git a/datafusion/sql/src/planner.rs b/datafusion/sql/src/planner.rs index 30f95170a34f..b885afb18b50 100644 --- a/datafusion/sql/src/planner.rs +++ b/datafusion/sql/src/planner.rs @@ -24,7 +24,7 @@ use arrow_schema::*; use datafusion_common::{ field_not_found, internal_err, plan_datafusion_err, DFSchemaRef, SchemaError, }; -use datafusion_expr::WindowUDF; +use datafusion_expr::{ParseCustomOperator, WindowUDF}; use sqlparser::ast::TimezoneInfo; use sqlparser::ast::{ArrayElemTypeDef, ExactNumberInfo}; use sqlparser::ast::{ColumnDef as SQLColumnDef, ColumnOption}; @@ -97,6 +97,7 @@ pub trait ContextProvider { pub struct ParserOptions { pub parse_float_as_decimal: bool, pub enable_ident_normalization: bool, + pub parse_custom_operator: Vec>, } impl Default for ParserOptions { @@ -104,6 +105,7 @@ impl Default for ParserOptions { Self { parse_float_as_decimal: false, enable_ident_normalization: true, + parse_custom_operator: vec![], } } } diff --git a/datafusion/sql/src/unparser/expr.rs b/datafusion/sql/src/unparser/expr.rs index ad898de5987a..2214fcb44625 100644 --- a/datafusion/sql/src/unparser/expr.rs +++ b/datafusion/sql/src/unparser/expr.rs @@ -41,6 +41,7 @@ use datafusion_common::{ use datafusion_expr::{ expr::{Alias, Exists, InList, ScalarFunction, Sort, WindowFunction}, Between, BinaryExpr, Case, Cast, Expr, GroupingSet, Like, Operator, TryCast, + WrapCustomOperator, }; use super::Unparser; @@ -649,6 +650,7 @@ impl Unparser<'_> { Operator::BitwiseShiftRight => Ok(ast::BinaryOperator::PGBitwiseShiftRight), Operator::BitwiseShiftLeft => Ok(ast::BinaryOperator::PGBitwiseShiftLeft), Operator::StringConcat => Ok(ast::BinaryOperator::StringConcat), + Operator::Custom(WrapCustomOperator(op)) => op.op_to_sql(), Operator::AtArrow => not_impl_err!("unsupported operation: {op:?}"), Operator::ArrowAt => not_impl_err!("unsupported operation: {op:?}"), } diff --git a/datafusion/sql/tests/sql_integration.rs b/datafusion/sql/tests/sql_integration.rs index f196d71d41de..977fd738e322 100644 --- a/datafusion/sql/tests/sql_integration.rs +++ b/datafusion/sql/tests/sql_integration.rs @@ -84,6 +84,7 @@ fn parse_decimals() { ParserOptions { parse_float_as_decimal: true, enable_ident_normalization: false, + ..Default::default() }, ); } @@ -137,6 +138,7 @@ fn parse_ident_normalization() { ParserOptions { parse_float_as_decimal: false, enable_ident_normalization, + ..Default::default() }, ); if plan.is_ok() { diff --git a/datafusion/substrait/src/logical_plan/producer.rs b/datafusion/substrait/src/logical_plan/producer.rs index 302f38606bfb..fb01bbb9136e 100644 --- a/datafusion/substrait/src/logical_plan/producer.rs +++ b/datafusion/substrait/src/logical_plan/producer.rs @@ -23,7 +23,7 @@ use std::sync::Arc; use arrow_buffer::ToByteSlice; use datafusion::arrow::datatypes::IntervalUnit; use datafusion::logical_expr::{ - CrossJoin, Distinct, Like, Partitioning, WindowFrameUnits, + CrossJoin, Distinct, Like, Partitioning, WindowFrameUnits, WrapCustomOperator, }; use datafusion::{ arrow::datatypes::{DataType, TimeUnit}, @@ -723,6 +723,7 @@ pub fn operator_to_name(op: Operator) -> &'static str { Operator::BitwiseXor => "bitwise_xor", Operator::BitwiseShiftRight => "bitwise_shift_right", Operator::BitwiseShiftLeft => "bitwise_shift_left", + Operator::Custom(WrapCustomOperator(op)) => op.name(), } }