diff --git a/datafusion/physical-expr/src/expressions/binary.rs b/datafusion/physical-expr/src/expressions/binary.rs index 1c2bb065b552..b2023574848d 100644 --- a/datafusion/physical-expr/src/expressions/binary.rs +++ b/datafusion/physical-expr/src/expressions/binary.rs @@ -1694,6 +1694,49 @@ mod tests { Ok(()) } + #[test] + fn plus_op_dict_scalar_decimal() -> Result<()> { + let schema = Schema::new(vec![Field::new( + "a", + DataType::Dictionary( + Box::new(DataType::Int8), + Box::new(DataType::Decimal128(10, 0)), + ), + true, + )]); + + let value = 123; + let decimal_array = Arc::new(create_decimal_array( + &[Some(value), None, Some(value - 1), Some(value + 1)], + 10, + 0, + )) as ArrayRef; + + let keys = Int8Array::from(vec![0, 2, 1, 3, 0]); + let a = DictionaryArray::try_new(&keys, &decimal_array)?; + + let keys = Int8Array::from(vec![0, 2, 1, 3, 0]); + let decimal_array = create_decimal_array( + &[Some(value + 1), None, Some(value), Some(value + 2)], + 10, + 0, + ); + let expected = DictionaryArray::try_new(&keys, &decimal_array)?; + + apply_arithmetic_scalar( + Arc::new(schema), + vec![Arc::new(a)], + Operator::Plus, + ScalarValue::Dictionary( + Box::new(DataType::Int8), + Box::new(ScalarValue::Decimal128(Some(1), 10, 0)), + ), + Arc::new(expected), + )?; + + Ok(()) + } + #[test] fn minus_op() -> Result<()> { let schema = Arc::new(Schema::new(vec![ @@ -1776,6 +1819,49 @@ mod tests { Ok(()) } + #[test] + fn minus_op_dict_scalar_decimal() -> Result<()> { + let schema = Schema::new(vec![Field::new( + "a", + DataType::Dictionary( + Box::new(DataType::Int8), + Box::new(DataType::Decimal128(10, 0)), + ), + true, + )]); + + let value = 123; + let decimal_array = Arc::new(create_decimal_array( + &[Some(value), None, Some(value - 1), Some(value + 1)], + 10, + 0, + )) as ArrayRef; + + let keys = Int8Array::from(vec![0, 2, 1, 3, 0]); + let a = DictionaryArray::try_new(&keys, &decimal_array)?; + + let keys = Int8Array::from(vec![0, 2, 1, 3, 0]); + let decimal_array = create_decimal_array( + &[Some(value - 1), None, Some(value - 2), Some(value)], + 10, + 0, + ); + let expected = DictionaryArray::try_new(&keys, &decimal_array)?; + + apply_arithmetic_scalar( + Arc::new(schema), + vec![Arc::new(a)], + Operator::Minus, + ScalarValue::Dictionary( + Box::new(DataType::Int8), + Box::new(ScalarValue::Decimal128(Some(1), 10, 0)), + ), + Arc::new(expected), + )?; + + Ok(()) + } + #[test] fn multiply_op() -> Result<()> { let schema = Arc::new(Schema::new(vec![ @@ -1850,6 +1936,46 @@ mod tests { Ok(()) } + #[test] + fn multiply_op_dict_scalar_decimal() -> Result<()> { + let schema = Schema::new(vec![Field::new( + "a", + DataType::Dictionary( + Box::new(DataType::Int8), + Box::new(DataType::Decimal128(10, 0)), + ), + true, + )]); + + let value = 123; + let decimal_array = Arc::new(create_decimal_array( + &[Some(value), None, Some(value - 1), Some(value + 1)], + 10, + 0, + )) as ArrayRef; + + let keys = Int8Array::from(vec![0, 2, 1, 3, 0]); + let a = DictionaryArray::try_new(&keys, &decimal_array)?; + + let keys = Int8Array::from(vec![0, 2, 1, 3, 0]); + let decimal_array = + create_decimal_array(&[Some(246), None, Some(244), Some(248)], 10, 0); + let expected = DictionaryArray::try_new(&keys, &decimal_array)?; + + apply_arithmetic_scalar( + Arc::new(schema), + vec![Arc::new(a)], + Operator::Multiply, + ScalarValue::Dictionary( + Box::new(DataType::Int8), + Box::new(ScalarValue::Decimal128(Some(2), 10, 0)), + ), + Arc::new(expected), + )?; + + Ok(()) + } + #[test] fn divide_op() -> Result<()> { let schema = Arc::new(Schema::new(vec![ @@ -1924,6 +2050,46 @@ mod tests { Ok(()) } + #[test] + fn divide_op_dict_scalar_decimal() -> Result<()> { + let schema = Schema::new(vec![Field::new( + "a", + DataType::Dictionary( + Box::new(DataType::Int8), + Box::new(DataType::Decimal128(10, 0)), + ), + true, + )]); + + let value = 123; + let decimal_array = Arc::new(create_decimal_array( + &[Some(value), None, Some(value - 1), Some(value + 1)], + 10, + 0, + )) as ArrayRef; + + let keys = Int8Array::from(vec![0, 2, 1, 3, 0]); + let a = DictionaryArray::try_new(&keys, &decimal_array)?; + + let keys = Int8Array::from(vec![0, 2, 1, 3, 0]); + let decimal_array = + create_decimal_array(&[Some(61), None, Some(61), Some(62)], 10, 0); + let expected = DictionaryArray::try_new(&keys, &decimal_array)?; + + apply_arithmetic_scalar( + Arc::new(schema), + vec![Arc::new(a)], + Operator::Divide, + ScalarValue::Dictionary( + Box::new(DataType::Int8), + Box::new(ScalarValue::Decimal128(Some(2), 10, 0)), + ), + Arc::new(expected), + )?; + + Ok(()) + } + #[test] fn modulus_op() -> Result<()> { let schema = Arc::new(Schema::new(vec![ diff --git a/datafusion/physical-expr/src/expressions/binary/kernels_arrow.rs b/datafusion/physical-expr/src/expressions/binary/kernels_arrow.rs index 40e0d2b0ed90..b75040f41974 100644 --- a/datafusion/physical-expr/src/expressions/binary/kernels_arrow.rs +++ b/datafusion/physical-expr/src/expressions/binary/kernels_arrow.rs @@ -189,28 +189,20 @@ pub(crate) fn add_decimal( } pub(crate) fn add_decimal_dyn_scalar(left: &dyn Array, right: i128) -> Result { - let left_decimal = left.as_any().downcast_ref::().unwrap(); + let (precision, scale) = get_precision_scale(left)?; let array = add_scalar_dyn::(left, right)?; - let decimal_array = as_decimal128_array(&array)?; - let decimal_array = decimal_array - .clone() - .with_precision_and_scale(left_decimal.precision(), left_decimal.scale())?; - Ok(Arc::new(decimal_array)) + decimal_array_with_precision_scale(array, precision, scale) } pub(crate) fn subtract_decimal_dyn_scalar( left: &dyn Array, right: i128, ) -> Result { - let left_decimal = left.as_any().downcast_ref::().unwrap(); + let (precision, scale) = get_precision_scale(left)?; let array = subtract_scalar_dyn::(left, right)?; - let decimal_array = as_decimal128_array(&array)?; - let decimal_array = decimal_array - .clone() - .with_precision_and_scale(left_decimal.precision(), left_decimal.scale())?; - Ok(Arc::new(decimal_array)) + decimal_array_with_precision_scale(array, precision, scale) } fn get_precision_scale(left: &dyn Array) -> Result<(u8, i8)> {