diff --git a/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/utils/ProtobufUtils.scala b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/utils/ProtobufUtils.scala index 4bd59ddce6cee..52870be5fbe07 100644 --- a/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/utils/ProtobufUtils.scala +++ b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/utils/ProtobufUtils.scala @@ -208,7 +208,7 @@ private[sql] object ProtobufUtils extends Logging { ).toList fileDescriptorList } catch { - case e: Descriptors.DescriptorValidationException => + case e: Exception => throw QueryCompilationErrors.failedParsingDescriptorError(descFilePath, e) } } diff --git a/connector/protobuf/src/test/resources/protobuf/basicmessage.proto b/connector/protobuf/src/test/resources/protobuf/basicmessage.proto index 4252f349cf045..8f4c1bb8eae42 100644 --- a/connector/protobuf/src/test/resources/protobuf/basicmessage.proto +++ b/connector/protobuf/src/test/resources/protobuf/basicmessage.proto @@ -17,6 +17,7 @@ // cd connector/protobuf/src/test/resources/protobuf // protoc --java_out=./ basicmessage.proto // protoc --include_imports --descriptor_set_out=basicmessage.desc --java_out=org/apache/spark/sql/protobuf/ basicmessage.proto +// protoc --descriptor_set_out=basicmessage_noimports.desc --java_out=org/apache/spark/sql/protobuf/ basicmessage.proto syntax = "proto3"; diff --git a/connector/protobuf/src/test/resources/protobuf/basicmessage_noimports.desc b/connector/protobuf/src/test/resources/protobuf/basicmessage_noimports.desc new file mode 100644 index 0000000000000..26ba6552cb01d --- /dev/null +++ b/connector/protobuf/src/test/resources/protobuf/basicmessage_noimports.desc @@ -0,0 +1,18 @@ + +È +basicmessage.proto$org.apache.spark.sql.protobuf.protosnestedenum.proto"Ü + BasicMessage +id (Rid! + string_value ( R stringValue + int32_value (R +int32Value + int64_value (R +int64Value! + double_value (R doubleValue + float_value (R +floatValue + +bool_value (R boolValue + bytes_value ( R +bytesValueS + rnested_enum (20.org.apache.spark.sql.protobuf.protos.NestedEnumR rnestedEnumBBBasicMessageProtobproto3 \ No newline at end of file diff --git a/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufCatalystDataConversionSuite.scala b/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufCatalystDataConversionSuite.scala index 271c5b0fec894..9f9b51006ca81 100644 --- a/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufCatalystDataConversionSuite.scala +++ b/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufCatalystDataConversionSuite.scala @@ -123,16 +123,21 @@ class ProtobufCatalystDataConversionSuite StringType -> ("StringMsg", "")) testingTypes.foreach { dt => - val seed = 1 + scala.util.Random.nextInt((1024 - 1) + 1) + val seed = scala.util.Random.nextInt(RandomDataGenerator.MAX_STR_LEN) test(s"single $dt with seed $seed") { val (messageName, defaultValue) = catalystTypesToProtoMessages(dt.fields(0).dataType) val rand = new scala.util.Random(seed) val generator = RandomDataGenerator.forType(dt, rand = rand).get - var data = generator() - while (data.asInstanceOf[Row].get(0) == defaultValue) // Do not use default values, since - data = generator() // from_protobuf() returns null in v3. + var data = generator().asInstanceOf[Row] + // Do not use default values, since from_protobuf() returns null in v3. + while ( + data != null && + (data.get(0) == defaultValue || + (dt == BinaryType && + data.get(0).asInstanceOf[Array[Byte]].isEmpty))) + data = generator().asInstanceOf[Row] val converter = CatalystTypeConverters.createToCatalystConverter(dt) val input = Literal.create(converter(data), dt) diff --git a/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufFunctionsSuite.scala b/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufFunctionsSuite.scala index 199ef235f1496..00ec56f90a632 100644 --- a/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufFunctionsSuite.scala +++ b/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufFunctionsSuite.scala @@ -677,4 +677,18 @@ class ProtobufFunctionsSuite extends QueryTest with SharedSparkSession with Seri === inputDf.select("durationMsg.duration").take(1).toSeq(0).get(0)) } } + + test("raise cannot construct protobuf descriptor error") { + val df = Seq(ByteString.empty().toByteArray).toDF("value") + val testFileDescriptor = testFile("basicmessage_noimports.desc").replace("file:/", "/") + + val e = intercept[AnalysisException] { + df.select(functions.from_protobuf($"value", "BasicMessage", testFileDescriptor) as 'sample) + .where("sample.string_value == \"slam\"").show() + } + checkError( + exception = e, + errorClass = "CANNOT_CONSTRUCT_PROTOBUF_DESCRIPTOR", + parameters = Map("descFilePath" -> testFileDescriptor)) + } } diff --git a/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufSerdeSuite.scala b/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufSerdeSuite.scala index 840535654ed6a..22b9d58bbd449 100644 --- a/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufSerdeSuite.scala +++ b/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufSerdeSuite.scala @@ -177,6 +177,29 @@ class ProtobufSerdeSuite extends SharedSparkSession { withFieldMatchType(Deserializer.create(CATALYST_STRUCT, protoNestedFile, _)) } + test("raise cannot parse and construct protobuf descriptor error") { + // passing serde_suite.proto instead serde_suite.desc + var testFileDesc = testFile("serde_suite.proto").replace("file:/", "/") + val e1 = intercept[AnalysisException] { + ProtobufUtils.buildDescriptor(testFileDesc, "FieldMissingInSQLRoot") + } + + checkError( + exception = e1, + errorClass = "CANNOT_PARSE_PROTOBUF_DESCRIPTOR", + parameters = Map("descFilePath" -> testFileDesc)) + + testFileDesc = testFile("basicmessage_noimports.desc").replace("file:/", "/") + val e2 = intercept[AnalysisException] { + ProtobufUtils.buildDescriptor(testFileDesc, "FieldMissingInSQLRoot") + } + + checkError( + exception = e2, + errorClass = "CANNOT_CONSTRUCT_PROTOBUF_DESCRIPTOR", + parameters = Map("descFilePath" -> testFileDesc)) + } + /** * Attempt to convert `catalystSchema` to `protoSchema` (or vice-versa if `deserialize` is * true), assert that it fails, and assert that the _cause_ of the thrown exception has a diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala index b56e1957f7720..139ea236e495c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala @@ -3325,7 +3325,7 @@ private[sql] object QueryCompilationErrors extends QueryErrorsBase { def descrioptorParseError(descFilePath: String, cause: Throwable): Throwable = { new AnalysisException( errorClass = "CANNOT_PARSE_PROTOBUF_DESCRIPTOR", - messageParameters = Map.empty("descFilePath" -> descFilePath), + messageParameters = Map("descFilePath" -> descFilePath), cause = Option(cause.getCause)) } @@ -3339,7 +3339,7 @@ private[sql] object QueryCompilationErrors extends QueryErrorsBase { def failedParsingDescriptorError(descFilePath: String, cause: Throwable): Throwable = { new AnalysisException( errorClass = "CANNOT_CONSTRUCT_PROTOBUF_DESCRIPTOR", - messageParameters = Map.empty("descFilePath" -> descFilePath), + messageParameters = Map("descFilePath" -> descFilePath), cause = Option(cause.getCause)) }