diff --git a/sql/api/src/main/scala/org/apache/spark/sql/errors/ExecutionErrors.scala b/sql/api/src/main/scala/org/apache/spark/sql/errors/ExecutionErrors.scala index 907c46f583cf1..0ee1d7037d438 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/errors/ExecutionErrors.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/errors/ExecutionErrors.scala @@ -116,10 +116,10 @@ private[sql] trait ExecutionErrors extends DataTypeErrorsBase { def arithmeticOverflowError( message: String, - hint: String = "", + suggestedFunc: String = "", context: QueryContext = null): ArithmeticException = { - val alternative = if (hint.nonEmpty) { - s" Use '$hint' to tolerate overflow and return NULL instead." + val alternative = if (suggestedFunc.nonEmpty) { + s" Use '$suggestedFunc' to tolerate overflow and return NULL instead." } else "" new SparkArithmeticException( errorClass = "ARITHMETIC_OVERFLOW", diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala index 5f13d397d1bf9..f7509f124ab50 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala @@ -166,7 +166,9 @@ case class CheckOverflowInSum( val value = child.eval(input) if (value == null) { if (nullOnOverflow) null - else throw QueryExecutionErrors.overflowInSumOfDecimalError(context) + else { + throw QueryExecutionErrors.overflowInSumOfDecimalError(context, suggestedFunc = "try_sum") + } } else { value.asInstanceOf[Decimal].toPrecision( dataType.precision, @@ -183,7 +185,7 @@ case class CheckOverflowInSum( val nullHandling = if (nullOnOverflow) { "" } else { - s"throw QueryExecutionErrors.overflowInSumOfDecimalError($errorContextCode);" + s"""throw QueryExecutionErrors.overflowInSumOfDecimalError($errorContextCode, "try_sum");""" } // scalastyle:off line.size.limit val code = code""" @@ -270,7 +272,8 @@ case class DecimalDivideWithOverflowCheck( if (nullOnOverflow) { null } else { - throw QueryExecutionErrors.overflowInSumOfDecimalError(getContextOrNull()) + throw QueryExecutionErrors.overflowInSumOfDecimalError(getContextOrNull(), + suggestedFunc = "try_avg") } } else { val value2 = right.eval(input) @@ -286,7 +289,7 @@ case class DecimalDivideWithOverflowCheck( val nullHandling = if (nullOnOverflow) { "" } else { - s"throw QueryExecutionErrors.overflowInSumOfDecimalError($errorContextCode);" + s"""throw QueryExecutionErrors.overflowInSumOfDecimalError($errorContextCode, "try_avg");""" } val eval1 = left.genCode(ctx) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala index efdc06d4cbd8a..0aed8e604bd9b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala @@ -295,8 +295,11 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase with ExecutionE "ansiConfig" -> toSQLConf(SQLConf.ANSI_ENABLED.key))) } - def overflowInSumOfDecimalError(context: QueryContext): ArithmeticException = { - arithmeticOverflowError("Overflow in sum of decimals", context = context) + def overflowInSumOfDecimalError( + context: QueryContext, + suggestedFunc: String): ArithmeticException = { + arithmeticOverflowError("Overflow in sum of decimals", suggestedFunc = suggestedFunc, + context = context) } def overflowInIntegralDivideError(context: QueryContext): ArithmeticException = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index 25f4d9f62354a..7ebcb280def6e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -2270,7 +2270,7 @@ class DataFrameAggregateSuite extends QueryTest } private def assertDecimalSumOverflow( - df: DataFrame, ansiEnabled: Boolean, expectedAnswer: Row): Unit = { + df: DataFrame, ansiEnabled: Boolean, fnName: String, expectedAnswer: Row): Unit = { if (!ansiEnabled) { checkAnswer(df, expectedAnswer) } else { @@ -2278,11 +2278,12 @@ class DataFrameAggregateSuite extends QueryTest df.collect() } assert(e.getMessage.contains("cannot be represented as Decimal") || - e.getMessage.contains("Overflow in sum of decimals")) + e.getMessage.contains(s"Overflow in sum of decimals. Use 'try_$fnName' to tolerate " + + s"overflow and return NULL instead.")) } } - def checkAggResultsForDecimalOverflow(aggFn: Column => Column): Unit = { + def checkAggResultsForDecimalOverflow(aggFn: Column => Column, fnName: String): Unit = { Seq("true", "false").foreach { wholeStageEnabled => withSQLConf((SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key, wholeStageEnabled)) { Seq(true, false).foreach { ansiEnabled => @@ -2306,27 +2307,27 @@ class DataFrameAggregateSuite extends QueryTest join(df, "intNum").agg(aggFn($"decNum")) val expectedAnswer = Row(null) - assertDecimalSumOverflow(df2, ansiEnabled, expectedAnswer) + assertDecimalSumOverflow(df2, ansiEnabled, fnName, expectedAnswer) val decStr = "1" + "0" * 19 val d1 = spark.range(0, 12, 1, 1) val d2 = d1.select(expr(s"cast('$decStr' as decimal (38, 18)) as d")).agg(aggFn($"d")) - assertDecimalSumOverflow(d2, ansiEnabled, expectedAnswer) + assertDecimalSumOverflow(d2, ansiEnabled, fnName, expectedAnswer) val d3 = spark.range(0, 1, 1, 1).union(spark.range(0, 11, 1, 1)) val d4 = d3.select(expr(s"cast('$decStr' as decimal (38, 18)) as d")).agg(aggFn($"d")) - assertDecimalSumOverflow(d4, ansiEnabled, expectedAnswer) + assertDecimalSumOverflow(d4, ansiEnabled, fnName, expectedAnswer) val d5 = d3.select(expr(s"cast('$decStr' as decimal (38, 18)) as d"), lit(1).as("key")).groupBy("key").agg(aggFn($"d").alias("aggd")).select($"aggd") - assertDecimalSumOverflow(d5, ansiEnabled, expectedAnswer) + assertDecimalSumOverflow(d5, ansiEnabled, fnName, expectedAnswer) val nullsDf = spark.range(1, 4, 1).select(expr(s"cast(null as decimal(38,18)) as d")) val largeDecimals = Seq(BigDecimal("1"* 20 + ".123"), BigDecimal("9"* 20 + ".123")). toDF("d") assertDecimalSumOverflow( - nullsDf.union(largeDecimals).agg(aggFn($"d")), ansiEnabled, expectedAnswer) + nullsDf.union(largeDecimals).agg(aggFn($"d")), ansiEnabled, fnName, expectedAnswer) val df3 = Seq( (BigDecimal("10000000000000000000"), 1), @@ -2344,9 +2345,9 @@ class DataFrameAggregateSuite extends QueryTest (BigDecimal("20000000000000000000"), 2)).toDF("decNum", "intNum") val df6 = df3.union(df4).union(df5) - val df7 = df6.groupBy("intNum").agg(sum("decNum"), countDistinct("decNum")). + val df7 = df6.groupBy("intNum").agg(aggFn($"decNum"), countDistinct("decNum")). filter("intNum == 1") - assertDecimalSumOverflow(df7, ansiEnabled, Row(1, null, 2)) + assertDecimalSumOverflow(df7, ansiEnabled, fnName, Row(1, null, 2)) } } } @@ -2354,11 +2355,11 @@ class DataFrameAggregateSuite extends QueryTest } test("SPARK-28067: Aggregate sum should not return wrong results for decimal overflow") { - checkAggResultsForDecimalOverflow(c => sum(c)) + checkAggResultsForDecimalOverflow(c => sum(c), "sum") } test("SPARK-35955: Aggregate avg should not return wrong results for decimal overflow") { - checkAggResultsForDecimalOverflow(c => avg(c)) + checkAggResultsForDecimalOverflow(c => avg(c), "avg") } test("SPARK-28224: Aggregate sum big decimal overflow") { @@ -2369,7 +2370,7 @@ class DataFrameAggregateSuite extends QueryTest Seq(true, false).foreach { ansiEnabled => withSQLConf((SQLConf.ANSI_ENABLED.key, ansiEnabled.toString)) { val structDf = largeDecimals.select("a").agg(sum("a")) - assertDecimalSumOverflow(structDf, ansiEnabled, Row(null)) + assertDecimalSumOverflow(structDf, ansiEnabled, "sum", Row(null)) } } }