Skip to content

Commit

Permalink
fix: attempt to divide by zero error on decimal division
Browse files Browse the repository at this point in the history
  • Loading branch information
viirya committed Mar 6, 2024
1 parent 9a58880 commit 0fe9dd1
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 2 deletions.
3 changes: 2 additions & 1 deletion core/src/execution/datafusion/expressions/scalar_funcs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -480,10 +480,11 @@ fn spark_decimal_div(
let l_mul = ten.pow(l_exp);
let r_mul = ten.pow(r_exp);
let five = BigInt::from(5);
let zero = BigInt::from(0);
let result: Decimal128Array = arrow::compute::kernels::arity::binary(left, right, |l, r| {
let l = BigInt::from(l) * &l_mul;
let r = BigInt::from(r) * &r_mul;
let div = &l / &r;
let div = if r.eq(&zero) { zero.clone() } else { &l / &r };
let res = if div.is_negative() {
div - &five
} else {
Expand Down
31 changes: 30 additions & 1 deletion spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import java.util
import org.apache.hadoop.fs.Path
import org.apache.spark.sql.{CometTestBase, DataFrame, Row}
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
import org.apache.spark.sql.functions.expr
import org.apache.spark.sql.functions.{expr, lit}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.SQLConf.SESSION_LOCAL_TIMEZONE
import org.apache.spark.sql.types.{Decimal, DecimalType, StructType}
Expand All @@ -47,6 +47,35 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper {
}
}

test("decimals divide by zero") {
// TODO: enable Spark 3.2 & 3.3 tests after supporting decimal reminder operation
assume(isSpark34Plus)

Seq(true, false).foreach { dictionary =>
withSQLConf(
SQLConf.PARQUET_WRITE_LEGACY_FORMAT.key -> "false",
"parquet.enable.dictionary" -> dictionary.toString) {
withTempPath { dir =>
val data = makeDecimalRDD(10, DecimalType(18, 10), dictionary)
data.write.parquet(dir.getCanonicalPath)
readParquetFile(dir.getCanonicalPath) { df =>
{
val decimalLiteral = Decimal(0.00)
val cometDf = df.select($"dec" / decimalLiteral)

checkAnswer(
cometDf,
data
.select($"dec" / decimalLiteral)
.collect()
.toSeq)
}
}
}
}
}
}

test("bitwise shift with different left/right types") {
Seq(false, true).foreach { dictionary =>
withSQLConf("parquet.enable.dictionary" -> dictionary.toString) {
Expand Down

0 comments on commit 0fe9dd1

Please sign in to comment.