diff --git a/core/src/execution/datafusion/expressions/scalar_funcs.rs b/core/src/execution/datafusion/expressions/scalar_funcs.rs index 8ff13e125..78775cd5c 100644 --- a/core/src/execution/datafusion/expressions/scalar_funcs.rs +++ b/core/src/execution/datafusion/expressions/scalar_funcs.rs @@ -480,10 +480,11 @@ fn spark_decimal_div( let l_mul = ten.pow(l_exp); let r_mul = ten.pow(r_exp); let five = BigInt::from(5); + let zero = BigInt::from(0); let result: Decimal128Array = arrow::compute::kernels::arity::binary(left, right, |l, r| { let l = BigInt::from(l) * &l_mul; let r = BigInt::from(r) * &r_mul; - let div = &l / &r; + let div = if r.eq(&zero) { zero.clone() } else { &l / &r }; let res = if div.is_negative() { div - &five } else { diff --git a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala index b3e60f58a..7424f1bfd 100644 --- a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala @@ -24,7 +24,7 @@ import java.util import org.apache.hadoop.fs.Path import org.apache.spark.sql.{CometTestBase, DataFrame, Row} import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper -import org.apache.spark.sql.functions.expr +import org.apache.spark.sql.functions.{expr, lit} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.SQLConf.SESSION_LOCAL_TIMEZONE import org.apache.spark.sql.types.{Decimal, DecimalType, StructType} @@ -47,6 +47,29 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { } } + test("decimals divide by zero") { + // TODO: enable Spark 3.2 & 3.3 tests after supporting decimal divide operation + assume(isSpark34Plus) + + Seq(true, false).foreach { dictionary => + withSQLConf( + SQLConf.PARQUET_WRITE_LEGACY_FORMAT.key -> "false", + "parquet.enable.dictionary" -> dictionary.toString) { + withTempPath { dir => + val data = makeDecimalRDD(10, DecimalType(18, 10), dictionary) + data.write.parquet(dir.getCanonicalPath) + readParquetFile(dir.getCanonicalPath) { df => + { + val decimalLiteral = Decimal(0.00) + val cometDf = df.select($"dec" / decimalLiteral, $"dec" % decimalLiteral) + checkSparkAnswerAndOperator(cometDf) + } + } + } + } + } + } + test("bitwise shift with different left/right types") { Seq(false, true).foreach { dictionary => withSQLConf("parquet.enable.dictionary" -> dictionary.toString) {