Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Implement Spark-compatible CAST from floating-point/double to decimal #384

Merged
merged 11 commits into from
May 9, 2024
11 changes: 11 additions & 0 deletions core/src/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,13 @@ pub enum CometError {
to_type: String,
},

#[error("[NUMERIC_VALUE_OUT_OF_RANGE] {value} cannot be represented as Decimal({precision}, {scale}). If necessary set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error, and return NULL instead.")]
NumericValueOutOfRange {
value: String,
precision: u8,
scale: i8,
},

#[error("[CAST_OVERFLOW] The value {value} of the type \"{from_type}\" cannot be cast to \"{to_type}\" \
due to an overflow. Use `try_cast` to tolerate overflow and return NULL instead. If necessary \
set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error.")]
Expand Down Expand Up @@ -208,6 +215,10 @@ impl jni::errors::ToException for CometError {
class: "org/apache/spark/SparkException".to_string(),
msg: self.to_string(),
},
CometError::NumericValueOutOfRange { .. } => Exception {
class: "org/apache/spark/SparkException".to_string(),
msg: self.to_string(),
},
CometError::NumberIntFormat { source: s } => Exception {
class: "java/lang/NumberFormatException".to_string(),
msg: s.to_string(),
Expand Down
86 changes: 84 additions & 2 deletions core/src/execution/datafusion/expressions/cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,10 @@ use std::{
use crate::errors::{CometError, CometResult};
use arrow::{
compute::{cast_with_options, CastOptions},
datatypes::TimestampMicrosecondType,
datatypes::{
ArrowPrimitiveType, Decimal128Type, DecimalType, Float32Type, Float64Type,
TimestampMicrosecondType,
},
record_batch::RecordBatch,
util::display::FormatOptions,
};
Expand All @@ -39,7 +42,7 @@ use chrono::{TimeZone, Timelike};
use datafusion::logical_expr::ColumnarValue;
use datafusion_common::{internal_err, Result as DataFusionResult, ScalarValue};
use datafusion_physical_expr::PhysicalExpr;
use num::{traits::CheckedNeg, CheckedSub, Integer, Num};
use num::{cast::AsPrimitive, traits::CheckedNeg, CheckedSub, Integer, Num, ToPrimitive};
use regex::Regex;

use crate::execution::datafusion::expressions::utils::{
Expand Down Expand Up @@ -332,6 +335,12 @@ impl Cast {
(DataType::Float32, DataType::LargeUtf8) => {
Self::spark_cast_float32_to_utf8::<i64>(&array, self.eval_mode)?
}
(DataType::Float32, DataType::Decimal128(precision, scale)) => {
Self::cast_float32_to_decimal128(&array, *precision, *scale, self.eval_mode)?
}
(DataType::Float64, DataType::Decimal128(precision, scale)) => {
Self::cast_float64_to_decimal128(&array, *precision, *scale, self.eval_mode)?
}
_ => {
// when we have no Spark-specific casting we delegate to DataFusion
cast_with_options(&array, to_type, &CAST_OPTIONS)?
Expand Down Expand Up @@ -395,6 +404,79 @@ impl Cast {
Ok(cast_array)
}

fn cast_float64_to_decimal128(
array: &dyn Array,
precision: u8,
scale: i8,
eval_mode: EvalMode,
) -> CometResult<ArrayRef> {
Self::cast_floating_point_to_decimal128::<Float64Type>(array, precision, scale, eval_mode)
}

fn cast_float32_to_decimal128(
array: &dyn Array,
precision: u8,
scale: i8,
eval_mode: EvalMode,
) -> CometResult<ArrayRef> {
Self::cast_floating_point_to_decimal128::<Float32Type>(array, precision, scale, eval_mode)
}

fn cast_floating_point_to_decimal128<T: ArrowPrimitiveType>(
array: &dyn Array,
precision: u8,
scale: i8,
eval_mode: EvalMode,
) -> CometResult<ArrayRef>
where
<T as ArrowPrimitiveType>::Native: AsPrimitive<f64>,
{
let input = array.as_any().downcast_ref::<PrimitiveArray<T>>().unwrap();
let mut cast_array = PrimitiveArray::<Decimal128Type>::builder(input.len());

let mul = 10_f64.powi(scale as i32);

for i in 0..input.len() {
if input.is_null(i) {
cast_array.append_null();
} else {
let input_value = input.value(i).as_();
let value = (input_value * mul).round().to_i128();

match value {
Some(v) => {
if Decimal128Type::validate_decimal_precision(v, precision).is_err() {
return Err(CometError::NumericValueOutOfRange {
value: input_value.to_string(),
precision,
scale,
});
}
cast_array.append_value(v);
vaibhawvipul marked this conversation as resolved.
Show resolved Hide resolved
}
None => {
if eval_mode == EvalMode::Ansi {
return Err(CometError::NumericValueOutOfRange {
value: input_value.to_string(),
precision,
scale,
});
} else {
cast_array.append_null();
}
}
}
}
}

let res = Arc::new(
cast_array
.with_precision_and_scale(precision, scale)?
.finish(),
) as ArrayRef;
Ok(res)
}

fn spark_cast_float64_to_utf8<OffsetSize>(
from: &dyn Array,
_eval_mode: EvalMode,
Expand Down
4 changes: 2 additions & 2 deletions docs/source/user-guide/compatibility.md
Original file line number Diff line number Diff line change
Expand Up @@ -89,9 +89,11 @@ The following cast operations are generally compatible with Spark except for the
| long | string | |
| float | boolean | |
| float | double | |
| float | decimal | |
| float | string | There can be differences in precision. For example, the input "1.4E-45" will produce 1.0E-45 instead of 1.4E-45 |
| double | boolean | |
| double | float | |
| double | decimal | |
| double | string | There can be differences in precision. For example, the input "1.4E-45" will produce 1.0E-45 instead of 1.4E-45 |
| decimal | float | |
| decimal | double | |
Expand All @@ -115,8 +117,6 @@ The following cast operations are not compatible with Spark for all inputs and a
|-|-|-|
| integer | decimal | No overflow check |
| long | decimal | No overflow check |
| float | decimal | No overflow check |
| double | decimal | No overflow check |
| string | timestamp | Not all valid formats are supported |
| binary | string | Only works for binary data representing valid UTF-8 strings |

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -227,13 +227,13 @@ object CometCast {

private def canCastFromFloat(toType: DataType): SupportLevel = toType match {
case DataTypes.BooleanType | DataTypes.DoubleType => Compatible()
case _: DecimalType => Incompatible(Some("No overflow check"))
case _: DecimalType => Compatible()
case _ => Unsupported
}

private def canCastFromDouble(toType: DataType): SupportLevel = toType match {
case DataTypes.BooleanType | DataTypes.FloatType => Compatible()
case _: DecimalType => Incompatible(Some("No overflow check"))
case _: DecimalType => Compatible()
case _ => Unsupported
}

Expand Down
16 changes: 11 additions & 5 deletions spark/src/test/scala/org/apache/comet/CometCastSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -344,8 +344,7 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
castTest(generateFloats(), DataTypes.DoubleType)
}

ignore("cast FloatType to DecimalType(10,2)") {
// Comet should have failed with [NUMERIC_VALUE_OUT_OF_RANGE]
test("cast FloatType to DecimalType(10,2)") {
castTest(generateFloats(), DataTypes.createDecimalType(10, 2))
}

Expand Down Expand Up @@ -402,8 +401,7 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
castTest(generateDoubles(), DataTypes.FloatType)
}

ignore("cast DoubleType to DecimalType(10,2)") {
// Comet should have failed with [NUMERIC_VALUE_OUT_OF_RANGE]
test("cast DoubleType to DecimalType(10,2)") {
castTest(generateDoubles(), DataTypes.createDecimalType(10, 2))
}

Expand Down Expand Up @@ -960,11 +958,19 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
val cometMessageModified = cometMessage
.replace("[CAST_INVALID_INPUT] ", "")
.replace("[CAST_OVERFLOW] ", "")
assert(cometMessageModified == sparkMessage)
.replace("[NUMERIC_VALUE_OUT_OF_RANGE] ", "")

if (sparkMessage.contains("cannot be represented as")) {
assert(cometMessage.contains("cannot be represented as"))
} else {
assert(cometMessageModified == sparkMessage)
}
} else {
// for Spark 3.2 we just make sure we are seeing a similar type of error
if (sparkMessage.contains("causes overflow")) {
assert(cometMessage.contains("due to an overflow"))
} else if (sparkMessage.contains("cannot be represented as")) {
assert(cometMessage.contains("cannot be represented as"))
Comment on lines +1006 to +1016
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can see that the approach we have for handling error message comparison for Spark 3.2 and 3.3 needs some rethinking. I am going to make a proposal to improve this.

#402

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

okay, that would be great.

} else {
// assume that this is an invalid input message in the form:
// `invalid input syntax for type numeric: -9223372036854775809`
Expand Down
Loading