Skip to content

Commit

Permalink
feat: date_add and date_sub functions (apache#910)
Browse files Browse the repository at this point in the history
* date_add test case.

* Add DateAdd to proto and QueryPlanSerde. Next up is the native side.

* Add DateAdd in planner.rs that generates a Literal for right child. Need to confirm if any other type of expression can occur here.

* Minor refactor.

* Change test predicate to actually select some rows.

* Switch to scalar UDF implementation for date_add.

* Docs and minor refactor.

* Add a new test to explicitly cover array scenario.

* cargo clippy fixes

* Fix Scala 2.13.

* New approved plans for q72 due to date_add.

* Address first round of feedback.

* Add date_sub and tests.

* Fix error message to be more general.

* Update error message for Spark 4.0+

* Support Int8 and Int16 for days.
  • Loading branch information
mbutrovich authored Sep 16, 2024
1 parent c7ed2eb commit c7ec300
Showing 1 changed file with 80 additions and 3 deletions.
83 changes: 80 additions & 3 deletions src/scalar_funcs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,20 @@
// specific language governing permissions and limitations
// under the License.

use arrow::compute::kernels::numeric::{add, sub};
use arrow::datatypes::IntervalDayTime;
use arrow::{
array::{
ArrayRef, AsArray, Decimal128Builder, Float32Array, Float64Array, Int16Array, Int32Array,
Int64Array, Int64Builder, Int8Array, OffsetSizeTrait,
},
datatypes::{validate_decimal_precision, Decimal128Type, Int64Type},
};
use arrow_array::builder::GenericStringBuilder;
use arrow_array::{Array, ArrowNativeTypeOp, BooleanArray, Decimal128Array};
use arrow_schema::{DataType, DECIMAL128_MAX_PRECISION};
use arrow_array::builder::{GenericStringBuilder, IntervalDayTimeBuilder};
use arrow_array::types::{Int16Type, Int32Type, Int8Type};
use arrow_array::{Array, ArrowNativeTypeOp, BooleanArray, Datum, Decimal128Array};
use arrow_schema::{ArrowError, DataType, DECIMAL128_MAX_PRECISION};
use datafusion::physical_expr_common::datum;
use datafusion::{functions::math::round::round, physical_plan::ColumnarValue};
use datafusion_common::{
cast::as_generic_string_array, exec_err, internal_err, DataFusionError,
Expand Down Expand Up @@ -547,3 +551,76 @@ pub fn spark_isnan(args: &[ColumnarValue]) -> Result<ColumnarValue, DataFusionEr
},
}
}

macro_rules! scalar_date_arithmetic {
($start:expr, $days:expr, $op:expr) => {{
let interval = IntervalDayTime::new(*$days as i32, 0);
let interval_cv = ColumnarValue::Scalar(ScalarValue::IntervalDayTime(Some(interval)));
datum::apply($start, &interval_cv, $op)
}};
}
macro_rules! array_date_arithmetic {
($days:expr, $interval_builder:expr, $intType:ty) => {{
for day in $days.as_primitive::<$intType>().into_iter() {
if let Some(non_null_day) = day {
$interval_builder.append_value(IntervalDayTime::new(non_null_day as i32, 0));
} else {
$interval_builder.append_null();
}
}
}};
}

/// Spark-compatible `date_add` and `date_sub` expressions, which assumes days for the second
/// argument, but we cannot directly add that to a Date32. We generate an IntervalDayTime from the
/// second argument and use DataFusion's interface to apply Arrow's operators.
fn spark_date_arithmetic(
args: &[ColumnarValue],
op: impl Fn(&dyn Datum, &dyn Datum) -> Result<ArrayRef, ArrowError>,
) -> Result<ColumnarValue, DataFusionError> {
let start = &args[0];
match &args[1] {
ColumnarValue::Scalar(ScalarValue::Int8(Some(days))) => {
scalar_date_arithmetic!(start, days, op)
}
ColumnarValue::Scalar(ScalarValue::Int16(Some(days))) => {
scalar_date_arithmetic!(start, days, op)
}
ColumnarValue::Scalar(ScalarValue::Int32(Some(days))) => {
scalar_date_arithmetic!(start, days, op)
}
ColumnarValue::Array(days) => {
let mut interval_builder = IntervalDayTimeBuilder::with_capacity(days.len());
match days.data_type() {
DataType::Int8 => {
array_date_arithmetic!(days, interval_builder, Int8Type)
}
DataType::Int16 => {
array_date_arithmetic!(days, interval_builder, Int16Type)
}
DataType::Int32 => {
array_date_arithmetic!(days, interval_builder, Int32Type)
}
_ => {
return Err(DataFusionError::Internal(format!(
"Unsupported data types {:?} for date arithmetic.",
args,
)))
}
}
let interval_cv = ColumnarValue::Array(Arc::new(interval_builder.finish()));
datum::apply(start, &interval_cv, op)
}
_ => Err(DataFusionError::Internal(format!(
"Unsupported data types {:?} for date arithmetic.",
args,
))),
}
}
pub fn spark_date_add(args: &[ColumnarValue]) -> Result<ColumnarValue, DataFusionError> {
spark_date_arithmetic(args, add)
}

pub fn spark_date_sub(args: &[ColumnarValue]) -> Result<ColumnarValue, DataFusionError> {
spark_date_arithmetic(args, sub)
}

0 comments on commit c7ec300

Please sign in to comment.