Skip to content

Commit

Permalink
Add modulus_dyn and modulus_scalar_dyn
Browse files Browse the repository at this point in the history
  • Loading branch information
viirya committed Feb 1, 2023
1 parent f9a78e0 commit aab7b7a
Showing 1 changed file with 105 additions and 1 deletion.
106 changes: 105 additions & 1 deletion arrow-arith/src/arithmetic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1333,6 +1333,45 @@ where
});
}

/// Perform `left % right` operation on two arrays. If either left or right value is null
/// then the result is also null. If any right hand value is zero then the result of this
/// operation will be `Err(ArrowError::DivideByZero)`.
pub fn modulus_dyn(left: &dyn Array, right: &dyn Array) -> Result<ArrayRef, ArrowError> {
match left.data_type() {
DataType::Dictionary(_, _) => {
typed_dict_math_op!(
left,
right,
|a, b| {
if b.is_zero() {
Err(ArrowError::DivideByZero)
} else {
Ok(a.mod_wrapping(b))
}
},
math_divide_checked_op_dict
)
}
_ => {
downcast_primitive_array!(
(left, right) => {
math_checked_divide_op(left, right, |a, b| {
if b.is_zero() {
Err(ArrowError::DivideByZero)
} else {
Ok(a.mod_wrapping(b))
}
}).map(|a| Arc::new(a) as ArrayRef)
}
_ => Err(ArrowError::CastError(format!(
"Unsupported data type {}, {}",
left.data_type(), right.data_type()
)))
)
}
}
}

/// Perform `left / right` operation on two arrays. If either left or right value is null
/// then the result is also null. If any right hand value is zero then the result of this
/// operation will be `Err(ArrowError::DivideByZero)`.
Expand Down Expand Up @@ -1551,6 +1590,23 @@ where
Ok(unary(array, |a| a.mod_wrapping(modulo)))
}

/// Modulus every value in an array by a scalar. If any value in the array is null then the
/// result is also null. If the scalar is zero then the result of this operation will be
/// `Err(ArrowError::DivideByZero)`.
pub fn modulus_scalar_dyn<T>(
array: &dyn Array,
modulo: T::Native,
) -> Result<ArrayRef, ArrowError>
where
T: ArrowNumericType,
T::Native: ArrowNativeTypeOp,
{
if modulo.is_zero() {
return Err(ArrowError::DivideByZero);
}
unary_dyn::<_, T>(array, |value| value.mod_wrapping(modulo))
}

/// Divide every value in an array by a scalar. If any value in the array is null then the
/// result is also null. If the scalar is zero then the result of this operation will be
/// `Err(ArrowError::DivideByZero)`.
Expand Down Expand Up @@ -2170,6 +2226,14 @@ mod tests {
assert_eq!(0, c.value(2));
assert_eq!(1, c.value(3));
assert_eq!(0, c.value(4));

let c = modulus_dyn(&a, &b).unwrap();
let c = as_primitive_array::<Int32Type>(&c);
assert_eq!(0, c.value(0));
assert_eq!(3, c.value(1));
assert_eq!(0, c.value(2));
assert_eq!(1, c.value(3));
assert_eq!(0, c.value(4));
}

#[test]
Expand All @@ -2182,6 +2246,16 @@ mod tests {
modulus(&a, &b).unwrap();
}

#[test]
#[should_panic(
expected = "called `Result::unwrap()` on an `Err` value: DivideByZero"
)]
fn test_int_array_modulus_dyn_divide_by_zero() {
let a = Int32Array::from(vec![1]);
let b = Int32Array::from(vec![0]);
modulus_dyn(&a, &b).unwrap();
}

#[test]
fn test_int_array_modulus_overflow_wrapping() {
let a = Int32Array::from(vec![i32::MIN]);
Expand Down Expand Up @@ -2258,6 +2332,11 @@ mod tests {
let c = modulus_scalar(&a, b).unwrap();
let expected = Int32Array::from(vec![0, 2, 0, 2, 1]);
assert_eq!(c, expected);

let c = modulus_scalar_dyn::<Int32Type>(&a, b).unwrap();
let c = as_primitive_array::<Int32Type>(&c);
let expected = Int32Array::from(vec![0, 2, 0, 2, 1]);
assert_eq!(c, &expected);
}

#[test]
Expand All @@ -2268,6 +2347,11 @@ mod tests {
let actual = modulus_scalar(a, 3).unwrap();
let expected = Int32Array::from(vec![None, Some(0), Some(2), None]);
assert_eq!(actual, expected);

let actual = modulus_scalar_dyn::<Int32Type>(a, 3).unwrap();
let actual = as_primitive_array::<Int32Type>(&actual);
let expected = Int32Array::from(vec![None, Some(0), Some(2), None]);
assert_eq!(actual, &expected);
}

#[test]
Expand All @@ -2283,7 +2367,11 @@ mod tests {
fn test_int_array_modulus_scalar_overflow_wrapping() {
let a = Int32Array::from(vec![i32::MIN]);
let result = modulus_scalar(&a, -1).unwrap();
assert_eq!(0, result.value(0))
assert_eq!(0, result.value(0));

let result = modulus_scalar_dyn::<Int32Type>(&a, -1).unwrap();
let result = as_primitive_array::<Int32Type>(&result);
assert_eq!(0, result.value(0));
}

#[test]
Expand Down Expand Up @@ -2566,6 +2654,14 @@ mod tests {
modulus(&a, &b).unwrap();
}

#[test]
#[should_panic(expected = "DivideByZero")]
fn test_i32_array_modulus_dyn_by_zero() {
let a = Int32Array::from(vec![15]);
let b = Int32Array::from(vec![0]);
modulus_dyn(&a, &b).unwrap();
}

#[test]
#[should_panic(expected = "DivideByZero")]
fn test_f32_array_modulus_by_zero() {
Expand All @@ -2574,6 +2670,14 @@ mod tests {
modulus(&a, &b).unwrap();
}

#[test]
#[should_panic(expected = "DivideByZero")]
fn test_f32_array_modulus_dyn_by_zero() {
let a = Float32Array::from(vec![1.5]);
let b = Float32Array::from(vec![0.0]);
modulus_dyn(&a, &b).unwrap();
}

#[test]
fn test_f64_array_divide() {
let a = Float64Array::from(vec![15.0, 15.0, 8.0]);
Expand Down

0 comments on commit aab7b7a

Please sign in to comment.