Skip to content

Commit

Permalink
Support arithmetic dyn scalar
Browse files Browse the repository at this point in the history
  • Loading branch information
viirya committed Feb 1, 2023
1 parent 7e94826 commit 699d7c2
Show file tree
Hide file tree
Showing 3 changed files with 239 additions and 59 deletions.
15 changes: 15 additions & 0 deletions datafusion/expr/src/type_coercion/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,15 @@ fn mathematics_numerical_coercion(
(Null, dec_type @ Decimal128(_, _)) | (dec_type @ Decimal128(_, _), Null) => {
Some(dec_type.clone())
}
(Dictionary(key_type, value_type), _) => {
let value_type =
mathematics_numerical_coercion(mathematics_op, value_type, rhs_type);
value_type
.map(|value_type| Dictionary(key_type.clone(), Box::new(value_type)))
}
(_, Dictionary(_, value_type)) => {
mathematics_numerical_coercion(mathematics_op, lhs_type, value_type)
}
(Decimal128(_, _), Float32 | Float64) => Some(Float64),
(Float32 | Float64, Decimal128(_, _)) => Some(Float64),
(Decimal128(_, _), _) => {
Expand Down Expand Up @@ -439,6 +448,12 @@ fn both_numeric_or_null_and_numeric(lhs_type: &DataType, rhs_type: &DataType) ->
match (lhs_type, rhs_type) {
(_, DataType::Null) => is_numeric(lhs_type),
(DataType::Null, _) => is_numeric(rhs_type),
(DataType::Dictionary(_, value_type), _) => {
is_numeric(value_type) && is_numeric(rhs_type)
}
(_, DataType::Dictionary(_, value_type)) => {
is_numeric(lhs_type) && is_numeric(value_type)
}
_ => is_numeric(lhs_type) && is_numeric(rhs_type),
}
}
Expand Down
111 changes: 103 additions & 8 deletions datafusion/physical-expr/src/expressions/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,10 @@ use std::{any::Any, sync::Arc};

use arrow::array::*;
use arrow::compute::kernels::arithmetic::{
add, add_scalar, divide_opt, divide_scalar, modulus, modulus_scalar, multiply,
multiply_scalar, subtract, subtract_scalar,
add, add_scalar_dyn as add_dyn_scalar, divide_opt,
divide_scalar_dyn as divide_dyn_scalar, modulus, modulus_scalar, multiply,
multiply_scalar_dyn as multiply_dyn_scalar, subtract,
subtract_scalar_dyn as subtract_dyn_scalar,
};
use arrow::compute::kernels::boolean::{and_kleene, not, or_kleene};
use arrow::compute::kernels::comparison::regexp_is_match_utf8;
Expand All @@ -49,6 +51,7 @@ use arrow::compute::kernels::comparison::{
use arrow::compute::kernels::comparison::{
eq_scalar, gt_eq_scalar, gt_scalar, lt_eq_scalar, lt_scalar, neq_scalar,
};
use arrow::datatypes::*;

use adapter::{eq_dyn, gt_dyn, gt_eq_dyn, lt_dyn, lt_eq_dyn, neq_dyn};
use arrow::compute::kernels::concat_elements::concat_elements_utf8;
Expand All @@ -58,12 +61,12 @@ use kernels::{
bitwise_xor, bitwise_xor_scalar,
};
use kernels_arrow::{
add_decimal, add_decimal_scalar, divide_decimal_scalar, divide_opt_decimal,
add_decimal, add_decimal_dyn_scalar, divide_decimal_dyn_scalar, divide_opt_decimal,
is_distinct_from, is_distinct_from_bool, is_distinct_from_decimal,
is_distinct_from_null, is_distinct_from_utf8, is_not_distinct_from,
is_not_distinct_from_bool, is_not_distinct_from_decimal, is_not_distinct_from_null,
is_not_distinct_from_utf8, modulus_decimal, modulus_decimal_scalar, multiply_decimal,
multiply_decimal_scalar, subtract_decimal, subtract_decimal_scalar,
multiply_decimal_dyn_scalar, subtract_decimal, subtract_decimal_dyn_scalar,
};

use arrow::datatypes::{DataType, Schema, TimeUnit};
Expand Down Expand Up @@ -315,6 +318,46 @@ macro_rules! compute_op_dyn_scalar {
}};
}

/// Invoke a dyn compute kernel on a data array and a scalar value
/// LEFT is Primitive or Dictionary array of numeric values, RIGHT is scalar value
/// OP_TYPE is the return type of scalar function
/// SCALAR_TYPE is the type of the scalar value
macro_rules! compute_primitive_op_dyn_scalar {
($LEFT:expr, $RIGHT:expr, $OP:ident, $OP_TYPE:expr, $SCALAR_TYPE:ident) => {{
// generate the scalar function name, such as lt_dyn_scalar, from the $OP parameter
// (which could have a value of lt_dyn) and the suffix _scalar
if let Some(value) = $RIGHT {
Ok(Arc::new(paste::expr! {[<$OP _dyn_scalar>]::<$SCALAR_TYPE>}(
$LEFT,
value,
)?))
} else {
// when the $RIGHT is a NULL, generate a NULL array of $OP_TYPE
Ok(Arc::new(new_null_array($OP_TYPE, $LEFT.len())))
}
}};
}

