diff --git a/src/scalar_funcs.rs b/src/scalar_funcs.rs index f4270220d708..5cc3f3dd7e9a 100644 --- a/src/scalar_funcs.rs +++ b/src/scalar_funcs.rs @@ -15,6 +15,8 @@ // specific language governing permissions and limitations // under the License. +use arrow::compute::kernels::numeric::{add, sub}; +use arrow::datatypes::IntervalDayTime; use arrow::{ array::{ ArrayRef, AsArray, Decimal128Builder, Float32Array, Float64Array, Int16Array, Int32Array, @@ -22,9 +24,11 @@ use arrow::{ }, datatypes::{validate_decimal_precision, Decimal128Type, Int64Type}, }; -use arrow_array::builder::GenericStringBuilder; -use arrow_array::{Array, ArrowNativeTypeOp, BooleanArray, Decimal128Array}; -use arrow_schema::{DataType, DECIMAL128_MAX_PRECISION}; +use arrow_array::builder::{GenericStringBuilder, IntervalDayTimeBuilder}; +use arrow_array::types::{Int16Type, Int32Type, Int8Type}; +use arrow_array::{Array, ArrowNativeTypeOp, BooleanArray, Datum, Decimal128Array}; +use arrow_schema::{ArrowError, DataType, DECIMAL128_MAX_PRECISION}; +use datafusion::physical_expr_common::datum; use datafusion::{functions::math::round::round, physical_plan::ColumnarValue}; use datafusion_common::{ cast::as_generic_string_array, exec_err, internal_err, DataFusionError, @@ -547,3 +551,76 @@ pub fn spark_isnan(args: &[ColumnarValue]) -> Result {{ + let interval = IntervalDayTime::new(*$days as i32, 0); + let interval_cv = ColumnarValue::Scalar(ScalarValue::IntervalDayTime(Some(interval))); + datum::apply($start, &interval_cv, $op) + }}; +} +macro_rules! array_date_arithmetic { + ($days:expr, $interval_builder:expr, $intType:ty) => {{ + for day in $days.as_primitive::<$intType>().into_iter() { + if let Some(non_null_day) = day { + $interval_builder.append_value(IntervalDayTime::new(non_null_day as i32, 0)); + } else { + $interval_builder.append_null(); + } + } + }}; +} + +/// Spark-compatible `date_add` and `date_sub` expressions, which assumes days for the second +/// argument, but we cannot directly add that to a Date32. We generate an IntervalDayTime from the +/// second argument and use DataFusion's interface to apply Arrow's operators. +fn spark_date_arithmetic( + args: &[ColumnarValue], + op: impl Fn(&dyn Datum, &dyn Datum) -> Result, +) -> Result { + let start = &args[0]; + match &args[1] { + ColumnarValue::Scalar(ScalarValue::Int8(Some(days))) => { + scalar_date_arithmetic!(start, days, op) + } + ColumnarValue::Scalar(ScalarValue::Int16(Some(days))) => { + scalar_date_arithmetic!(start, days, op) + } + ColumnarValue::Scalar(ScalarValue::Int32(Some(days))) => { + scalar_date_arithmetic!(start, days, op) + } + ColumnarValue::Array(days) => { + let mut interval_builder = IntervalDayTimeBuilder::with_capacity(days.len()); + match days.data_type() { + DataType::Int8 => { + array_date_arithmetic!(days, interval_builder, Int8Type) + } + DataType::Int16 => { + array_date_arithmetic!(days, interval_builder, Int16Type) + } + DataType::Int32 => { + array_date_arithmetic!(days, interval_builder, Int32Type) + } + _ => { + return Err(DataFusionError::Internal(format!( + "Unsupported data types {:?} for date arithmetic.", + args, + ))) + } + } + let interval_cv = ColumnarValue::Array(Arc::new(interval_builder.finish())); + datum::apply(start, &interval_cv, op) + } + _ => Err(DataFusionError::Internal(format!( + "Unsupported data types {:?} for date arithmetic.", + args, + ))), + } +} +pub fn spark_date_add(args: &[ColumnarValue]) -> Result { + spark_date_arithmetic(args, add) +} + +pub fn spark_date_sub(args: &[ColumnarValue]) -> Result { + spark_date_arithmetic(args, sub) +}