Skip to content

Commit

Permalink
Check overflow when casting floating point value to decimal256 (#3033)
Browse files Browse the repository at this point in the history
* Check overflow when casting floating point value to decimal256

* Add from_f64
  • Loading branch information
viirya authored Nov 7, 2022
1 parent 12f0ef4 commit 6dd9dae
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 5 deletions.
15 changes: 14 additions & 1 deletion arrow-buffer/src/bigint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
// under the License.

use num::cast::AsPrimitive;
use num::BigInt;
use num::{BigInt, FromPrimitive};
use std::cmp::Ordering;

/// A signed 256-bit integer
Expand Down Expand Up @@ -102,6 +102,19 @@ impl i256 {
Self::from_parts(v as u128, v >> 127)
}

/// Create an optional i256 from the provided `f64`. Returning `None`
/// if overflow occurred
pub fn from_f64(v: f64) -> Option<Self> {
BigInt::from_f64(v).and_then(|i| {
let (integer, overflow) = i256::from_bigint_with_overflow(i);
if overflow {
None
} else {
Some(integer)
}
})
}

/// Create an i256 from the provided low u128 and high i128
#[inline]
pub const fn from_parts(low: u128, high: i128) -> Self {
Expand Down
59 changes: 55 additions & 4 deletions arrow-cast/src/cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -387,16 +387,38 @@ fn cast_floating_point_to_decimal256<T: ArrowPrimitiveType>(
array: &PrimitiveArray<T>,
precision: u8,
scale: u8,
cast_options: &CastOptions,
) -> Result<ArrayRef, ArrowError>
where
<T as ArrowPrimitiveType>::Native: AsPrimitive<f64>,
{
let mul = 10_f64.powi(scale as i32);

array
.unary::<_, Decimal256Type>(|v| i256::from_i128((v.as_() * mul).round() as i128))
.with_precision_and_scale(precision, scale)
.map(|a| Arc::new(a) as ArrayRef)
if cast_options.safe {
let iter = array
.iter()
.map(|v| v.and_then(|v| i256::from_f64((v.as_() * mul).round())));
let casted_array =
unsafe { PrimitiveArray::<Decimal256Type>::from_trusted_len_iter(iter) };
casted_array
.with_precision_and_scale(precision, scale)
.map(|a| Arc::new(a) as ArrayRef)
} else {
array
.try_unary::<_, Decimal256Type, _>(|v| {
i256::from_f64((v.as_() * mul).round()).ok_or_else(|| {
ArrowError::CastError(format!(
"Cannot cast to {}({}, {}). Overflowing on {:?}",
Decimal256Type::PREFIX,
precision,
scale,
v
))
})
})
.and_then(|a| a.with_precision_and_scale(precision, scale))
.map(|a| Arc::new(a) as ArrayRef)
}
}

/// Cast the primitive array using [`PrimitiveArray::reinterpret_cast`]
Expand Down Expand Up @@ -666,11 +688,13 @@ pub fn cast_with_options(
as_primitive_array::<Float32Type>(array),
*precision,
*scale,
cast_options,
),
Float64 => cast_floating_point_to_decimal256(
as_primitive_array::<Float64Type>(array),
*precision,
*scale,
cast_options,
),
Null => Ok(new_null_array(to_type, array.len())),
_ => Err(ArrowError::CastError(format!(
Expand Down Expand Up @@ -6166,4 +6190,31 @@ mod tests {
err
);
}

#[test]
fn test_cast_floating_point_to_decimal256_overflow() {
let array = Float64Array::from(vec![f64::MAX]);
let array = Arc::new(array) as ArrayRef;
let casted_array = cast_with_options(
&array,
&DataType::Decimal256(76, 50),
&CastOptions { safe: true },
);
assert!(casted_array.is_ok());
assert!(casted_array.unwrap().is_null(0));

let casted_array = cast_with_options(
&array,
&DataType::Decimal256(76, 50),
&CastOptions { safe: false },
);
let err = casted_array.unwrap_err().to_string();
let expected_error = "Cast error: Cannot cast to Decimal256(76, 50)";
assert!(
err.contains(expected_error),
"did not find expected error '{}' in actual error '{}'",
expected_error,
err
);
}
}

0 comments on commit 6dd9dae

Please sign in to comment.