Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SEDONA-607] Fix error message enhancements for geometry functions #1555

Merged
merged 2 commits into from
Aug 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -160,21 +160,20 @@ case class ST_GeomFromWKB(inputExpressions: Seq[Expression])
override def nullable: Boolean = true

override def eval(inputRow: InternalRow): Any = {
val arg = inputExpressions.head.eval(inputRow)
try {
(inputExpressions.head.eval(inputRow)) match {
case (geomString: UTF8String) => {
arg match {
case geomString: UTF8String =>
// Parse UTF-8 encoded wkb string
Constructors.geomFromText(geomString.toString, FileDataSplitter.WKB).toGenericArrayData
}
case (wkb: Array[Byte]) => {
case wkb: Array[Byte] =>
// convert raw wkb byte array to geometry
Constructors.geomFromWKB(wkb).toGenericArrayData
}
case null => null
}
} catch {
case e: Exception =>
InferredExpression.throwExpressionInferenceException(inputRow, inputExpressions, e)
InferredExpression.throwExpressionInferenceException(getClass.getSimpleName, Seq(arg), e)
}
}

Expand All @@ -201,21 +200,20 @@ case class ST_GeomFromEWKB(inputExpressions: Seq[Expression])
override def nullable: Boolean = true

override def eval(inputRow: InternalRow): Any = {
val arg = inputExpressions.head.eval(inputRow)
try {
(inputExpressions.head.eval(inputRow)) match {
case (geomString: UTF8String) => {
arg match {
case geomString: UTF8String =>
// Parse UTF-8 encoded wkb string
Constructors.geomFromText(geomString.toString, FileDataSplitter.WKB).toGenericArrayData
}
case (wkb: Array[Byte]) => {
case wkb: Array[Byte] =>
// convert raw wkb byte array to geometry
Constructors.geomFromWKB(wkb).toGenericArrayData
}
case null => null
}
} catch {
case e: Exception =>
InferredExpression.throwExpressionInferenceException(inputRow, inputExpressions, e)
InferredExpression.throwExpressionInferenceException(getClass.getSimpleName, Seq(arg), e)
}
}

Expand Down Expand Up @@ -267,7 +265,10 @@ case class ST_LineFromWKB(inputExpressions: Seq[Expression])
}
} catch {
case e: Exception =>
InferredExpression.throwExpressionInferenceException(inputRow, inputExpressions, e)
InferredExpression.throwExpressionInferenceException(
getClass.getSimpleName,
Seq(wkb, srid),
e)
}
}

Expand Down Expand Up @@ -321,7 +322,10 @@ case class ST_LinestringFromWKB(inputExpressions: Seq[Expression])
}
} catch {
case e: Exception =>
InferredExpression.throwExpressionInferenceException(inputRow, inputExpressions, e)
InferredExpression.throwExpressionInferenceException(
getClass.getSimpleName,
Seq(wkb, srid),
e)
}
}

Expand Down Expand Up @@ -375,7 +379,10 @@ case class ST_PointFromWKB(inputExpressions: Seq[Expression])
}
} catch {
case e: Exception =>
InferredExpression.throwExpressionInferenceException(inputRow, inputExpressions, e)
InferredExpression.throwExpressionInferenceException(
getClass.getSimpleName,
Seq(wkb, srid),
e)
}
}

Expand Down Expand Up @@ -413,7 +420,6 @@ case class ST_GeomFromGeoJSON(inputExpressions: Seq[Expression])
override def eval(inputRow: InternalRow): Any = {
val geomString = inputExpressions.head.eval(inputRow).asInstanceOf[UTF8String].toString
try {

val geometry = Constructors.geomFromText(geomString, FileDataSplitter.GEOJSON)
// If the user specify a bunch of attributes to go with each geometry, we need to store all of them in this geometry
if (inputExpressions.length > 1) {
Expand All @@ -422,7 +428,10 @@ case class ST_GeomFromGeoJSON(inputExpressions: Seq[Expression])
GeometrySerializer.serialize(geometry)
} catch {
case e: Exception =>
InferredExpression.throwExpressionInferenceException(inputRow, inputExpressions, e)
InferredExpression.throwExpressionInferenceException(
getClass.getSimpleName,
Seq(geomString),
e)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -353,7 +353,10 @@ case class ST_IsValidDetail(children: Seq[Expression])
Seq(validDetail.valid, UTF8String.fromString(validDetail.reason), serLocation))
} catch {
case e: Exception =>
InferredExpression.throwExpressionInferenceException(input, children, e)
InferredExpression.throwExpressionInferenceException(
getClass.getSimpleName,
Seq(geometry),
e)
}
}

