Skip to content

Commit

Permalink
[SPARK-40769][CORE][SQL] Migrate type check failures of aggregate exp…
Browse files Browse the repository at this point in the history
…ressions onto error classes

### What changes were proposed in this pull request?
This pr aims to replace `TypeCheckFailure` by `DataTypeMismatch` in type checks in the aggregate expressions:

- Count
- CollectSet
- CountMinSketchAgg
- HistogramNumeric

### Why are the changes needed?
Migration onto error classes unifies Spark SQL error messages.

### Does this PR introduce _any_ user-facing change?
Yes. The PR changes user-facing error messages.

### How was this patch tested?
Pass GitHub Actions

Closes #38498 from LuciferYang/SPARK-40769.

Authored-by: yangjie01 <yangjie01@baidu.com>
Signed-off-by: Max Gekk <max.gekk@gmail.com>
  • Loading branch information
LuciferYang authored and MaxGekk committed Nov 4, 2022
1 parent f31f64d commit a2a8de9
Show file tree
Hide file tree
Showing 10 changed files with 218 additions and 54 deletions.
11 changes: 11 additions & 0 deletions core/src/main/resources/error/error-classes.json
Original file line number Diff line number Diff line change
Expand Up @@ -370,6 +370,11 @@
"Cannot use an UnspecifiedFrame. This should have been converted during analysis."
]
},
"UNSUPPORTED_INPUT_TYPE" : {
"message" : [
"The input of <functionName> can't be <dataType> type data."
]
},
"VALUE_OUT_OF_RANGE" : {
"message" : [
"The <exprName> must be between <valueRange> (current value = <currentValue>)"
Expand All @@ -380,6 +385,12 @@
"The <functionName> requires <expectedNum> parameters but the actual number is <actualNum>."
]
},
"WRONG_NUM_ARGS_WITH_SUGGESTION" : {
"message" : [
"The <functionName> requires <expectedNum> parameters but the actual number is <actualNum>.",
"If you have to call this function with <legacyNum> parameters, set the legacy configuration <legacyConfKey> to <legacyConfValue>."
]
},
"WRONG_NUM_ENDPOINTS" : {
"message" : [
"The number of endpoints must be >= 2 to construct intervals but the actual number is <actualNumber>."
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,11 @@
package org.apache.spark.sql.catalyst.expressions.aggregate

import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.trees.TreePattern.{COUNT, TreePattern}
import org.apache.spark.sql.errors.QueryErrorsBase
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._

Expand All @@ -45,7 +47,8 @@ import org.apache.spark.sql.types._
group = "agg_funcs",
since = "1.0.0")
// scalastyle:on line.size.limit
case class Count(children: Seq[Expression]) extends DeclarativeAggregate {
case class Count(children: Seq[Expression]) extends DeclarativeAggregate
with QueryErrorsBase {

override def nullable: Boolean = false

Expand All @@ -56,9 +59,17 @@ case class Count(children: Seq[Expression]) extends DeclarativeAggregate {

override def checkInputDataTypes(): TypeCheckResult = {
if (children.isEmpty && !SQLConf.get.getConf(SQLConf.ALLOW_PARAMETERLESS_COUNT)) {
TypeCheckResult.TypeCheckFailure(s"$prettyName requires at least one argument. " +
s"If you have to call the function $prettyName without arguments, set the legacy " +
s"configuration `${SQLConf.ALLOW_PARAMETERLESS_COUNT.key}` as true")
DataTypeMismatch(
errorSubClass = "WRONG_NUM_ARGS_WITH_SUGGESTION",
messageParameters = Map(
"functionName" -> toSQLId(prettyName),
"expectedNum" -> " >= 1",
"actualNum" -> "0",
"legacyNum" -> "0",
"legacyConfKey" -> toSQLConf(SQLConf.ALLOW_PARAMETERLESS_COUNT.key),
"legacyConfValue" -> toSQLConfVal(true.toString)
)
)
} else {
TypeCheckResult.TypeCheckSuccess
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,10 @@ package org.apache.spark.sql.catalyst.expressions.aggregate

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess}
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{DataTypeMismatch, TypeCheckSuccess}
import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, Expression, ExpressionDescription, Literal}
import org.apache.spark.sql.catalyst.trees.QuaternaryLike
import org.apache.spark.sql.errors.QueryErrorsBase
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.util.sketch.CountMinSketch
Expand Down Expand Up @@ -63,7 +64,8 @@ case class CountMinSketchAgg(
override val inputAggBufferOffset: Int)
extends TypedImperativeAggregate[CountMinSketch]
with ExpectsInputTypes
with QuaternaryLike[Expression] {
with QuaternaryLike[Expression]
with QueryErrorsBase {

def this(
child: Expression,
Expand All @@ -82,17 +84,60 @@ case class CountMinSketchAgg(
val defaultCheck = super.checkInputDataTypes()
if (defaultCheck.isFailure) {
defaultCheck
} else if (!epsExpression.foldable || !confidenceExpression.foldable ||
!seedExpression.foldable) {
TypeCheckFailure(
"The eps, confidence or seed provided must be a literal or foldable")
} else if (epsExpression.eval() == null || confidenceExpression.eval() == null ||
seedExpression.eval() == null) {
TypeCheckFailure("The eps, confidence or seed provided should not be null")
} else if (!epsExpression.foldable) {
DataTypeMismatch(
errorSubClass = "NON_FOLDABLE_INPUT",
messageParameters = Map(
"inputName" -> "eps",
"inputType" -> toSQLType(epsExpression.dataType),
"inputExpr" -> toSQLExpr(epsExpression))
)
} else if (!confidenceExpression.foldable) {
DataTypeMismatch(
errorSubClass = "NON_FOLDABLE_INPUT",
messageParameters = Map(
"inputName" -> "confidence",
"inputType" -> toSQLType(confidenceExpression.dataType),
"inputExpr" -> toSQLExpr(confidenceExpression))
)
} else if (!seedExpression.foldable) {
DataTypeMismatch(
errorSubClass = "NON_FOLDABLE_INPUT",
messageParameters = Map(
"inputName" -> "seed",
"inputType" -> toSQLType(seedExpression.dataType),
"inputExpr" -> toSQLExpr(seedExpression))
)
} else if (epsExpression.eval() == null) {
DataTypeMismatch(
errorSubClass = "UNEXPECTED_NULL",
messageParameters = Map("exprName" -> "eps"))
} else if (confidenceExpression.eval() == null) {
DataTypeMismatch(
errorSubClass = "UNEXPECTED_NULL",
messageParameters = Map("exprName" -> "confidence"))
} else if (seedExpression.eval() == null) {
DataTypeMismatch(
errorSubClass = "UNEXPECTED_NULL",
messageParameters = Map("exprName" -> "seed"))
} else if (eps <= 0.0) {
TypeCheckFailure(s"Relative error must be positive (current value = $eps)")
DataTypeMismatch(
errorSubClass = "VALUE_OUT_OF_RANGE",
messageParameters = Map(
"exprName" -> "eps",
"valueRange" -> s"(${0.toDouble}, ${Double.MaxValue}]",
"currentValue" -> toSQLValue(eps, DoubleType)
)
)
} else if (confidence <= 0.0 || confidence >= 1.0) {
TypeCheckFailure(s"Confidence must be within range (0.0, 1.0) (current value = $confidence)")
DataTypeMismatch(
errorSubClass = "VALUE_OUT_OF_RANGE",
messageParameters = Map(
"exprName" -> "confidence",
"valueRange" -> s"(${0.toDouble}, ${1.toDouble}]",
"currentValue" -> toSQLValue(confidence, DoubleType)
)
)
} else {
TypeCheckSuccess
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,11 @@ import com.google.common.primitives.{Doubles, Ints}

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess}
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{DataTypeMismatch, TypeCheckSuccess}
import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionDescription, ImplicitCastInputTypes}
import org.apache.spark.sql.catalyst.trees.BinaryLike
import org.apache.spark.sql.catalyst.util.GenericArrayData
import org.apache.spark.sql.errors.QueryErrorsBase
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.spark.sql.util.NumericHistogram
Expand Down Expand Up @@ -63,7 +64,7 @@ case class HistogramNumeric(
override val mutableAggBufferOffset: Int,
override val inputAggBufferOffset: Int)
extends TypedImperativeAggregate[NumericHistogram] with ImplicitCastInputTypes
with BinaryLike[Expression] {
with BinaryLike[Expression] with QueryErrorsBase {

def this(child: Expression, nBins: Expression) = {
this(child, nBins, 0, 0)
Expand All @@ -89,11 +90,26 @@ case class HistogramNumeric(
if (defaultCheck.isFailure) {
defaultCheck
} else if (!nBins.foldable) {
TypeCheckFailure(s"${this.prettyName} needs the nBins provided must be a constant literal.")
DataTypeMismatch(
errorSubClass = "NON_FOLDABLE_INPUT",
messageParameters = Map(
"inputName" -> "nb",
"inputType" -> toSQLType(nBins.dataType),
"inputExpr" -> toSQLExpr(nBins))
)
} else if (nb == null) {
TypeCheckFailure(s"${this.prettyName} needs nBins value must not be null.")
DataTypeMismatch(
errorSubClass = "UNEXPECTED_NULL",
messageParameters = Map("exprName" -> "nb"))
} else if (nb.asInstanceOf[Int] < 2) {
TypeCheckFailure(s"${this.prettyName} needs nBins to be at least 2, but you supplied $nb.")
DataTypeMismatch(
errorSubClass = "VALUE_OUT_OF_RANGE",
messageParameters = Map(
"exprName" -> "nb",
"valueRange" -> s"[2, ${Int.MaxValue}]",
"currentValue" -> toSQLValue(nb, IntegerType)
)
)
} else {
TypeCheckSuccess
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,11 @@ import scala.collection.mutable

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.trees.UnaryLike
import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, TypeUtils}
import org.apache.spark.sql.errors.QueryErrorsBase
import org.apache.spark.sql.types._
import org.apache.spark.util.BoundedPriorityQueue

Expand Down Expand Up @@ -145,7 +147,8 @@ case class CollectList(
case class CollectSet(
child: Expression,
mutableAggBufferOffset: Int = 0,
inputAggBufferOffset: Int = 0) extends Collect[mutable.HashSet[Any]] {
inputAggBufferOffset: Int = 0)
extends Collect[mutable.HashSet[Any]] with QueryErrorsBase {

def this(child: Expression) = this(child, 0, 0)

Expand Down Expand Up @@ -177,7 +180,13 @@ case class CollectSet(
if (!child.dataType.existsRecursively(_.isInstanceOf[MapType])) {
TypeCheckResult.TypeCheckSuccess
} else {
TypeCheckResult.TypeCheckFailure("collect_set() cannot have map type data")
DataTypeMismatch(
errorSubClass = "UNSUPPORTED_INPUT_TYPE",
messageParameters = Map(
"functionName" -> toSQLId(prettyName),
"dataType" -> toSQLType(MapType)
)
)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,10 @@ private[sql] trait QueryErrorsBase {
quoteByDefault(conf)
}

def toSQLConfVal(conf: String): String = {
quoteByDefault(conf)
}

def toDSOption(option: String): String = {
quoteByDefault(option)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,9 @@ import scala.util.Random

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.TypeCheckFailure
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.util.TypeUtils.toSQLValue
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.util.sketch.CountMinSketch
Expand Down Expand Up @@ -140,13 +141,24 @@ class CountMinSketchAggSuite extends SparkFunSuite {
epsExpression = Literal(epsOfTotalCount),
confidenceExpression = Literal(confidence),
seedExpression = AttributeReference("c", IntegerType)())

Seq(wrongEps, wrongConfidence, wrongSeed).foreach { wrongAgg =>
assertResult(
TypeCheckFailure("The eps, confidence or seed provided must be a literal or foldable")) {
wrongAgg.checkInputDataTypes()
}
}
assertResult(
DataTypeMismatch(
errorSubClass = "NON_FOLDABLE_INPUT",
Map("inputName" -> "eps", "inputType" -> "\"DOUBLE\"", "inputExpr" -> "\"a\"")
)
)(wrongEps.checkInputDataTypes())
assertResult(
DataTypeMismatch(
errorSubClass = "NON_FOLDABLE_INPUT",
Map("inputName" -> "confidence", "inputType" -> "\"DOUBLE\"", "inputExpr" -> "\"b\"")
)
)(wrongConfidence.checkInputDataTypes())
assertResult(
DataTypeMismatch(
errorSubClass = "NON_FOLDABLE_INPUT",
Map("inputName" -> "seed", "inputType" -> "\"INT\"", "inputExpr" -> "\"c\"")
)
)(wrongSeed.checkInputDataTypes())
}

test("fails analysis if parameters are invalid") {
Expand All @@ -155,27 +167,52 @@ class CountMinSketchAggSuite extends SparkFunSuite {
val wrongConfidence = cms(epsOfTotalCount, null, seed)
val wrongSeed = cms(epsOfTotalCount, confidence, null)

Seq(wrongEps, wrongConfidence, wrongSeed).foreach { wrongAgg =>
assertResult(TypeCheckFailure("The eps, confidence or seed provided should not be null")) {
wrongAgg.checkInputDataTypes()
}
}
assertResult(
DataTypeMismatch(
errorSubClass = "UNEXPECTED_NULL",
Map("exprName" -> "eps")
)
)(wrongEps.checkInputDataTypes())
assertResult(
DataTypeMismatch(
errorSubClass = "UNEXPECTED_NULL",
Map("exprName" -> "confidence")
)
)(wrongConfidence.checkInputDataTypes())
assertResult(
DataTypeMismatch(
errorSubClass = "UNEXPECTED_NULL",
Map("exprName" -> "seed")
)
)(wrongSeed.checkInputDataTypes())

// parameters are out of the valid range
Seq(0.0, -1000.0).foreach { invalidEps =>
val invalidAgg = cms(invalidEps, confidence, seed)
assertResult(
TypeCheckFailure(s"Relative error must be positive (current value = $invalidEps)")) {
invalidAgg.checkInputDataTypes()
}
DataTypeMismatch(
errorSubClass = "VALUE_OUT_OF_RANGE",
messageParameters = Map(
"exprName" -> "eps",
"valueRange" -> s"(${0.toDouble}, ${Double.MaxValue}]",
"currentValue" -> toSQLValue(invalidEps, DoubleType)
)
)
)(invalidAgg.checkInputDataTypes())
}

Seq(0.0, 1.0, -2.0, 2.0).foreach { invalidConfidence =>
val invalidAgg = cms(epsOfTotalCount, invalidConfidence, seed)
assertResult(TypeCheckFailure(
s"Confidence must be within range (0.0, 1.0) (current value = $invalidConfidence)")) {
invalidAgg.checkInputDataTypes()
}
assertResult(
DataTypeMismatch(
errorSubClass = "VALUE_OUT_OF_RANGE",
messageParameters = Map(
"exprName" -> "confidence",
"valueRange" -> s"(${0.toDouble}, ${1.toDouble}]",
"currentValue" -> toSQLValue(invalidConfidence, DoubleType)
)
)
)(invalidAgg.checkInputDataTypes())
}
}

Expand Down
Loading

0 comments on commit a2a8de9

Please sign in to comment.