/// Invoke a dyn decimal compute kernel on a data array and a scalar value
/// LEFT is Decimal or Dictionary array of decimal values, RIGHT is scalar value
/// OP_TYPE is the return type of scalar function
/// SCALAR_TYPE is the type of the scalar value
macro_rules! compute_primitive_decimal_op_dyn_scalar {
($LEFT:expr, $RIGHT:expr, $OP:ident, $OP_TYPE:expr) => {{
// generate the scalar function name, such as lt_dyn_scalar, from the $OP parameter
// (which could have a value of lt_dyn) and the suffix _scalar
if let Some(value) = $RIGHT {
Ok(paste::expr! {[<$OP _decimal_dyn_scalar>]}(
$LEFT,
value,
)?)
} else {
// when the $RIGHT is a NULL, generate a NULL array of $OP_TYPE
Ok(Arc::new(new_null_array($OP_TYPE, $LEFT.len())))
}
}};
}

/// Invoke a compute kernel on array(s)
macro_rules! compute_op {
// invoke binary operator
Expand Down Expand Up @@ -376,6 +419,36 @@ macro_rules! binary_primitive_array_op {
}};
}

/// Invoke a compute dyn kernel on an array and a scalar
/// The binary_primitive_array_op_dyn_scalar macro only evaluates for primitive
/// types like integers and floats.
macro_rules! binary_primitive_array_op_dyn_scalar {
($LEFT:expr, $RIGHT:expr, $OP:ident, $OP_TYPE:expr) => {{
// unwrap underlying (non dictionary) value
let right = unwrap_dict_value($RIGHT);

let result: Result<Arc<dyn Array>> = match right {
ScalarValue::Decimal128(v, _, _) => compute_primitive_decimal_op_dyn_scalar!($LEFT, v, $OP, $OP_TYPE),
ScalarValue::Int8(v) => compute_primitive_op_dyn_scalar!($LEFT, v, $OP, $OP_TYPE, Int8Type),
ScalarValue::Int16(v) => compute_primitive_op_dyn_scalar!($LEFT, v, $OP, $OP_TYPE, Int16Type),
ScalarValue::Int32(v) => compute_primitive_op_dyn_scalar!($LEFT, v, $OP, $OP_TYPE, Int32Type),
ScalarValue::Int64(v) => compute_primitive_op_dyn_scalar!($LEFT, v, $OP, $OP_TYPE, Int64Type),
ScalarValue::UInt8(v) => compute_primitive_op_dyn_scalar!($LEFT, v, $OP, $OP_TYPE, UInt8Type),
ScalarValue::UInt16(v) => compute_primitive_op_dyn_scalar!($LEFT, v, $OP, $OP_TYPE, UInt16Type),
ScalarValue::UInt32(v) => compute_primitive_op_dyn_scalar!($LEFT, v, $OP, $OP_TYPE, UInt32Type),
ScalarValue::UInt64(v) => compute_primitive_op_dyn_scalar!($LEFT, v, $OP, $OP_TYPE, UInt64Type),
ScalarValue::Float32(v) => compute_primitive_op_dyn_scalar!($LEFT, v, $OP, $OP_TYPE, Float32Type),
ScalarValue::Float64(v) => compute_primitive_op_dyn_scalar!($LEFT, v, $OP, $OP_TYPE, Float64Type),
other => Err(DataFusionError::Internal(format!(
"Data type {:?} not supported for scalar operation '{}' on dyn array",
other, stringify!($OP)))
)
};

Some(result)
}}
}

/// Invoke a compute kernel on an array and a scalar
/// The binary_primitive_array_op_scalar macro only evaluates for primitive
/// types like integers and floats.
Expand Down Expand Up @@ -904,6 +977,7 @@ impl BinaryExpr {
scalar: &ScalarValue,
) -> Result<Option<Result<ArrayRef>>> {
let bool_type = &DataType::Boolean;
let left_type = array.data_type();
let scalar_result = match &self.op {
Operator::Lt => {
binary_array_op_dyn_scalar!(array, scalar.clone(), lt, bool_type)
Expand All @@ -924,18 +998,39 @@ impl BinaryExpr {
binary_array_op_dyn_scalar!(array, scalar.clone(), neq, bool_type)
}
Operator::Plus => {
binary_primitive_array_op_scalar!(array, scalar.clone(), add)
binary_primitive_array_op_dyn_scalar!(
array,
scalar.clone(),
add,
left_type
)
}
Operator::Minus => {
binary_primitive_array_op_scalar!(array, scalar.clone(), subtract)
binary_primitive_array_op_dyn_scalar!(
array,
scalar.clone(),
subtract,
left_type
)
}
Operator::Multiply => {
binary_primitive_array_op_scalar!(array, scalar.clone(), multiply)
binary_primitive_array_op_dyn_scalar!(
array,
scalar.clone(),
multiply,
left_type
)
}
Operator::Divide => {
binary_primitive_array_op_scalar!(array, scalar.clone(), divide)
binary_primitive_array_op_dyn_scalar!(
array,
scalar.clone(),
divide,
left_type
)
}
Operator::Modulo => {
// todo: change to binary_primitive_array_op_dyn_scalar! once modulo is implemented
binary_primitive_array_op_scalar!(array, scalar.clone(), modulus)
}
Operator::RegexMatch => binary_string_array_flag_op_scalar!(
Expand Down
Loading

0 comments on commit 699d7c2

Please sign in to comment.