Expand Down Expand Up @@ -627,20 +630,19 @@ case class ST_MinimumBoundingRadius(inputExpressions: Seq[Expression])

override def eval(input: InternalRow): Any = {
val expr = inputExpressions(0)
val geometry = expr.toGeometry(input)

try {
val geometry = expr match {
case s: SerdeAware => s.evalWithoutSerialization(input)
case _ => expr.toGeometry(input)
}

geometry match {
case geometry: Geometry => getMinimumBoundingRadius(geometry)
case _ => null
}
} catch {
case e: Exception =>
InferredExpression.throwExpressionInferenceException(input, inputExpressions, e)
InferredExpression.throwExpressionInferenceException(
getClass.getSimpleName,
Seq(geometry),
e)
}
}

Expand Down Expand Up @@ -932,22 +934,24 @@ case class ST_SubDivideExplode(children: Seq[Expression]) extends Generator with
children.validateLength(2)

override def eval(input: InternalRow): TraversableOnce[InternalRow] = {
val geometryRaw = children.head
val maxVerticesRaw = children(1)
val geometry = children.head.toGeometry(input)
val maxVertices = children(1).toInt(input)
try {
geometryRaw.toGeometry(input) match {
geometry match {
case geom: Geometry =>
ArrayData.toArrayData(
Functions.subDivide(geom, maxVerticesRaw.toInt(input)).map(_.toGenericArrayData))
ArrayData.toArrayData(Functions.subDivide(geom, maxVertices).map(_.toGenericArrayData))
Functions
.subDivide(geom, maxVerticesRaw.toInt(input))
.subDivide(geom, maxVertices)
.map(_.toGenericArrayData)
.map(InternalRow(_))
case _ => new Array[InternalRow](0)
}
} catch {
case e: Exception =>
InferredExpression.throwExpressionInferenceException(input, children, e)
InferredExpression.throwExpressionInferenceException(
getClass.getSimpleName,
Seq(geometry, maxVertices),
e)
}
}

Expand Down Expand Up @@ -1008,8 +1012,8 @@ case class ST_MaximumInscribedCircle(children: Seq[Expression])
with CodegenFallback {

override def eval(input: InternalRow): Any = {
val geometry = children.head.toGeometry(input)
try {
val geometry = children.head.toGeometry(input)
var inscribedCircle: InscribedCircle = null
inscribedCircle = Functions.maximumInscribedCircle(geometry)

Expand All @@ -1018,7 +1022,10 @@ case class ST_MaximumInscribedCircle(children: Seq[Expression])
InternalRow.fromSeq(Seq(serCenter, serNearest, inscribedCircle.radius))
} catch {
case e: Exception =>
InferredExpression.throwExpressionInferenceException(input, children, e)
InferredExpression.throwExpressionInferenceException(
getClass.getSimpleName,
Seq(geometry),
e)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,19 @@
*/
package org.apache.spark.sql.sedona_sql.expressions

import org.apache.commons.lang3.StringUtils
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Expression, ImplicitCastInputTypes, Literal}
import org.apache.spark.sql.catalyst.expressions.{Expression, ImplicitCastInputTypes}
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
import org.apache.spark.sql.catalyst.util.ArrayData
import org.apache.spark.sql.sedona_sql.UDT.GeometryUDT
import org.apache.spark.sql.types.{AbstractDataType, BinaryType, BooleanType, DataType, DataTypes, DoubleType, IntegerType, LongType, StringType, StructField, StructType}
import org.apache.spark.sql.types.{AbstractDataType, BinaryType, BooleanType, DataType, DataTypes, DoubleType, IntegerType, LongType, StringType}
import org.apache.spark.unsafe.types.UTF8String
import org.locationtech.jts.geom.Geometry
import org.apache.spark.sql.sedona_sql.expressions.implicits._

import scala.collection.convert.ImplicitConversions.`collection AsScalaIterable`
import scala.collection.mutable.ArrayBuffer
import scala.reflect.runtime.universe.TypeTag
import scala.reflect.runtime.universe.Type
import scala.reflect.runtime.universe.typeOf
Expand Down Expand Up @@ -77,27 +79,40 @@ abstract class InferredExpression(fSeq: InferrableFunction*)
override def inputTypes: Seq[AbstractDataType] = f.sparkInputTypes
override def dataType: DataType = f.sparkReturnType

private lazy val argExtractors: Array[InternalRow => Any] = f.buildExtractors(inputExpressions)
private lazy val argExtractors: Array[InternalRow => Any] = buildExtractors(inputExpressions)
private lazy val evaluator: InternalRow => Any = f.evaluatorBuilder(argExtractors)

private def findAllLiterals(expression: Expression): Seq[Literal] = {
expression match {
case lit: Literal => Seq(lit)
case _ => expression.children.flatMap(findAllLiterals)
}
}
// Remember input args to generate error messages when exceptions occur. The input arguments are
// helpful for troubleshooting the cause of errors.
private val inputArgs: ArrayBuffer[AnyRef] = ArrayBuffer.empty[AnyRef]

private def findAllLiteralsInExpressions(expressions: Seq[Expression]): Seq[String] = {
expressions.flatMap(findAllLiterals).map(_.value.toString)
private def buildExtractors(expressions: Seq[Expression]): Array[InternalRow => Any] = {
f.argExtractorBuilders
.zipAll(expressions, null, null)
.flatMap {
case (null, _) => None
case (builder, expr) =>
val extractor = builder(expr)
Some((input: InternalRow) => {
val arg = extractor(input)
inputArgs += arg.asInstanceOf[AnyRef]
arg
})
}
.toArray
}

override def eval(input: InternalRow): Any = {

try {
f.serializer(evaluator(input))
} catch {
case e: Exception =>
InferredExpression.throwExpressionInferenceException(input, inputExpressions, e)
InferredExpression.throwExpressionInferenceException(
getClass.getSimpleName,
inputArgs.toSeq,
e)
} finally {
inputArgs.clear()
}
}

Expand All @@ -106,32 +121,32 @@ abstract class InferredExpression(fSeq: InferrableFunction*)
evaluator(input)
} catch {
case e: Exception =>
InferredExpression.throwExpressionInferenceException(input, inputExpressions, e)
InferredExpression.throwExpressionInferenceException(
getClass.getSimpleName,
inputArgs.toSeq,
e)
} finally {
inputArgs.clear()
}
}
}

object InferredExpression {
def throwExpressionInferenceException(
input: InternalRow,
inputExpressions: Seq[Expression],
name: String,
inputArgs: Seq[Any],
e: Exception): Nothing = {
val literalsAsStrings = if (input == null) {
// In case no input row is provided, we can't extract literals from the input expressions.
inputExpressions.flatMap(findAllLiterals).map(_.value.toString)
if (e.isInstanceOf[InferredExpressionException]) {
throw e
} else {
Seq.empty[String]
}
val literalsOrInputString = literalsAsStrings.mkString(", ")
throw new InferredExpressionException(
s"Exception occurred while evaluating expression - source: [$literalsOrInputString]",
e)
}

def findAllLiterals(expression: Expression): Seq[Literal] = {
expression match {
case lit: Literal => Seq(lit)
case _ => expression.children.flatMap(findAllLiterals)
val inputsAsStrings = inputArgs.map { arg =>
val argStr = if (arg != null) arg.toString else "null"
StringUtils.abbreviate(argStr, 5000)
}
val inputsString = inputsAsStrings.mkString(", ")
throw new InferredExpressionException(
s"Exception occurred while evaluating expression $name - inputs: [$inputsString]",
e)
}
}
}
Expand Down Expand Up @@ -301,17 +316,7 @@ case class InferrableFunction(
sparkReturnType: DataType,
serializer: Any => Any,
argExtractorBuilders: Seq[Expression => InternalRow => Any],
evaluatorBuilder: Array[InternalRow => Any] => InternalRow => Any) {
def buildExtractors(expressions: Seq[Expression]): Array[InternalRow => Any] = {
argExtractorBuilders
.zipAll(expressions, null, null)
.flatMap {
case (null, _) => None
case (builder, expr) => Some(builder(expr))
}
.toArray
}
}
evaluatorBuilder: Array[InternalRow => Any] => InternalRow => Any)

object InferrableFunction {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,13 +55,16 @@ abstract class ST_Predicate
if (rightArray == null) {
null
} else {
val leftGeometry = GeometrySerializer.deserialize(leftArray)
val rightGeometry = GeometrySerializer.deserialize(rightArray)
try {
val leftGeometry = GeometrySerializer.deserialize(leftArray)
val rightGeometry = GeometrySerializer.deserialize(rightArray)
evalGeom(leftGeometry, rightGeometry)
} catch {
case e: Exception =>
InferredExpression.throwExpressionInferenceException(inputRow, inputExpressions, e)
InferredExpression.throwExpressionInferenceException(
getClass.getSimpleName,
Seq(leftGeometry, rightGeometry),
e)
}
}
}
Expand Down
Loading
Loading