Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-40777][SQL][PROTOBUF] Protobuf import support and move error-classes. #38344

Closed
Closed
Show file tree
Hide file tree
Changes from 11 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
c6177a3
review changes nit
SandishKumarHN Oct 18, 2022
8fad018
adding spark-protobuf error class into spark error-class framework
SandishKumarHN Oct 21, 2022
dac2fa8
adding spark-protobuf error class into spark error-class framework
SandishKumarHN Oct 21, 2022
d1c9b1e
adding spark-protobuf error class into spark error-class framework, i…
SandishKumarHN Oct 21, 2022
03b918f
Merge remote-tracking branch 'remote-spark/master' into SPARK-40777
SandishKumarHN Oct 21, 2022
60c2122
adding spark-protobuf error class into spark error-class framework, i…
SandishKumarHN Oct 21, 2022
7c866de
Merge branch 'SPARK-40777' into SPARK-40777-ProtoErrorCls
SandishKumarHN Oct 22, 2022
e6f3cab
spark-protobuf error clss frameworks, import support, timestamp and d…
SandishKumarHN Oct 22, 2022
70a5983
fixing typos
SandishKumarHN Oct 24, 2022
26e471b
review changes, name protobuf error class and comment in pom.xml
SandishKumarHN Oct 24, 2022
68f87c1
review chages rm includeMavenTypes and import own protos
SandishKumarHN Oct 24, 2022
9cdf4d5
review changes: nested import support, nit, build fix
SandishKumarHN Oct 25, 2022
4e1080e
review changes style changes, nit
SandishKumarHN Oct 25, 2022
dbaf24d
Merge remote-tracking branch 'remote-spark/master' into SPARK-40777-P…
SandishKumarHN Oct 26, 2022
6045ffe
review changes: search messages on all imports
SandishKumarHN Oct 26, 2022
3fc59b5
review changes: use prettyName
SandishKumarHN Oct 27, 2022
e5482a5
Merge remote-tracking branch 'remote-spark/master' into SPARK-40777-P…
SandishKumarHN Oct 27, 2022
dd63be8
review changes: nit, option, runtime error
SandishKumarHN Oct 28, 2022
e5140b0
error class name changes, more details to error message
SandishKumarHN Oct 28, 2022
3037415
Merge remote-tracking branch 'remote-spark/master' into SPARK-40777-P…
SandishKumarHN Oct 29, 2022
be02c9e
NO_UDF_INTERFACE_ERROR to NO_UDF_INTERFACE
SandishKumarHN Oct 29, 2022
d8cce82
review changes scala style, find, parseFileDescriptorSet
SandishKumarHN Nov 1, 2022
ad1f7e1
Merge remote-tracking branch 'remote-spark/master' into SPARK-40777-P…
SandishKumarHN Nov 1, 2022
48bcb5c
review changes buildDescriptor suggested changes
SandishKumarHN Nov 1, 2022
87918a1
review changes, error classes
SandishKumarHN Nov 3, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,10 @@ import scala.util.control.NonFatal

import com.google.protobuf.DynamicMessage

import org.apache.spark.SparkException
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, Expression, SpecificInternalRow, UnaryExpression}
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode}
import org.apache.spark.sql.catalyst.util.{FailFastMode, ParseMode, PermissiveMode}
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.protobuf.utils.{ProtobufOptions, ProtobufUtils, SchemaConverters}
import org.apache.spark.sql.types.{AbstractDataType, BinaryType, DataType, StructType}

