Skip to content

Commit

Permalink
[SPARK-48016][SQL] Fix a bug in try_divide function when with decimals
Browse files Browse the repository at this point in the history
 Currently, the following query will throw DIVIDE_BY_ZERO error instead of returning null
 ```
SELECT try_divide(1, decimal(0));
```

This is caused by the rule `DecimalPrecision`:
```
case b  BinaryOperator(left, right) if left.dataType != right.dataType =>
  (left, right) match {
 ...
    case (l: Literal, r) if r.dataType.isInstanceOf[DecimalType] &&
        l.dataType.isInstanceOf[IntegralType] &&
        literalPickMinimumPrecision =>
      b.makeCopy(Array(Cast(l, DataTypeUtils.fromLiteral(l)), r))
```
The result of the above makeCopy will contain `ANSI` as the `evalMode`, instead of `TRY`.
This PR is to fix this bug by replacing the makeCopy method calls with withNewChildren

Bug fix in try_* functions.

Yes, it fixes a long-standing bug in the try_divide function.

New UT

No

Closes #46286 from gengliangwang/avoidMakeCopy.

Authored-by: Gengliang Wang <gengliang@apache.org>
Signed-off-by: Gengliang Wang <gengliang@apache.org>
(cherry picked from commit 3fbcb26)
Signed-off-by: Gengliang Wang <gengliang@apache.org>
  • Loading branch information
gengliangwang committed Apr 29, 2024
1 parent 616c216 commit e78ee2c
Show file tree
Hide file tree
Showing 8 changed files with 261 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ object DecimalPrecision extends TypeCoercionRule {
val resultType = widerDecimalType(p1, s1, p2, s2)
val newE1 = if (e1.dataType == resultType) e1 else Cast(e1, resultType)
val newE2 = if (e2.dataType == resultType) e2 else Cast(e2, resultType)
b.makeCopy(Array(newE1, newE2))
b.withNewChildren(Seq(newE1, newE2))
}

