Skip to content

Commit

Permalink
[SPARK-49993][SQL] Improve error messages for Sum and Average
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?
This PR improves messages for ANSI related issues for Sum and Average.

### Why are the changes needed?
The [PR](#48206) for removing ANSI suggestion in ARITHMETIC_OVERFLOW was getting too big, so this PR aims to split the work into multiple tasks.

### Does this PR introduce _any_ user-facing change?
Yes, new suggestions are added to do try_sum and try_average

### How was this patch tested?
Tests added.

### Was this patch authored or co-authored using generative AI tooling?
No.

Closes #48499 from mihailom-db/fixSuggestions.

Authored-by: Mihailo Milosevic <mihailo.milosevic@databricks.com>
Signed-off-by: Max Gekk <max.gekk@gmail.com>
  • Loading branch information
mihailom-db authored and MaxGekk committed Oct 28, 2024
1 parent cde8e4a commit 1cd72c6
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -116,10 +116,10 @@ private[sql] trait ExecutionErrors extends DataTypeErrorsBase {

def arithmeticOverflowError(
message: String,
hint: String = "",
suggestedFunc: String = "",
context: QueryContext = null): ArithmeticException = {
val alternative = if (hint.nonEmpty) {
s" Use '$hint' to tolerate overflow and return NULL instead."
val alternative = if (suggestedFunc.nonEmpty) {
s" Use '$suggestedFunc' to tolerate overflow and return NULL instead."
} else ""
new SparkArithmeticException(
errorClass = "ARITHMETIC_OVERFLOW",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,9 @@ case class CheckOverflowInSum(
val value = child.eval(input)
if (value == null) {
if (nullOnOverflow) null
else throw QueryExecutionErrors.overflowInSumOfDecimalError(context)
else {
throw QueryExecutionErrors.overflowInSumOfDecimalError(context, suggestedFunc = "try_sum")
}
} else {
value.asInstanceOf[Decimal].toPrecision(
dataType.precision,
Expand All @@ -183,7 +185,7 @@ case class CheckOverflowInSum(
val nullHandling = if (nullOnOverflow) {
""
} else {
s"throw QueryExecutionErrors.overflowInSumOfDecimalError($errorContextCode);"
s"""throw QueryExecutionErrors.overflowInSumOfDecimalError($errorContextCode, "try_sum");"""
}
// scalastyle:off line.size.limit
val code = code"""
Expand Down Expand Up @@ -270,7 +272,8 @@ case class DecimalDivideWithOverflowCheck(
if (nullOnOverflow) {
null
} else {
throw QueryExecutionErrors.overflowInSumOfDecimalError(getContextOrNull())
throw QueryExecutionErrors.overflowInSumOfDecimalError(getContextOrNull(),
suggestedFunc = "try_avg")
}
} else {
val value2 = right.eval(input)
Expand All @@ -286,7 +289,7 @@ case class DecimalDivideWithOverflowCheck(
val nullHandling = if (nullOnOverflow) {
""
} else {
s"throw QueryExecutionErrors.overflowInSumOfDecimalError($errorContextCode);"
s"""throw QueryExecutionErrors.overflowInSumOfDecimalError($errorContextCode, "try_avg");"""
}

val eval1 = left.genCode(ctx)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -295,8 +295,11 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase with ExecutionE
"ansiConfig" -> toSQLConf(SQLConf.ANSI_ENABLED.key)))
}

def overflowInSumOfDecimalError(context: QueryContext): ArithmeticException = {
arithmeticOverflowError("Overflow in sum of decimals", context = context)
def overflowInSumOfDecimalError(
context: QueryContext,
suggestedFunc: String): ArithmeticException = {
arithmeticOverflowError("Overflow in sum of decimals", suggestedFunc = suggestedFunc,
context = context)
}

def overflowInIntegralDivideError(context: QueryContext): ArithmeticException = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2270,19 +2270,20 @@ class DataFrameAggregateSuite extends QueryTest
}

private def assertDecimalSumOverflow(
df: DataFrame, ansiEnabled: Boolean, expectedAnswer: Row): Unit = {
df: DataFrame, ansiEnabled: Boolean, fnName: String, expectedAnswer: Row): Unit = {
if (!ansiEnabled) {
checkAnswer(df, expectedAnswer)
} else {
val e = intercept[ArithmeticException] {
df.collect()
}
assert(e.getMessage.contains("cannot be represented as Decimal") ||
e.getMessage.contains("Overflow in sum of decimals"))
e.getMessage.contains(s"Overflow in sum of decimals. Use 'try_$fnName' to tolerate " +
s"overflow and return NULL instead."))
}
}

def checkAggResultsForDecimalOverflow(aggFn: Column => Column): Unit = {
def checkAggResultsForDecimalOverflow(aggFn: Column => Column, fnName: String): Unit = {
Seq("true", "false").foreach { wholeStageEnabled =>
withSQLConf((SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key, wholeStageEnabled)) {
Seq(true, false).foreach { ansiEnabled =>
Expand All @@ -2306,27 +2307,27 @@ class DataFrameAggregateSuite extends QueryTest
join(df, "intNum").agg(aggFn($"decNum"))

val expectedAnswer = Row(null)
assertDecimalSumOverflow(df2, ansiEnabled, expectedAnswer)
assertDecimalSumOverflow(df2, ansiEnabled, fnName, expectedAnswer)

val decStr = "1" + "0" * 19
val d1 = spark.range(0, 12, 1, 1)
val d2 = d1.select(expr(s"cast('$decStr' as decimal (38, 18)) as d")).agg(aggFn($"d"))
assertDecimalSumOverflow(d2, ansiEnabled, expectedAnswer)
assertDecimalSumOverflow(d2, ansiEnabled, fnName, expectedAnswer)

val d3 = spark.range(0, 1, 1, 1).union(spark.range(0, 11, 1, 1))
val d4 = d3.select(expr(s"cast('$decStr' as decimal (38, 18)) as d")).agg(aggFn($"d"))
assertDecimalSumOverflow(d4, ansiEnabled, expectedAnswer)
assertDecimalSumOverflow(d4, ansiEnabled, fnName, expectedAnswer)

val d5 = d3.select(expr(s"cast('$decStr' as decimal (38, 18)) as d"),
lit(1).as("key")).groupBy("key").agg(aggFn($"d").alias("aggd")).select($"aggd")
assertDecimalSumOverflow(d5, ansiEnabled, expectedAnswer)
assertDecimalSumOverflow(d5, ansiEnabled, fnName, expectedAnswer)

val nullsDf = spark.range(1, 4, 1).select(expr(s"cast(null as decimal(38,18)) as d"))

val largeDecimals = Seq(BigDecimal("1"* 20 + ".123"), BigDecimal("9"* 20 + ".123")).
toDF("d")
assertDecimalSumOverflow(
nullsDf.union(largeDecimals).agg(aggFn($"d")), ansiEnabled, expectedAnswer)
nullsDf.union(largeDecimals).agg(aggFn($"d")), ansiEnabled, fnName, expectedAnswer)

val df3 = Seq(
(BigDecimal("10000000000000000000"), 1),
Expand All @@ -2344,21 +2345,21 @@ class DataFrameAggregateSuite extends QueryTest
(BigDecimal("20000000000000000000"), 2)).toDF("decNum", "intNum")

val df6 = df3.union(df4).union(df5)
val df7 = df6.groupBy("intNum").agg(sum("decNum"), countDistinct("decNum")).
val df7 = df6.groupBy("intNum").agg(aggFn($"decNum"), countDistinct("decNum")).
filter("intNum == 1")
assertDecimalSumOverflow(df7, ansiEnabled, Row(1, null, 2))
assertDecimalSumOverflow(df7, ansiEnabled, fnName, Row(1, null, 2))
}
}
}
}
}

test("SPARK-28067: Aggregate sum should not return wrong results for decimal overflow") {
checkAggResultsForDecimalOverflow(c => sum(c))
checkAggResultsForDecimalOverflow(c => sum(c), "sum")
}

test("SPARK-35955: Aggregate avg should not return wrong results for decimal overflow") {
checkAggResultsForDecimalOverflow(c => avg(c))
checkAggResultsForDecimalOverflow(c => avg(c), "avg")
}

test("SPARK-28224: Aggregate sum big decimal overflow") {
Expand All @@ -2369,7 +2370,7 @@ class DataFrameAggregateSuite extends QueryTest
Seq(true, false).foreach { ansiEnabled =>
withSQLConf((SQLConf.ANSI_ENABLED.key, ansiEnabled.toString)) {
val structDf = largeDecimals.select("a").agg(sum("a"))
assertDecimalSumOverflow(structDf, ansiEnabled, Row(null))
assertDecimalSumOverflow(structDf, ansiEnabled, "sum", Row(null))
}
}
}
Expand Down

0 comments on commit 1cd72c6

Please sign in to comment.