Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 28 additions & 28 deletions python/pyspark/sql/connect/proto/expressions_pb2.py

Large diffs are not rendered by default.

6 changes: 6 additions & 0 deletions python/pyspark/sql/connect/proto/expressions_pb2.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -1537,6 +1537,7 @@ class CommonInlineUserDefinedFunction(google.protobuf.message.Message):
PYTHON_UDF_FIELD_NUMBER: builtins.int
SCALAR_SCALA_UDF_FIELD_NUMBER: builtins.int
JAVA_UDF_FIELD_NUMBER: builtins.int
IS_DISTINCT_FIELD_NUMBER: builtins.int
function_name: builtins.str
"""(Required) Name of the user-defined function."""
deterministic: builtins.bool
Expand All @@ -1552,6 +1553,8 @@ class CommonInlineUserDefinedFunction(google.protobuf.message.Message):
def scalar_scala_udf(self) -> global___ScalarScalaUDF: ...
@property
def java_udf(self) -> global___JavaUDF: ...
is_distinct: builtins.bool
"""(Required) Indicate if this function should be applied on distinct values."""
def __init__(
self,
*,
Expand All @@ -1561,6 +1564,7 @@ class CommonInlineUserDefinedFunction(google.protobuf.message.Message):
python_udf: global___PythonUDF | None = ...,
scalar_scala_udf: global___ScalarScalaUDF | None = ...,
java_udf: global___JavaUDF | None = ...,
is_distinct: builtins.bool = ...,
) -> None: ...
def HasField(
self,
Expand All @@ -1586,6 +1590,8 @@ class CommonInlineUserDefinedFunction(google.protobuf.message.Message):
b"function",
"function_name",
b"function_name",
"is_distinct",
b"is_distinct",
"java_udf",
b"java_udf",
"python_udf",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -406,14 +406,20 @@ object DeserializerBuildHelper {
NewInstance(cls, arguments, Nil, propagateNull = false, dt, outerPointerGetter))

case AgnosticEncoders.RowEncoder(fields) =>
val isExternalRow = !path.dataType.isInstanceOf[StructType]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it really safe to call dataType here? The path expression might not be resolved and then this will throw an exception.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should be. If you don't know the dataType at this point, then you can't build a deserializer.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The problem comes up if you have a RowEncoder being used inside a ProductEncoder. The the path in the recursion will come from

createDeserializer(field.enc, getter, newTypePath),

and then
addToPath(path, field.name, field.enc.dataType, newTypePath)
and then here
val newPath = UnresolvedExtractValue(path, expressions.Literal(part))
so the path will contain UnresolvedExtractValue and the .dataType will throw

   org.apache.spark.sql.catalyst.analysis.UnresolvedException: [INTERNAL_ERROR] Invalid call to dataType on unresolved object SQLSTATE: XX000
  at org.apache.spark.sql.catalyst.analysis.UnresolvedExtractValue.dataType(unresolved.scala:939)
  at org.apache.spark.sql.catalyst.DeserializerBuildHelper$.createDeserializer(DeserializerBuildHelper.scala:411)

Is there some assumption somewhere that the encoders should not be fully composable and RowEncoder can only be used it certain cases?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@hvanhovell Created this PR #51319 that fixes the issue.

val convertedFields = fields.zipWithIndex.map { case (f, i) =>
val newTypePath = walkedTypePath.recordField(
f.enc.clsTag.runtimeClass.getName,
f.name)
exprs.If(
Invoke(path, "isNullAt", BooleanType, exprs.Literal(i) :: Nil),
exprs.Literal.create(null, externalDataTypeFor(f.enc)),
createDeserializer(f.enc, GetStructField(path, i), newTypePath))
val deserializer = createDeserializer(f.enc, GetStructField(path, i), newTypePath)
if (isExternalRow) {
exprs.If(
Invoke(path, "isNullAt", BooleanType, exprs.Literal(i) :: Nil),
exprs.Literal.create(null, externalDataTypeFor(f.enc)),
deserializer)
} else {
deserializer
}
}
exprs.If(IsNull(path),
exprs.Literal.create(null, externalDataTypeFor(enc)),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.{DeserializerBuildHelper, InternalRow, Java
import org.apache.spark.sql.catalyst.analysis.{Analyzer, GetColumnByOrdinal, SimpleAnalyzer, UnresolvedAttribute, UnresolvedExtractValue}
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder.{Deserializer, Serializer}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.objects.{AssertNotNull, InitializeJavaBean, NewInstance}
import org.apache.spark.sql.catalyst.expressions.objects.AssertNotNull
import org.apache.spark.sql.catalyst.optimizer.{ReassignLambdaVariableID, SimplifyCasts}
import org.apache.spark.sql.catalyst.plans.logical.{CatalystSerde, DeserializeToObject, LeafNode, LocalRelation}
import org.apache.spark.sql.catalyst.types.DataTypeUtils
Expand Down Expand Up @@ -200,8 +200,7 @@ case class ExpressionEncoder[T](
UnresolvedAttribute.quoted(part.toString)
case GetStructField(GetColumnByOrdinal(0, dt), ordinal, _) =>
GetColumnByOrdinal(ordinal, dt)
case If(IsNull(GetColumnByOrdinal(0, _)), _, n: NewInstance) => n
case If(IsNull(GetColumnByOrdinal(0, _)), _, i: InitializeJavaBean) => i
case If(IsNull(GetColumnByOrdinal(0, _)), _, e) => e
}
} else {
// For other input objects like primitive, array, map, etc., we deserialize the first column
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,23 @@ case class OptionNestedGeneric[T](list: Option[T])
case class MapNestedGenericKey[T](list: Map[T, Int])
case class MapNestedGenericValue[T](list: Map[Int, T])

class Wrapper[T](val value: T) {
override def hashCode(): Int = value.hashCode()
override def equals(obj: Any): Boolean = obj match {
case other: Wrapper[T @unchecked] => value == other.value
case _ => false
}
}

class WrapperCodec[T] extends Codec[Wrapper[T], T] {
override def encode(in: Wrapper[T]): T = in.value
override def decode(out: T): Wrapper[T] = new Wrapper(out)
}

class WrapperCodecProvider[T] extends (() => Codec[Wrapper[T], T]) {
override def apply(): Codec[Wrapper[T], T] = new WrapperCodec[T]
}

class ExpressionEncoderSuite extends CodegenInterpretedPlanTest with AnalysisTest
with QueryErrorsBase {
OuterScopes.addOuterScope(this)
Expand Down Expand Up @@ -568,6 +585,7 @@ class ExpressionEncoderSuite extends CodegenInterpretedPlanTest with AnalysisTes
encodeDecodeTest(FooEnum.E1, "scala Enum")



private def testTransformingEncoder(
name: String,
provider: () => Codec[Any, Array[Byte]]): Unit = test(name) {
Expand All @@ -585,6 +603,19 @@ class ExpressionEncoderSuite extends CodegenInterpretedPlanTest with AnalysisTes
testTransformingEncoder("transforming java serialization encoder", JavaSerializationCodec)
testTransformingEncoder("transforming kryo encoder", KryoSerializationCodec)

test("transforming row encoder") {
val schema = new StructType().add("a", LongType).add("b", StringType)
val encoder = ExpressionEncoder(TransformingEncoder(
classTag[Wrapper[Row]],
RowEncoder.encoderFor(schema),
new WrapperCodecProvider[Row]))
.resolveAndBind()
val toRow = encoder.createSerializer()
val fromRow = encoder.createDeserializer()
assert(fromRow(toRow(new Wrapper(Row(9L, "x")))) == new Wrapper(Row(9L, "x")))
}


// Scala / Java big decimals ----------------------------------------------------------

encodeDecodeTest(BigDecimal(("9" * 20) + "." + "9" * 18),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,9 @@ import org.apache.spark.sql.{AnalysisException, Encoder, Encoders, Row}
import org.apache.spark.sql.api.java.UDF2
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{PrimitiveIntEncoder, PrimitiveLongEncoder, StringEncoder}
import org.apache.spark.sql.connect.test.{QueryTest, RemoteSparkSession}
import org.apache.spark.sql.expressions.Aggregator
import org.apache.spark.sql.functions.{call_function, col, struct, udaf, udf}
import org.apache.spark.sql.types.{IntegerType, StringType}
import org.apache.spark.sql.expressions.{Aggregator, MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.functions.{call_function, col, count, lit, struct, udaf, udf}
import org.apache.spark.sql.types.{DataType, IntegerType, LongType, StringType, StructField, StructType}

/**
* All tests in this class requires client UDF defined in this test class synced with the server.
Expand Down Expand Up @@ -442,6 +442,41 @@ class UserDefinedFunctionE2ETestSuite extends QueryTest with RemoteSparkSession
.as(StringEncoder)
checkDataset(ds, "01", "12")
}

test("inline UserDefinedAggregateFunction") {
val summer0 = new LongSummer(0)
val summer1 = new LongSummer(1)
val summer3 = new LongSummer(3)
val ds = spark
.range(10)
.select(
count(lit(1)),
summer0(),
summer1(col("id")),
summer3(col("id"), col("id") + 1, col("id") + 2))
checkDataset(ds, Row(10L, 10L, Row(45L, 10L), Row(45L, 55L, 65L, 10L)))
}

test("inline UserDefinedAggregateFunction distinct") {
val summer1 = new LongSummer(1)
val ds =
spark.range(10).union(spark.range(10)).select(count(lit(1)), summer1.distinct(col("id")))
checkDataset(ds, Row(20L, Row(45L, 10L)))
}

test("register UserDefinedAggregateFunction") {
spark.udf.register("s0", new LongSummer(0))
spark.udf.register("s2", new LongSummer(2))
spark.udf.register("s4", new LongSummer(4))
val ds = spark
.range(10)
.select(
count(lit(1)),
call_function("s0"),
call_function("s2", col("id"), col("id") + 1),
call_function("s4", col("id"), col("id") + 1, col("id") + 2, col("id") + 3))
checkDataset(ds, Row(10L, 10L, Row(45L, 55L, 10L), Row(45L, 55L, 65L, 75L, 10L)))
}
}

case class UdafTestInput(id: Long, extra: Long)
Expand Down Expand Up @@ -503,3 +538,56 @@ object RowAggregator extends Aggregator[Row, (Long, Long), Long] {
class StringConcat extends UDF2[String, String, String] {
override def call(t1: String, t2: String): String = t1 + t2
}

class LongSummer(size: Int) extends UserDefinedAggregateFunction {
assert(size >= 0)

override def inputSchema: StructType = {
StructType(Array.tabulate(size)(i => StructField(s"val_$i", LongType)))
}

override def bufferSchema: StructType = inputSchema.add("counter", LongType)

override def dataType: DataType = {
if (size == 0) {
LongType
} else {
bufferSchema
}
}

override def deterministic: Boolean = true

override def initialize(buffer: MutableAggregationBuffer): Unit = {
var i = 0
while (i < size + 1) {
buffer.update(i, 0L)
i += 1
}
}

override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
var i = 0
while (i < size) {
buffer.update(i, buffer.getLong(i) + input.getLong(i))
i += 1
}
buffer.update(size, buffer.getLong(size) + 1)
}

override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
var i = 0
while (i < size + 1) {
buffer1.update(i, buffer1.getLong(i) + buffer2.getLong(i))
i += 1
}
}

override def evaluate(buffer: Row): Any = {
if (size == 0) {
buffer.getLong(0)
} else {
buffer
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -384,6 +384,8 @@ message CommonInlineUserDefinedFunction {
ScalarScalaUDF scalar_scala_udf = 5;
JavaUDF java_udf = 6;
}
// (Required) Indicate if this function should be applied on distinct values.
bool is_distinct = 7;
}

message PythonUDF {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,9 @@ package org.apache.spark.sql.connect

import org.apache.spark.connect.proto
import org.apache.spark.sql
import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.connect.common.DataTypeProtoConverter
import org.apache.spark.sql.expressions.{UserDefinedAggregateFunction, UserDefinedFunction}
import org.apache.spark.sql.expressions.{UserDefinedAggregateFunction, UserDefinedAggregator, UserDefinedFunction}
import org.apache.spark.sql.types.DataType

/**
Expand Down Expand Up @@ -55,6 +56,13 @@ class UDFRegistration(session: SparkSession) extends sql.UDFRegistration {
/** @inheritdoc */
override def register(
name: String,
udaf: UserDefinedAggregateFunction): UserDefinedAggregateFunction =
throw ConnectClientUnsupportedErrors.registerUdaf()
udaf: UserDefinedAggregateFunction): UserDefinedAggregateFunction = {
val wrapped = UserDefinedAggregator(
aggregator = new UserDefinedAggregateFunctionWrapper(udaf),
inputEncoder = RowEncoder.encoderFor(udaf.inputSchema),
givenName = Option(name),
deterministic = udaf.deterministic)
register(name, wrapped, "scala_udf", validateParameterCount = false)
udaf
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -69,10 +69,12 @@ private[sql] object UdfToProtoUtils {
*/
private[sql] def toProto(
udf: UserDefinedFunction,
arguments: Seq[proto.Expression] = Nil): proto.CommonInlineUserDefinedFunction = {
arguments: Seq[proto.Expression] = Nil,
isDistinct: Boolean = false): proto.CommonInlineUserDefinedFunction = {
val invokeUdf = proto.CommonInlineUserDefinedFunction
.newBuilder()
.setDeterministic(udf.deterministic)
.setIsDistinct(isDistinct)
.addAllArguments(arguments.asJava)
val protoUdf = invokeUdf.getScalarScalaUdfBuilder
.setNullable(udf.nullable)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.connect

import scala.reflect.classTag

import org.apache.spark.sql.{Encoder, Row}
import org.apache.spark.sql.catalyst.encoders.{Codec, RowEncoder}
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.TransformingEncoder
import org.apache.spark.sql.catalyst.expressions.GenericRow
import org.apache.spark.sql.expressions.{Aggregator, MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types.StructType

/**
* An [[Aggregator]] that wraps a [[UserDefinedAggregateFunction]]. This allows us to execute and
* register these (deprecated) UDAFS using the Aggregator code path.
*
* This implementation assumes that the aggregation buffers can be updated in place. See
* `org.apache.spark.sql.catalyst.expressions.aggregate.TypedImperativeAggregate` for more
* information.
*/
private[connect] class UserDefinedAggregateFunctionWrapper(udaf: UserDefinedAggregateFunction)
extends Aggregator[Row, MutableRow, Any] {
override def zero: MutableRow = {
val row = new MutableRow(udaf.bufferSchema)
udaf.initialize(row)
row
}

override def reduce(b: MutableRow, a: Row): MutableRow = {
udaf.update(b, a)
b
}

override def merge(b1: MutableRow, b2: MutableRow): MutableRow = {
udaf.merge(b1, b2)
b1
}

override def finish(reduction: MutableRow): Any = {
udaf.evaluate(reduction)
}

override def bufferEncoder: Encoder[MutableRow] = {
TransformingEncoder(
classTag[MutableRow],
RowEncoder.encoderFor(udaf.bufferSchema),
MutableRow)
}

override def outputEncoder: Encoder[Any] = {
RowEncoder
.encoderForDataType(udaf.dataType, lenient = false)
.asInstanceOf[Encoder[Any]]
}
}

/**
* Mutable row implementation that is used by [[UserDefinedAggregateFunctionWrapper]]. This code
* assumes that it is allowed to mutate/reuse buffers during aggregation.
*/
private[connect] class MutableRow(private[this] val values: Array[Any])
extends MutableAggregationBuffer {
def this(schema: StructType) = this(new Array[Any](schema.length))
def this(row: Row) = this(row.asInstanceOf[GenericRow].values)
override def length: Int = values.length
override def update(i: Int, value: Any): Unit = values(i) = value
override def get(i: Int): Any = values(i)
override def copy(): MutableRow = new MutableRow(values.clone())
def asGenericRow: Row = new GenericRow(values)
}

private[connect] object MutableRow extends (() => Codec[MutableRow, Row]) {
object MutableRowCodec extends Codec[MutableRow, Row] {
override def encode(in: MutableRow): Row = in.asGenericRow
override def decode(out: Row): MutableRow = new MutableRow(out)
}

override def apply(): Codec[MutableRow, Row] = MutableRowCodec
}
Loading