From 699d7c28bb9cb817eb4d9e54f9dd2fdf1c7e70d2 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 1 Feb 2023 11:07:22 -0800 Subject: [PATCH] Support arithmetic dyn scalar --- datafusion/expr/src/type_coercion/binary.rs | 15 ++ .../physical-expr/src/expressions/binary.rs | 111 ++++++++++- .../src/expressions/binary/kernels_arrow.rs | 172 ++++++++++++------ 3 files changed, 239 insertions(+), 59 deletions(-) diff --git a/datafusion/expr/src/type_coercion/binary.rs b/datafusion/expr/src/type_coercion/binary.rs index d2923c8dbfae..cd6356fa2944 100644 --- a/datafusion/expr/src/type_coercion/binary.rs +++ b/datafusion/expr/src/type_coercion/binary.rs @@ -341,6 +341,15 @@ fn mathematics_numerical_coercion( (Null, dec_type @ Decimal128(_, _)) | (dec_type @ Decimal128(_, _), Null) => { Some(dec_type.clone()) } + (Dictionary(key_type, value_type), _) => { + let value_type = + mathematics_numerical_coercion(mathematics_op, value_type, rhs_type); + value_type + .map(|value_type| Dictionary(key_type.clone(), Box::new(value_type))) + } + (_, Dictionary(_, value_type)) => { + mathematics_numerical_coercion(mathematics_op, lhs_type, value_type) + } (Decimal128(_, _), Float32 | Float64) => Some(Float64), (Float32 | Float64, Decimal128(_, _)) => Some(Float64), (Decimal128(_, _), _) => { @@ -439,6 +448,12 @@ fn both_numeric_or_null_and_numeric(lhs_type: &DataType, rhs_type: &DataType) -> match (lhs_type, rhs_type) { (_, DataType::Null) => is_numeric(lhs_type), (DataType::Null, _) => is_numeric(rhs_type), + (DataType::Dictionary(_, value_type), _) => { + is_numeric(value_type) && is_numeric(rhs_type) + } + (_, DataType::Dictionary(_, value_type)) => { + is_numeric(lhs_type) && is_numeric(value_type) + } _ => is_numeric(lhs_type) && is_numeric(rhs_type), } } diff --git a/datafusion/physical-expr/src/expressions/binary.rs b/datafusion/physical-expr/src/expressions/binary.rs index d2346d278dc4..e54806babb99 100644 --- a/datafusion/physical-expr/src/expressions/binary.rs +++ b/datafusion/physical-expr/src/expressions/binary.rs @@ -24,8 +24,10 @@ use std::{any::Any, sync::Arc}; use arrow::array::*; use arrow::compute::kernels::arithmetic::{ - add, add_scalar, divide_opt, divide_scalar, modulus, modulus_scalar, multiply, - multiply_scalar, subtract, subtract_scalar, + add, add_scalar_dyn as add_dyn_scalar, divide_opt, + divide_scalar_dyn as divide_dyn_scalar, modulus, modulus_scalar, multiply, + multiply_scalar_dyn as multiply_dyn_scalar, subtract, + subtract_scalar_dyn as subtract_dyn_scalar, }; use arrow::compute::kernels::boolean::{and_kleene, not, or_kleene}; use arrow::compute::kernels::comparison::regexp_is_match_utf8; @@ -49,6 +51,7 @@ use arrow::compute::kernels::comparison::{ use arrow::compute::kernels::comparison::{ eq_scalar, gt_eq_scalar, gt_scalar, lt_eq_scalar, lt_scalar, neq_scalar, }; +use arrow::datatypes::*; use adapter::{eq_dyn, gt_dyn, gt_eq_dyn, lt_dyn, lt_eq_dyn, neq_dyn}; use arrow::compute::kernels::concat_elements::concat_elements_utf8; @@ -58,12 +61,12 @@ use kernels::{ bitwise_xor, bitwise_xor_scalar, }; use kernels_arrow::{ - add_decimal, add_decimal_scalar, divide_decimal_scalar, divide_opt_decimal, + add_decimal, add_decimal_dyn_scalar, divide_decimal_dyn_scalar, divide_opt_decimal, is_distinct_from, is_distinct_from_bool, is_distinct_from_decimal, is_distinct_from_null, is_distinct_from_utf8, is_not_distinct_from, is_not_distinct_from_bool, is_not_distinct_from_decimal, is_not_distinct_from_null, is_not_distinct_from_utf8, modulus_decimal, modulus_decimal_scalar, multiply_decimal, - multiply_decimal_scalar, subtract_decimal, subtract_decimal_scalar, + multiply_decimal_dyn_scalar, subtract_decimal, subtract_decimal_dyn_scalar, }; use arrow::datatypes::{DataType, Schema, TimeUnit}; @@ -315,6 +318,46 @@ macro_rules! compute_op_dyn_scalar { }}; } +/// Invoke a dyn compute kernel on a data array and a scalar value +/// LEFT is Primitive or Dictionary array of numeric values, RIGHT is scalar value +/// OP_TYPE is the return type of scalar function +/// SCALAR_TYPE is the type of the scalar value +macro_rules! compute_primitive_op_dyn_scalar { + ($LEFT:expr, $RIGHT:expr, $OP:ident, $OP_TYPE:expr, $SCALAR_TYPE:ident) => {{ + // generate the scalar function name, such as lt_dyn_scalar, from the $OP parameter + // (which could have a value of lt_dyn) and the suffix _scalar + if let Some(value) = $RIGHT { + Ok(Arc::new(paste::expr! {[<$OP _dyn_scalar>]::<$SCALAR_TYPE>}( + $LEFT, + value, + )?)) + } else { + // when the $RIGHT is a NULL, generate a NULL array of $OP_TYPE + Ok(Arc::new(new_null_array($OP_TYPE, $LEFT.len()))) + } + }}; +} + +/// Invoke a dyn decimal compute kernel on a data array and a scalar value +/// LEFT is Decimal or Dictionary array of decimal values, RIGHT is scalar value +/// OP_TYPE is the return type of scalar function +/// SCALAR_TYPE is the type of the scalar value +macro_rules! compute_primitive_decimal_op_dyn_scalar { + ($LEFT:expr, $RIGHT:expr, $OP:ident, $OP_TYPE:expr) => {{ + // generate the scalar function name, such as lt_dyn_scalar, from the $OP parameter + // (which could have a value of lt_dyn) and the suffix _scalar + if let Some(value) = $RIGHT { + Ok(paste::expr! {[<$OP _decimal_dyn_scalar>]}( + $LEFT, + value, + )?) + } else { + // when the $RIGHT is a NULL, generate a NULL array of $OP_TYPE + Ok(Arc::new(new_null_array($OP_TYPE, $LEFT.len()))) + } + }}; +} + /// Invoke a compute kernel on array(s) macro_rules! compute_op { // invoke binary operator @@ -376,6 +419,36 @@ macro_rules! binary_primitive_array_op { }}; } +/// Invoke a compute dyn kernel on an array and a scalar +/// The binary_primitive_array_op_dyn_scalar macro only evaluates for primitive +/// types like integers and floats. +macro_rules! binary_primitive_array_op_dyn_scalar { + ($LEFT:expr, $RIGHT:expr, $OP:ident, $OP_TYPE:expr) => {{ + // unwrap underlying (non dictionary) value + let right = unwrap_dict_value($RIGHT); + + let result: Result> = match right { + ScalarValue::Decimal128(v, _, _) => compute_primitive_decimal_op_dyn_scalar!($LEFT, v, $OP, $OP_TYPE), + ScalarValue::Int8(v) => compute_primitive_op_dyn_scalar!($LEFT, v, $OP, $OP_TYPE, Int8Type), + ScalarValue::Int16(v) => compute_primitive_op_dyn_scalar!($LEFT, v, $OP, $OP_TYPE, Int16Type), + ScalarValue::Int32(v) => compute_primitive_op_dyn_scalar!($LEFT, v, $OP, $OP_TYPE, Int32Type), + ScalarValue::Int64(v) => compute_primitive_op_dyn_scalar!($LEFT, v, $OP, $OP_TYPE, Int64Type), + ScalarValue::UInt8(v) => compute_primitive_op_dyn_scalar!($LEFT, v, $OP, $OP_TYPE, UInt8Type), + ScalarValue::UInt16(v) => compute_primitive_op_dyn_scalar!($LEFT, v, $OP, $OP_TYPE, UInt16Type), + ScalarValue::UInt32(v) => compute_primitive_op_dyn_scalar!($LEFT, v, $OP, $OP_TYPE, UInt32Type), + ScalarValue::UInt64(v) => compute_primitive_op_dyn_scalar!($LEFT, v, $OP, $OP_TYPE, UInt64Type), + ScalarValue::Float32(v) => compute_primitive_op_dyn_scalar!($LEFT, v, $OP, $OP_TYPE, Float32Type), + ScalarValue::Float64(v) => compute_primitive_op_dyn_scalar!($LEFT, v, $OP, $OP_TYPE, Float64Type), + other => Err(DataFusionError::Internal(format!( + "Data type {:?} not supported for scalar operation '{}' on dyn array", + other, stringify!($OP))) + ) + }; + + Some(result) + }} +} + /// Invoke a compute kernel on an array and a scalar /// The binary_primitive_array_op_scalar macro only evaluates for primitive /// types like integers and floats. @@ -904,6 +977,7 @@ impl BinaryExpr { scalar: &ScalarValue, ) -> Result>> { let bool_type = &DataType::Boolean; + let left_type = array.data_type(); let scalar_result = match &self.op { Operator::Lt => { binary_array_op_dyn_scalar!(array, scalar.clone(), lt, bool_type) @@ -924,18 +998,39 @@ impl BinaryExpr { binary_array_op_dyn_scalar!(array, scalar.clone(), neq, bool_type) } Operator::Plus => { - binary_primitive_array_op_scalar!(array, scalar.clone(), add) + binary_primitive_array_op_dyn_scalar!( + array, + scalar.clone(), + add, + left_type + ) } Operator::Minus => { - binary_primitive_array_op_scalar!(array, scalar.clone(), subtract) + binary_primitive_array_op_dyn_scalar!( + array, + scalar.clone(), + subtract, + left_type + ) } Operator::Multiply => { - binary_primitive_array_op_scalar!(array, scalar.clone(), multiply) + binary_primitive_array_op_dyn_scalar!( + array, + scalar.clone(), + multiply, + left_type + ) } Operator::Divide => { - binary_primitive_array_op_scalar!(array, scalar.clone(), divide) + binary_primitive_array_op_dyn_scalar!( + array, + scalar.clone(), + divide, + left_type + ) } Operator::Modulo => { + // todo: change to binary_primitive_array_op_dyn_scalar! once modulo is implemented binary_primitive_array_op_scalar!(array, scalar.clone(), modulus) } Operator::RegexMatch => binary_string_array_flag_op_scalar!( diff --git a/datafusion/physical-expr/src/expressions/binary/kernels_arrow.rs b/datafusion/physical-expr/src/expressions/binary/kernels_arrow.rs index 2135982b67f8..40e0d2b0ed90 100644 --- a/datafusion/physical-expr/src/expressions/binary/kernels_arrow.rs +++ b/datafusion/physical-expr/src/expressions/binary/kernels_arrow.rs @@ -19,11 +19,16 @@ //! destined for arrow-rs but are in datafusion until they are ported. use arrow::compute::{ - add, add_scalar, divide_opt, divide_scalar, modulus, modulus_scalar, multiply, - multiply_scalar, subtract, subtract_scalar, + add, add_scalar_dyn, divide_opt, divide_scalar, divide_scalar_dyn, modulus, + modulus_scalar, multiply, multiply_scalar, multiply_scalar_dyn, subtract, + subtract_scalar_dyn, }; -use arrow::{array::*, datatypes::ArrowNumericType}; -use datafusion_common::Result; +use arrow::datatypes::Decimal128Type; +use arrow::{array::*, datatypes::ArrowNumericType, downcast_dictionary_array}; +use arrow_schema::DataType; +use datafusion_common::cast::as_decimal128_array; +use datafusion_common::{DataFusionError, Result}; +use std::sync::Arc; // Simple (low performance) kernels until optimized kernels are added to arrow // See https://github.com/apache/arrow-rs/issues/960 @@ -183,50 +188,123 @@ pub(crate) fn add_decimal( Ok(array) } -pub(crate) fn add_decimal_scalar( - left: &Decimal128Array, +pub(crate) fn add_decimal_dyn_scalar(left: &dyn Array, right: i128) -> Result { + let left_decimal = left.as_any().downcast_ref::().unwrap(); + + let array = add_scalar_dyn::(left, right)?; + let decimal_array = as_decimal128_array(&array)?; + let decimal_array = decimal_array + .clone() + .with_precision_and_scale(left_decimal.precision(), left_decimal.scale())?; + Ok(Arc::new(decimal_array)) +} + +pub(crate) fn subtract_decimal_dyn_scalar( + left: &dyn Array, right: i128, -) -> Result { - let array = add_scalar(left, right)? - .with_precision_and_scale(left.precision(), left.scale())?; - Ok(array) +) -> Result { + let left_decimal = left.as_any().downcast_ref::().unwrap(); + + let array = subtract_scalar_dyn::(left, right)?; + let decimal_array = as_decimal128_array(&array)?; + let decimal_array = decimal_array + .clone() + .with_precision_and_scale(left_decimal.precision(), left_decimal.scale())?; + Ok(Arc::new(decimal_array)) } -pub(crate) fn subtract_decimal( - left: &Decimal128Array, - right: &Decimal128Array, -) -> Result { - let array = subtract(left, right)? - .with_precision_and_scale(left.precision(), left.scale())?; - Ok(array) +fn get_precision_scale(left: &dyn Array) -> Result<(u8, i8)> { + match left.data_type() { + DataType::Decimal128(precision, scale) => Ok((*precision, *scale)), + DataType::Dictionary(_, value_type) => match value_type.as_ref() { + DataType::Decimal128(precision, scale) => Ok((*precision, *scale)), + _ => Err(DataFusionError::Internal( + "Unexpected data type".to_string(), + )), + }, + _ => Err(DataFusionError::Internal( + "Unexpected data type".to_string(), + )), + } } -pub(crate) fn subtract_decimal_scalar( - left: &Decimal128Array, +fn decimal_array_with_precision_scale( + array: ArrayRef, + precision: u8, + scale: i8, +) -> Result { + let array = array.as_ref(); + let decimal_array = match array.data_type() { + DataType::Decimal128(_, _) => { + let array = as_decimal128_array(array)?; + Arc::new(array.clone().with_precision_and_scale(precision, scale)?) + as ArrayRef + } + DataType::Dictionary(_, _) => { + downcast_dictionary_array!( + array => match array.values().data_type() { + DataType::Decimal128(_, _) => { + let decimal_dict_array = array.downcast_dict::().unwrap(); + let decimal_array = decimal_dict_array.values().clone(); + let decimal_array = decimal_array.with_precision_and_scale(precision, scale)?; + Arc::new(array.with_values(&decimal_array)) as ArrayRef + } + t => return Err(DataFusionError::Internal(format!("Unexpected dictionary value type {t}"))), + }, + t => return Err(DataFusionError::Internal(format!("Unexpected datatype {t}"))), + ) + } + _ => { + return Err(DataFusionError::Internal( + "Unexpected data type".to_string(), + )) + } + }; + Ok(decimal_array) +} + +pub(crate) fn multiply_decimal_dyn_scalar( + left: &dyn Array, right: i128, -) -> Result { - let array = subtract_scalar(left, right)? - .with_precision_and_scale(left.precision(), left.scale())?; - Ok(array) +) -> Result { + let (precision, scale) = get_precision_scale(left)?; + + let array = multiply_scalar_dyn::(left, right)?; + + let divide = 10_i128.pow(scale as u32); + let array = divide_scalar_dyn::(&array, divide)?; + + decimal_array_with_precision_scale(array, precision, scale) } -pub(crate) fn multiply_decimal( +pub(crate) fn divide_decimal_dyn_scalar( + left: &dyn Array, + right: i128, +) -> Result { + let (precision, scale) = get_precision_scale(left)?; + + let mul = 10_i128.pow(scale as u32); + let array = multiply_scalar_dyn::(left, mul)?; + + let array = divide_scalar_dyn::(&array, right)?; + decimal_array_with_precision_scale(array, precision, scale) +} + +pub(crate) fn subtract_decimal( left: &Decimal128Array, right: &Decimal128Array, ) -> Result { - let divide = 10_i128.pow(left.scale() as u32); - let array = multiply(left, right)?; - let array = divide_scalar(&array, divide)? + let array = subtract(left, right)? .with_precision_and_scale(left.precision(), left.scale())?; Ok(array) } -pub(crate) fn multiply_decimal_scalar( +pub(crate) fn multiply_decimal( left: &Decimal128Array, - right: i128, + right: &Decimal128Array, ) -> Result { - let array = multiply_scalar(left, right)?; let divide = 10_i128.pow(left.scale() as u32); + let array = multiply(left, right)?; let array = divide_scalar(&array, divide)? .with_precision_and_scale(left.precision(), left.scale())?; Ok(array) @@ -243,18 +321,6 @@ pub(crate) fn divide_opt_decimal( Ok(array) } -pub(crate) fn divide_decimal_scalar( - left: &Decimal128Array, - right: i128, -) -> Result { - let mul = 10_i128.pow(left.scale() as u32); - let array = multiply_scalar(left, mul)?; - // `0` of right will be checked in `divide_scalar` - let array = divide_scalar(&array, right)? - .with_precision_and_scale(left.precision(), left.scale())?; - Ok(array) -} - pub(crate) fn modulus_decimal( left: &Decimal128Array, right: &Decimal128Array, @@ -371,25 +437,28 @@ mod tests { let expect = create_decimal_array(&[Some(246), None, Some(245), Some(247)], 25, 3); assert_eq!(expect, result); - let result = add_decimal_scalar(&left_decimal_array, 10)?; + let result = add_decimal_dyn_scalar(&left_decimal_array, 10)?; + let result = as_decimal128_array(&result)?; let expect = create_decimal_array(&[Some(133), None, Some(132), Some(134)], 25, 3); - assert_eq!(expect, result); + assert_eq!(&expect, result); // subtract let result = subtract_decimal(&left_decimal_array, &right_decimal_array)?; let expect = create_decimal_array(&[Some(0), None, Some(-1), Some(1)], 25, 3); assert_eq!(expect, result); - let result = subtract_decimal_scalar(&left_decimal_array, 10)?; + let result = subtract_decimal_dyn_scalar(&left_decimal_array, 10)?; + let result = as_decimal128_array(&result)?; let expect = create_decimal_array(&[Some(113), None, Some(112), Some(114)], 25, 3); - assert_eq!(expect, result); + assert_eq!(&expect, result); // multiply let result = multiply_decimal(&left_decimal_array, &right_decimal_array)?; let expect = create_decimal_array(&[Some(15), None, Some(15), Some(15)], 25, 3); assert_eq!(expect, result); - let result = multiply_decimal_scalar(&left_decimal_array, 10)?; + let result = multiply_decimal_dyn_scalar(&left_decimal_array, 10)?; + let result = as_decimal128_array(&result)?; let expect = create_decimal_array(&[Some(1), None, Some(1), Some(1)], 25, 3); - assert_eq!(expect, result); + assert_eq!(&expect, result); // divide let left_decimal_array = create_decimal_array( &[ @@ -414,7 +483,8 @@ mod tests { 3, ); assert_eq!(expect, result); - let result = divide_decimal_scalar(&left_decimal_array, 10)?; + let result = divide_decimal_dyn_scalar(&left_decimal_array, 10)?; + let result = as_decimal128_array(&result)?; let expect = create_decimal_array( &[ Some(123456700), @@ -426,7 +496,7 @@ mod tests { 25, 3, ); - assert_eq!(expect, result); + assert_eq!(&expect, result); let result = modulus_decimal(&left_decimal_array, &right_decimal_array)?; let expect = create_decimal_array(&[Some(7), None, Some(37), Some(16), None], 25, 3); @@ -444,7 +514,7 @@ mod tests { let left_decimal_array = create_decimal_array(&[Some(101)], 10, 1); let right_decimal_array = create_decimal_array(&[Some(0)], 1, 1); - let err = divide_decimal_scalar(&left_decimal_array, 0).unwrap_err(); + let err = divide_decimal_dyn_scalar(&left_decimal_array, 0).unwrap_err(); assert_eq!("Arrow error: Divide by zero error", err.to_string()); let err = modulus_decimal(&left_decimal_array, &right_decimal_array).unwrap_err(); assert_eq!("Arrow error: Divide by zero error", err.to_string());