/**
Expand Down Expand Up @@ -202,21 +202,21 @@ object DecimalPrecision extends TypeCoercionRule {
case (l: Literal, r) if r.dataType.isInstanceOf[DecimalType] &&
l.dataType.isInstanceOf[IntegralType] &&
literalPickMinimumPrecision =>
b.makeCopy(Array(Cast(l, DataTypeUtils.fromLiteral(l)), r))
b.withNewChildren(Seq(Cast(l, DataTypeUtils.fromLiteral(l)), r))
case (l, r: Literal) if l.dataType.isInstanceOf[DecimalType] &&
r.dataType.isInstanceOf[IntegralType] &&
literalPickMinimumPrecision =>
b.makeCopy(Array(l, Cast(r, DataTypeUtils.fromLiteral(r))))
b.withNewChildren(Seq(l, Cast(r, DataTypeUtils.fromLiteral(r))))
// Promote integers inside a binary expression with fixed-precision decimals to decimals,
// and fixed-precision decimals in an expression with floats / doubles to doubles
case (l @ IntegralTypeExpression(), r @ DecimalExpression(_, _)) =>
b.makeCopy(Array(Cast(l, DecimalType.forType(l.dataType)), r))
b.withNewChildren(Seq(Cast(l, DecimalType.forType(l.dataType)), r))
case (l @ DecimalExpression(_, _), r @ IntegralTypeExpression()) =>
b.makeCopy(Array(l, Cast(r, DecimalType.forType(r.dataType))))
b.withNewChildren(Seq(l, Cast(r, DecimalType.forType(r.dataType))))
case (l, r @ DecimalExpression(_, _)) if isFloat(l.dataType) =>
b.makeCopy(Array(l, Cast(r, DoubleType)))
b.withNewChildren(Seq(l, Cast(r, DoubleType)))
case (l @ DecimalExpression(_, _), r) if isFloat(r.dataType) =>
b.makeCopy(Array(Cast(l, DoubleType), r))
b.withNewChildren(Seq(Cast(l, DoubleType), r))
case _ => b
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1102,22 +1102,22 @@ object TypeCoercion extends TypeCoercionBase {

case a @ BinaryArithmetic(left @ StringTypeExpression(), right)
if right.dataType != CalendarIntervalType =>
a.makeCopy(Array(Cast(left, DoubleType), right))
a.withNewChildren(Seq(Cast(left, DoubleType), right))
case a @ BinaryArithmetic(left, right @ StringTypeExpression())
if left.dataType != CalendarIntervalType =>
a.makeCopy(Array(left, Cast(right, DoubleType)))
a.withNewChildren(Seq(left, Cast(right, DoubleType)))

// For equality between string and timestamp we cast the string to a timestamp
// so that things like rounding of subsecond precision does not affect the comparison.
case p @ Equality(left @ StringTypeExpression(), right @ TimestampTypeExpression()) =>
p.makeCopy(Array(Cast(left, TimestampType), right))
p.withNewChildren(Seq(Cast(left, TimestampType), right))
case p @ Equality(left @ TimestampTypeExpression(), right @ StringTypeExpression()) =>
p.makeCopy(Array(left, Cast(right, TimestampType)))
p.withNewChildren(Seq(left, Cast(right, TimestampType)))

case p @ BinaryComparison(left, right)
if findCommonTypeForBinaryComparison(left.dataType, right.dataType, conf).isDefined =>
val commonType = findCommonTypeForBinaryComparison(left.dataType, right.dataType, conf).get
p.makeCopy(Array(castExpr(left, commonType), castExpr(right, commonType)))
p.withNewChildren(Seq(castExpr(left, commonType), castExpr(right, commonType)))
}
}

Expand Down
2 changes: 1 addition & 1 deletion sql/core/src/test/resources/log4j2.properties
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ logger.parquet_recordwriter.name = org.apache.parquet.hadoop.InternalParquetReco
logger.parquet_recordwriter.additivity = false
logger.parquet_recordwriter.level = off

logger.parquet_outputcommitter.name = org.apache.parquet.hadoop.ParquetOutputCommitter
logger.parquet_outputcommitter.name = org.sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scalaapache.parquet.hadoop.ParquetOutputCommitter
logger.parquet_outputcommitter.additivity = false
logger.parquet_outputcommitter.level = off

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,20 @@ Project [try_add(2147483647, 1) AS try_add(2147483647, 1)#x]
+- OneRowRelation


-- !query
SELECT try_add(2147483647, decimal(1))
-- !query analysis
Project [try_add(2147483647, cast(1 as decimal(10,0))) AS try_add(2147483647, 1)#x]
+- OneRowRelation


-- !query
SELECT try_add(2147483647, "1")
-- !query analysis
Project [try_add(2147483647, 1) AS try_add(2147483647, 1)#xL]
+- OneRowRelation


-- !query
SELECT try_add(-2147483648, -1)
-- !query analysis
Expand Down Expand Up @@ -211,6 +225,20 @@ Project [try_divide(1, (1.0 / 0.0)) AS try_divide(1, (1.0 / 0.0))#x]
+- OneRowRelation


-- !query
SELECT try_divide(1, decimal(0))
-- !query analysis
Project [try_divide(1, cast(0 as decimal(10,0))) AS try_divide(1, 0)#x]
+- OneRowRelation


-- !query
SELECT try_divide(1, "0")
-- !query analysis
Project [try_divide(1, 0) AS try_divide(1, 0)#x]
+- OneRowRelation


-- !query
SELECT try_divide(interval 2 year, 2)
-- !query analysis
Expand Down Expand Up @@ -267,6 +295,20 @@ Project [try_subtract(2147483647, -1) AS try_subtract(2147483647, -1)#x]
+- OneRowRelation


-- !query
SELECT try_subtract(2147483647, decimal(-1))
-- !query analysis
Project [try_subtract(2147483647, cast(-1 as decimal(10,0))) AS try_subtract(2147483647, -1)#x]
+- OneRowRelation


-- !query
SELECT try_subtract(2147483647, "-1")
-- !query analysis
Project [try_subtract(2147483647, -1) AS try_subtract(2147483647, -1)#xL]
+- OneRowRelation


-- !query
SELECT try_subtract(-2147483648, 1)
-- !query analysis
Expand Down Expand Up @@ -351,6 +393,20 @@ Project [try_multiply(2147483647, -2) AS try_multiply(2147483647, -2)#x]
+- OneRowRelation


-- !query
SELECT try_multiply(2147483647, decimal(-2))
-- !query analysis
Project [try_multiply(2147483647, cast(-2 as decimal(10,0))) AS try_multiply(2147483647, -2)#x]
+- OneRowRelation


-- !query
SELECT try_multiply(2147483647, "-2")
-- !query analysis
Project [try_multiply(2147483647, -2) AS try_multiply(2147483647, -2)#xL]
+- OneRowRelation


-- !query
SELECT try_multiply(-2147483648, 2)
-- !query analysis
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,20 @@ Project [try_add(2147483647, 1) AS try_add(2147483647, 1)#x]
+- OneRowRelation


-- !query
SELECT try_add(2147483647, decimal(1))
-- !query analysis
Project [try_add(2147483647, cast(1 as decimal(10,0))) AS try_add(2147483647, 1)#x]
+- OneRowRelation


-- !query
SELECT try_add(2147483647, "1")
-- !query analysis
Project [try_add(2147483647, 1) AS try_add(2147483647, 1)#x]
+- OneRowRelation


-- !query
SELECT try_add(-2147483648, -1)
-- !query analysis
Expand Down Expand Up @@ -211,6 +225,20 @@ Project [try_divide(1, (1.0 / 0.0)) AS try_divide(1, (1.0 / 0.0))#x]
+- OneRowRelation


-- !query
SELECT try_divide(1, decimal(0))
-- !query analysis
Project [try_divide(1, cast(0 as decimal(10,0))) AS try_divide(1, 0)#x]
+- OneRowRelation


-- !query
SELECT try_divide(1, "0")
-- !query analysis
Project [try_divide(1, 0) AS try_divide(1, 0)#x]
+- OneRowRelation


-- !query
SELECT try_divide(interval 2 year, 2)
-- !query analysis
Expand Down Expand Up @@ -267,6 +295,20 @@ Project [try_subtract(2147483647, -1) AS try_subtract(2147483647, -1)#x]
+- OneRowRelation


-- !query
SELECT try_subtract(2147483647, decimal(-1))
-- !query analysis
Project [try_subtract(2147483647, cast(-1 as decimal(10,0))) AS try_subtract(2147483647, -1)#x]
+- OneRowRelation


-- !query
SELECT try_subtract(2147483647, "-1")
-- !query analysis
Project [try_subtract(2147483647, -1) AS try_subtract(2147483647, -1)#x]
+- OneRowRelation


-- !query
SELECT try_subtract(-2147483648, 1)
-- !query analysis
Expand Down Expand Up @@ -351,6 +393,20 @@ Project [try_multiply(2147483647, -2) AS try_multiply(2147483647, -2)#x]
+- OneRowRelation


-- !query
SELECT try_multiply(2147483647, decimal(-2))
-- !query analysis
Project [try_multiply(2147483647, cast(-2 as decimal(10,0))) AS try_multiply(2147483647, -2)#x]
+- OneRowRelation


-- !query
SELECT try_multiply(2147483647, "-2")
-- !query analysis
Project [try_multiply(2147483647, -2) AS try_multiply(2147483647, -2)#x]
+- OneRowRelation


-- !query
SELECT try_multiply(-2147483648, 2)
-- !query analysis
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
-- Numeric + Numeric
SELECT try_add(1, 1);
SELECT try_add(2147483647, 1);
SELECT try_add(2147483647, decimal(1));
SELECT try_add(2147483647, "1");
SELECT try_add(-2147483648, -1);
SELECT try_add(9223372036854775807L, 1);
SELECT try_add(-9223372036854775808L, -1);
Expand Down Expand Up @@ -38,6 +40,8 @@ SELECT try_divide(0, 0);
SELECT try_divide(1, (2147483647 + 1));
SELECT try_divide(1L, (9223372036854775807L + 1L));
SELECT try_divide(1, 1.0 / 0.0);
SELECT try_divide(1, decimal(0));
SELECT try_divide(1, "0");

-- Interval / Numeric
SELECT try_divide(interval 2 year, 2);
Expand All @@ -50,6 +54,8 @@ SELECT try_divide(interval 106751991 day, 0.5);
-- Numeric - Numeric
SELECT try_subtract(1, 1);
SELECT try_subtract(2147483647, -1);
SELECT try_subtract(2147483647, decimal(-1));
SELECT try_subtract(2147483647, "-1");
SELECT try_subtract(-2147483648, 1);
SELECT try_subtract(9223372036854775807L, -1);
SELECT try_subtract(-9223372036854775808L, 1);
Expand All @@ -66,6 +72,8 @@ SELECT try_subtract(interval 106751991 day, interval -3 day);
-- Numeric * Numeric
SELECT try_multiply(2, 3);
SELECT try_multiply(2147483647, -2);
SELECT try_multiply(2147483647, decimal(-2));
SELECT try_multiply(2147483647, "-2");
SELECT try_multiply(-2147483648, 2);
SELECT try_multiply(9223372036854775807L, 2);
SELECT try_multiply(-9223372036854775808L, -2);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,22 @@ struct<try_add(2147483647, 1):int>
NULL


-- !query
SELECT try_add(2147483647, decimal(1))
-- !query schema
struct<try_add(2147483647, 1):decimal(11,0)>
-- !query output
2147483648


-- !query
SELECT try_add(2147483647, "1")
-- !query schema
struct<try_add(2147483647, 1):bigint>
-- !query output
2147483648


-- !query
SELECT try_add(-2147483648, -1)
-- !query schema
Expand Down Expand Up @@ -341,6 +357,22 @@ org.apache.spark.SparkArithmeticException
}


-- !query
SELECT try_divide(1, decimal(0))
-- !query schema
struct<try_divide(1, 0):decimal(12,11)>
-- !query output
NULL


-- !query
SELECT try_divide(1, "0")
-- !query schema
struct<try_divide(1, 0):double>
-- !query output
NULL


-- !query
SELECT try_divide(interval 2 year, 2)
-- !query schema
Expand Down Expand Up @@ -405,6 +437,22 @@ struct<try_subtract(2147483647, -1):int>
NULL


-- !query
SELECT try_subtract(2147483647, decimal(-1))
-- !query schema
struct<try_subtract(2147483647, -1):decimal(11,0)>
-- !query output
2147483648


-- !query
SELECT try_subtract(2147483647, "-1")
-- !query schema
struct<try_subtract(2147483647, -1):bigint>
-- !query output
2147483648


-- !query
SELECT try_subtract(-2147483648, 1)
-- !query schema
Expand Down Expand Up @@ -547,6 +595,22 @@ struct<try_multiply(2147483647, -2):int>
NULL


-- !query
SELECT try_multiply(2147483647, decimal(-2))
-- !query schema
struct<try_multiply(2147483647, -2):decimal(21,0)>
-- !query output
-4294967294


-- !query
SELECT try_multiply(2147483647, "-2")
-- !query schema
struct<try_multiply(2147483647, -2):bigint>
-- !query output
-4294967294


-- !query
SELECT try_multiply(-2147483648, 2)
-- !query schema
Expand Down
Loading

0 comments on commit e78ee2c

Please sign in to comment.