Expand Down Expand Up @@ -71,16 +70,11 @@ private[protobuf] case class ProtobufDataToCatalyst(
@transient private lazy val parseMode: ParseMode = {
val mode = protobufOptions.parseMode
if (mode != PermissiveMode && mode != FailFastMode) {
throw new AnalysisException(unacceptableModeMessage(mode.name))
throw QueryCompilationErrors.parseModeUnsupportedError("from_protobuf", mode)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use prettyName instead of "from_protobuf", please.

}
mode
}

private def unacceptableModeMessage(name: String): String = {
s"from_protobuf() doesn't support the $name mode. " +
s"Acceptable modes are ${PermissiveMode.name} and ${FailFastMode.name}."
}

@transient private lazy val nullResultRow: Any = dataType match {
case st: StructType =>
val resultRow = new SpecificInternalRow(st.map(_.dataType))
Expand All @@ -98,13 +92,9 @@ private[protobuf] case class ProtobufDataToCatalyst(
case PermissiveMode =>
nullResultRow
case FailFastMode =>
throw new SparkException(
"Malformed records are detected in record parsing. " +
s"Current parse Mode: ${FailFastMode.name}. To process malformed records as null " +
"result, try setting the option 'mode' as 'PERMISSIVE'.",
e)
throw QueryCompilationErrors.malformedRecordsDetectedInRecordParsingError(e)
case _ =>
throw new AnalysisException(unacceptableModeMessage(parseMode.name))
throw QueryCompilationErrors.parseModeUnsupportedError("from_protobuf", parseMode)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto: prettyName

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@MaxGekk fixed.

}
}

Expand All @@ -119,8 +109,8 @@ private[protobuf] case class ProtobufDataToCatalyst(
case Some(number) =>
// Unknown fields contain a field with same number as a known field. Must be due to
// mismatch of schema between writer and reader here.
throw new IllegalArgumentException(s"Type mismatch encountered for field:" +
s" ${messageDescriptor.getFields.get(number)}")
throw QueryCompilationErrors.protobufFieldTypeMismatchError(
messageDescriptor.getFields.get(number).toString)
case None =>
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,14 @@ import com.google.protobuf.{ByteString, DynamicMessage, Message}
import com.google.protobuf.Descriptors._
import com.google.protobuf.Descriptors.FieldDescriptor.JavaType._

import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.{InternalRow, NoopFilters, StructFilters}
import org.apache.spark.sql.catalyst.expressions.{SpecificInternalRow, UnsafeArrayData}
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, DateTimeUtils, GenericArrayData}
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.protobuf.utils.ProtobufUtils
import org.apache.spark.sql.protobuf.utils.ProtobufUtils.ProtoMatchedField
import org.apache.spark.sql.protobuf.utils.ProtobufUtils.toFieldStr
import org.apache.spark.sql.protobuf.utils.SchemaConverters.IncompatibleSchemaException
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String

Expand Down Expand Up @@ -61,10 +61,10 @@ private[sql] class ProtobufDeserializer(
}
}
} catch {
case ise: IncompatibleSchemaException =>
throw new IncompatibleSchemaException(
s"Cannot convert Protobuf type ${rootDescriptor.getName} " +
s"to SQL type ${rootCatalystType.sql}.",
case ise: AnalysisException =>
throw QueryCompilationErrors.cannotConvertProtobufTypeToCatalystTypeError(
rootDescriptor.getName,
rootCatalystType,
ise)
}

Expand Down Expand Up @@ -152,11 +152,6 @@ private[sql] class ProtobufDeserializer(
catalystType: DataType,
protoPath: Seq[String],
catalystPath: Seq[String]): (CatalystDataUpdater, Int, Any) => Unit = {
val errorPrefix = s"Cannot convert Protobuf ${toFieldStr(protoPath)} to " +
s"SQL ${toFieldStr(catalystPath)} because "
val incompatibleMsg = errorPrefix +
s"schema is incompatible (protoType = ${protoType} ${protoType.toProto.getLabel} " +
s"${protoType.getJavaType} ${protoType.getType}, sqlType = ${catalystType.sql})"

(protoType.getJavaType, catalystType) match {

Expand All @@ -175,8 +170,9 @@ private[sql] class ProtobufDeserializer(
case (INT, ShortType) =>
(updater, ordinal, value) => updater.setShort(ordinal, value.asInstanceOf[Short])

case (BOOLEAN | INT | FLOAT | DOUBLE | LONG | STRING | ENUM | BYTE_STRING,
ArrayType(dataType: DataType, containsNull)) if protoType.isRepeated =>
case (
BOOLEAN | INT | FLOAT | DOUBLE | LONG | STRING | ENUM | BYTE_STRING,
ArrayType(dataType: DataType, containsNull)) if protoType.isRepeated =>
newArrayWriter(protoType, protoPath, catalystPath, dataType, containsNull)

case (LONG, LongType) =>
Expand All @@ -199,7 +195,7 @@ private[sql] class ProtobufDeserializer(
(updater, ordinal, value) =>
val byte_array = value match {
case s: ByteString => s.toByteArray
case _ => throw new Exception("Invalid ByteString format")
case _ => throw QueryCompilationErrors.invalidByteStringFormatError()
}
updater.set(ordinal, byte_array)

Expand Down Expand Up @@ -244,7 +240,13 @@ private[sql] class ProtobufDeserializer(
case (ENUM, StringType) =>
(updater, ordinal, value) => updater.set(ordinal, UTF8String.fromString(value.toString))

case _ => throw new IncompatibleSchemaException(incompatibleMsg)
case _ =>
throw QueryCompilationErrors.cannotConvertProtobufTypeToSqlTypeError(
toFieldStr(protoPath),
toFieldStr(catalystPath),
s"${protoType} ${protoType.toProto.getLabel} ${protoType.getJavaType}" +
s" ${protoType.getType}",
catalystType)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,14 @@ import com.google.protobuf.Descriptors.{Descriptor, FieldDescriptor}
import com.google.protobuf.Descriptors.FieldDescriptor.JavaType._

import org.apache.spark.internal.Logging
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.SpecializedGetters
import org.apache.spark.sql.catalyst.util.{DateTimeUtils, IntervalUtils}
import org.apache.spark.sql.catalyst.util.IntervalStringStyles.ANSI_STYLE
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.protobuf.utils.ProtobufUtils
import org.apache.spark.sql.protobuf.utils.ProtobufUtils.{toFieldStr, ProtoMatchedField}
import org.apache.spark.sql.protobuf.utils.SchemaConverters.IncompatibleSchemaException
import org.apache.spark.sql.types._

/**
Expand All @@ -53,10 +54,10 @@ private[sql] class ProtobufSerializer(
newStructConverter(st, rootDescriptor, Nil, Nil).asInstanceOf[Any => Any]
}
} catch {
case ise: IncompatibleSchemaException =>
throw new IncompatibleSchemaException(
s"Cannot convert SQL type ${rootCatalystType.sql} to Protobuf type " +
s"${rootDescriptor.getName}.",
case ise: AnalysisException =>
throw QueryCompilationErrors.cannotConvertSqlTypeToProtobufError(
rootDescriptor.getName,
rootCatalystType,
ise)
}
if (nullable) { (data: Any) =>
Expand All @@ -77,8 +78,6 @@ private[sql] class ProtobufSerializer(
fieldDescriptor: FieldDescriptor,
catalystPath: Seq[String],
protoPath: Seq[String]): Converter = {
val errorPrefix = s"Cannot convert SQL ${toFieldStr(catalystPath)} " +
s"to Protobuf ${toFieldStr(protoPath)} because "
(catalystType, fieldDescriptor.getJavaType) match {
case (NullType, _) =>
(getter, ordinal) => null
Expand All @@ -104,10 +103,11 @@ private[sql] class ProtobufSerializer(
(getter, ordinal) =>
val data = getter.getUTF8String(ordinal).toString
if (!enumSymbols.contains(data)) {
throw new IncompatibleSchemaException(
errorPrefix +
s""""$data" cannot be written since it's not defined in enum """ +
enumSymbols.mkString("\"", "\", \"", "\""))
throw QueryCompilationErrors.cannotConvertCatalystTypeToProtobufEnumTypeError(
toFieldStr(catalystPath),
toFieldStr(protoPath),
data,
enumSymbols.mkString("\"", "\", \"", "\""))
}
fieldDescriptor.getEnumType.findValueByName(data)
case (StringType, STRING) =>
Expand All @@ -124,7 +124,8 @@ private[sql] class ProtobufSerializer(
case (TimestampType, MESSAGE) =>
(getter, ordinal) =>
val millis = DateTimeUtils.microsToMillis(getter.getLong(ordinal))
Timestamp.newBuilder()
Timestamp
.newBuilder()
.setSeconds((millis / 1000))
.setNanos(((millis % 1000) * 1000000).toInt)
.build()
Expand Down Expand Up @@ -201,7 +202,8 @@ private[sql] class ProtobufSerializer(
val calendarInterval = IntervalUtils.fromIntervalString(dayTimeIntervalString)

val millis = DateTimeUtils.microsToMillis(calendarInterval.microseconds)
val duration = Duration.newBuilder()
val duration = Duration
.newBuilder()
.setSeconds((millis / 1000))
.setNanos(((millis % 1000) * 1000000).toInt)

Expand All @@ -215,10 +217,12 @@ private[sql] class ProtobufSerializer(
duration.build()

case _ =>
throw new IncompatibleSchemaException(
errorPrefix +
s"schema is incompatible (sqlType = ${catalystType.sql}, " +
s"protoType = ${fieldDescriptor.getJavaType})")
throw QueryCompilationErrors.cannotConvertCatalystTypeToProtobufTypeError(
toFieldStr(catalystPath),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please, don't quote it twice if you do that inside of cannotConvertCatalystTypeToProtobufTypeError() by toSQLId

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

okay, will call toSQLType inside the error class, avoid using toFieldStr for sqlType

toFieldStr(protoPath),
catalystType,
s"${fieldDescriptor} ${fieldDescriptor.toProto.getLabel} ${fieldDescriptor.getJavaType}" +
s" ${fieldDescriptor.getType}")
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@ import com.google.protobuf.{DescriptorProtos, Descriptors, InvalidProtocolBuffer
import com.google.protobuf.Descriptors.{Descriptor, FieldDescriptor}

import org.apache.spark.internal.Logging
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.protobuf.utils.SchemaConverters.IncompatibleSchemaException
import org.apache.spark.sql.types._
import org.apache.spark.util.Utils

Expand Down Expand Up @@ -61,9 +61,9 @@ private[sql] object ProtobufUtils extends Logging {
protoPath: Seq[String],
catalystPath: Seq[String]) {
if (descriptor.getName == null) {
throw new IncompatibleSchemaException(
s"Attempting to treat ${descriptor.getName} as a RECORD, " +
s"but it was: ${descriptor.getContainingType}")
throw QueryCompilationErrors.unknownProtobufMessageTypeError(
descriptor.getName,
descriptor.getContainingType().getName)
}

private[this] val protoFieldArray = descriptor.getFields.asScala.toArray
Expand All @@ -79,30 +79,29 @@ private[sql] object ProtobufUtils extends Logging {

/**
* Validate that there are no Catalyst fields which don't have a matching Protobuf field,
* throwing [[IncompatibleSchemaException]] if such extra fields are found. If
* `ignoreNullable` is false, consider nullable Catalyst fields to be eligible to be an extra
* field; otherwise, ignore nullable Catalyst fields when checking for extras.
* throwing [[AnalysisException]] if such extra fields are found. If `ignoreNullable` is
* false, consider nullable Catalyst fields to be eligible to be an extra field; otherwise,
* ignore nullable Catalyst fields when checking for extras.
*/
def validateNoExtraCatalystFields(ignoreNullable: Boolean): Unit =
catalystSchema.fields.foreach { sqlField =>
if (getFieldByName(sqlField.name).isEmpty &&
(!ignoreNullable || !sqlField.nullable)) {
throw new IncompatibleSchemaException(
s"Cannot find ${toFieldStr(catalystPath :+ sqlField.name)} in Protobuf schema")
throw QueryCompilationErrors.cannotFindCatalystTypeInProtobufSchemaError(
toFieldStr(catalystPath :+ sqlField.name))
}
}

/**
* Validate that there are no Protobuf fields which don't have a matching Catalyst field,
* throwing [[IncompatibleSchemaException]] if such extra fields are found. Only required
* (non-nullable) fields are checked; nullable fields are ignored.
* throwing [[AnalysisException]] if such extra fields are found. Only required (non-nullable)
* fields are checked; nullable fields are ignored.
*/
def validateNoExtraRequiredProtoFields(): Unit = {
val extraFields = protoFieldArray.toSet -- matchedFields.map(_.fieldDescriptor)
extraFields.filterNot(isNullable).foreach { extraField =>
throw new IncompatibleSchemaException(
s"Found ${toFieldStr(protoPath :+ extraField.getName())} in Protobuf schema " +
"but there is no match in the SQL schema")
throw QueryCompilationErrors.cannotFindProtobufFieldInCatalystError(
toFieldStr(protoPath :+ extraField.getName()))
}
}

Expand All @@ -125,10 +124,11 @@ private[sql] object ProtobufUtils extends Logging {
case Seq(protoField) => Some(protoField)
case Seq() => None
case matches =>
throw new IncompatibleSchemaException(
s"Searching for '$name' in " +
s"Protobuf schema at ${toFieldStr(protoPath)} gave ${matches.size} matches. " +
s"Candidates: " + matches.map(_.getName()).mkString("[", ", ", "]"))
throw QueryCompilationErrors.protobufFieldMatchError(
name,
toFieldStr(protoPath),
s"${matches.size}",
matches.map(_.getName()).mkString("[", ", ", "]"))
}
}
}
Expand Down Expand Up @@ -159,14 +159,12 @@ private[sql] object ProtobufUtils extends Logging {
} catch {
case _: ClassNotFoundException =>
val hasDots = protobufClassName.contains(".")
throw new IllegalArgumentException(
s"Could not load Protobuf class with name '$protobufClassName'" +
(if (hasDots) "" else ". Ensure the class name includes package prefix.")
)
throw QueryCompilationErrors.protobufClassLoadError(protobufClassName,
if (hasDots) "" else ". Ensure the class name includes package prefix.")
}

if (!classOf[Message].isAssignableFrom(protobufClass)) {
throw new IllegalArgumentException(s"$protobufClassName is not a Protobuf message type")
throw QueryCompilationErrors.protobufMessageTypeError(protobufClassName)
// TODO: Need to support V2. This might work with V2 classes too.
}

Expand All @@ -185,7 +183,7 @@ private[sql] object ProtobufUtils extends Logging {
descriptor match {
case Some(d) => d
case None =>
throw new RuntimeException(s"Unable to locate Message '$messageName' in Descriptor")
throw QueryCompilationErrors.unableToLocateProtobufMessageError(messageName)
}
}

Expand All @@ -196,27 +194,31 @@ private[sql] object ProtobufUtils extends Logging {
fileDescriptorSet = DescriptorProtos.FileDescriptorSet.parseFrom(dscFile)
} catch {
case ex: InvalidProtocolBufferException =>
// TODO move all the exceptions to core/src/main/resources/error/error-classes.json
throw new RuntimeException("Error parsing descriptor byte[] into Descriptor object", ex)
throw QueryCompilationErrors.descrioptorParseError(ex)
case ex: IOException =>
throw new RuntimeException(
"Error reading Protobuf descriptor file at path: " +
descFilePath,
ex)
throw QueryCompilationErrors.cannotFindDescriptorFileError(descFilePath, ex)
}

val descriptorProto: DescriptorProtos.FileDescriptorProto = fileDescriptorSet.getFile(0)
val descriptorProto: DescriptorProtos.FileDescriptorProto =
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@rangadi This handles the protobuf import; please let me know if you know of a better approach.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you some brief comments here?
Is the last file the right file?

fileDescriptorSet.getFileList.asScala.last

var fileDescriptorList = List[Descriptors.FileDescriptor]()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this the import file list? What happens when the imported file imports other files? i.e. A imports B and B imports C.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@rangadi some changes were made to support nested imports. Reading protobuf descriptors from the bottom up, the last element in FileDescriptorSet is the initial FileDescriptorProto, from which we will continue to find more FileDescriptors recursively.

for (fd <- fileDescriptorSet.getFileList.asScala) {
if (descriptorProto.getName != fd.getName) {
fileDescriptorList = fileDescriptorList ++
List(Descriptors.FileDescriptor.buildFrom(fd, new Array[Descriptors.FileDescriptor](0)))
}
}
try {
val fileDescriptor: Descriptors.FileDescriptor = Descriptors.FileDescriptor.buildFrom(
descriptorProto,
new Array[Descriptors.FileDescriptor](0))
val fileDescriptor: Descriptors.FileDescriptor =
Descriptors.FileDescriptor.buildFrom(descriptorProto, fileDescriptorList.toArray)
if (fileDescriptor.getMessageTypes().isEmpty()) {
throw new RuntimeException("No MessageTypes returned, " + fileDescriptor.getName());
throw QueryCompilationErrors.noProtobufMessageTypeReturnError(fileDescriptor.getName())
}
fileDescriptor
} catch {
case e: Descriptors.DescriptorValidationException =>
throw new RuntimeException("Error constructing FileDescriptor", e)
throw QueryCompilationErrors.failedParsingDescriptorError(e)
}
}

Expand Down
Loading