Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -241,19 +241,12 @@ object DeserializerBuildHelper {
val walkedTypePath = WalkedTypePath().recordRoot(enc.clsTag.runtimeClass.getName)
// Assumes we are deserializing the first column of a row.
val input = GetColumnByOrdinal(0, enc.dataType)
enc match {
case AgnosticEncoders.RowEncoder(fields) =>
val children = fields.zipWithIndex.map { case (f, i) =>
createDeserializer(f.enc, GetStructField(input, i), walkedTypePath)
}
CreateExternalRow(children, enc.schema)
case _ =>
val deserializer = createDeserializer(
enc,
upCastToExpectedType(input, enc.dataType, walkedTypePath),
walkedTypePath)
expressionWithNullSafety(deserializer, enc.nullable, walkedTypePath)
}
val deserializer = createDeserializer(
enc,
upCastToExpectedType(input, enc.dataType, walkedTypePath),
walkedTypePath,
isTopLevel = true)
expressionWithNullSafety(deserializer, enc.nullable, walkedTypePath)
}

/**
Expand All @@ -265,11 +258,13 @@ object DeserializerBuildHelper {
* external representation.
* @param path The expression which can be used to extract serialized value.
* @param walkedTypePath The paths from top to bottom to access current field when deserializing.
* @param isTopLevel true if we are creating a deserializer for the top level value.
*/
private def createDeserializer(
enc: AgnosticEncoder[_],
path: Expression,
walkedTypePath: WalkedTypePath): Expression = enc match {
walkedTypePath: WalkedTypePath,
isTopLevel: Boolean = false): Expression = enc match {
case ae: AgnosticExpressionPathEncoder[_] =>
ae.fromCatalyst(path)
case _ if isNativeEncoder(enc) =>
Expand Down Expand Up @@ -408,13 +403,12 @@ object DeserializerBuildHelper {
NewInstance(cls, arguments, Nil, propagateNull = false, dt, outerPointerGetter))

case AgnosticEncoders.RowEncoder(fields) =>
val isExternalRow = !path.dataType.isInstanceOf[StructType]
val convertedFields = fields.zipWithIndex.map { case (f, i) =>
val newTypePath = walkedTypePath.recordField(
f.enc.clsTag.runtimeClass.getName,
f.name)
val deserializer = createDeserializer(f.enc, GetStructField(path, i), newTypePath)
if (isExternalRow) {
if (!isTopLevel) {
exprs.If(
Invoke(path, "isNullAt", BooleanType, exprs.Literal(i) :: Nil),
exprs.Literal.create(null, externalDataTypeFor(f.enc)),
Expand Down Expand Up @@ -460,7 +454,7 @@ object DeserializerBuildHelper {
Literal.create(provider(), ObjectType(classOf[Codec[_, _]])),
"decode",
dataTypeForClass(tag.runtimeClass),
createDeserializer(encoder, path, walkedTypePath) :: Nil)
createDeserializer(encoder, path, walkedTypePath, isTopLevel) :: Nil)
}

private def deserializeArray(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -659,6 +659,22 @@ class ExpressionEncoderSuite extends CodegenInterpretedPlanTest with AnalysisTes
assert(fromRow(toRow(new Wrapper(Row(9L, "x")))) == new Wrapper(Row(9L, "x")))
}

test("SPARK-52614: transforming encoder row encoder in product encoder") {
val schema = new StructType().add("a", LongType).add("b", StringType)
val wrapperEncoder = TransformingEncoder(
classTag[Wrapper[Row]],
RowEncoder.encoderFor(schema),
new WrapperCodecProvider[Row])
val encoder = ExpressionEncoder(ProductEncoder(
classTag[V[Wrapper[Row]]],
Seq(EncoderField("v", wrapperEncoder, nullable = false, Metadata.empty)),
None))
.resolveAndBind()
val toRow = encoder.createSerializer()
val fromRow = encoder.createDeserializer()
assert(fromRow(toRow(V(new Wrapper(Row(9L, "x"))))) == V(new Wrapper(Row(9L, "x"))))
}

// below tests are related to SPARK-49960 and TransformingEncoder usage
test("""Encoder with OptionEncoder of transformation""".stripMargin) {
type T = Option[V[V[Int]]]
Expand Down