Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-43082][CONNECT][PYTHON] Arrow-optimized Python UDFs in Spark Connect #40725

Closed
wants to merge 13 commits into from
12 changes: 9 additions & 3 deletions python/pyspark/sql/connect/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
LambdaFunction,
UnresolvedNamedLambdaVariable,
)
from pyspark.sql.connect.udf import _create_udf
from pyspark.sql.connect.udf import _create_py_udf
from pyspark.sql import functions as pysparkfuncs
from pyspark.sql.types import _from_numpy_type, DataType, StructType, ArrayType, StringType

Expand Down Expand Up @@ -2461,6 +2461,7 @@ def unwrap_udt(col: "ColumnOrName") -> Column:
def udf(
f: Optional[Union[Callable[..., Any], "DataTypeOrString"]] = None,
returnType: "DataTypeOrString" = StringType(),
useArrow: Optional[bool] = None,
) -> Union["UserDefinedFunctionLike", Callable[[Callable[..., Any]], "UserDefinedFunctionLike"]]:
from pyspark.rdd import PythonEvalType

Expand All @@ -2469,10 +2470,15 @@ def udf(
# for decorator use it as a returnType
return_type = f or returnType
return functools.partial(
_create_udf, returnType=return_type, evalType=PythonEvalType.SQL_BATCHED_UDF
_create_py_udf,
returnType=return_type,
evalType=PythonEvalType.SQL_BATCHED_UDF,
useArrow=useArrow,
)
else:
return _create_udf(f=f, returnType=returnType, evalType=PythonEvalType.SQL_BATCHED_UDF)
return _create_py_udf(
f=f, returnType=returnType, evalType=PythonEvalType.SQL_BATCHED_UDF, useArrow=useArrow
)


udf.__doc__ = pysparkfuncs.udf.__doc__
Expand Down
46 changes: 45 additions & 1 deletion python/pyspark/sql/connect/udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@

import sys
import functools
import warnings
from inspect import getfullargspec
from typing import cast, Callable, Any, TYPE_CHECKING, Optional, Union

from pyspark.rdd import PythonEvalType
Expand All @@ -33,7 +35,7 @@
)
from pyspark.sql.connect.column import Column
from pyspark.sql.connect.types import UnparsedDataType
from pyspark.sql.types import DataType, StringType
from pyspark.sql.types import ArrayType, DataType, MapType, StringType, StructType
from pyspark.sql.udf import UDFRegistration as PySparkUDFRegistration


Expand All @@ -47,6 +49,48 @@
from pyspark.sql.types import StringType


def _create_py_udf(
f: Callable[..., Any],
returnType: "DataTypeOrString",
evalType: int,
useArrow: Optional[bool] = None,
) -> "UserDefinedFunctionLike":
from pyspark.sql.udf import _create_arrow_py_udf
from pyspark.sql.connect.session import _active_spark_session

if _active_spark_session is None:
is_arrow_enabled = False
else:
is_arrow_enabled = (
_active_spark_session.conf.get("spark.sql.execution.pythonUDF.arrow.enabled") == "true"
if useArrow is None
else useArrow
)

regular_udf = _create_udf(f, returnType, evalType)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is duplicated code in _create_py_udf between Spark Connect Python Client and vanilla PySpark, except for fetching the active SparkSession.
However, for a clear code path separation and abstraction, I decided not to refactor it for now.

return_type = regular_udf.returnType
try:
is_func_with_args = len(getfullargspec(f).args) > 0
except TypeError:
is_func_with_args = False
is_output_atomic_type = (
not isinstance(return_type, StructType)
and not isinstance(return_type, MapType)
and not isinstance(return_type, ArrayType)
)
if is_arrow_enabled:
if is_output_atomic_type and is_func_with_args:
return _create_arrow_py_udf(regular_udf)
else:
warnings.warn(
"Arrow optimization for Python UDFs cannot be enabled.",
UserWarning,
)
return regular_udf
else:
return regular_udf


def _create_udf(
f: Callable[..., Any],
returnType: "DataTypeOrString",
Expand Down
48 changes: 48 additions & 0 deletions python/pyspark/sql/tests/connect/test_parity_arrow_python_udf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import unittest

from pyspark.sql.tests.connect.test_parity_udf import UDFParityTests
from pyspark.sql.tests.test_arrow_python_udf import PythonUDFArrowTestsMixin


class ArrowPythonUDFParityTests(UDFParityTests, PythonUDFArrowTestsMixin):
@classmethod
def setUpClass(cls):
super(ArrowPythonUDFParityTests, cls).setUpClass()
cls.spark.conf.set("spark.sql.execution.pythonUDF.arrow.enabled", "true")

@classmethod
def tearDownClass(cls):
try:
cls.spark.conf.unset("spark.sql.execution.pythonUDF.arrow.enabled")
finally:
super(ArrowPythonUDFParityTests, cls).tearDownClass()


if __name__ == "__main__":
import unittest
from pyspark.sql.tests.connect.test_parity_arrow_python_udf import * # noqa: F401

try:
import xmlrunner # type: ignore[import]

testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError:
testRunner = None
unittest.main(testRunner=testRunner, verbosity=2)
21 changes: 15 additions & 6 deletions python/pyspark/sql/tests/test_arrow_python_udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,7 @@
@unittest.skipIf(
not have_pandas or not have_pyarrow, pandas_requirement_message or pyarrow_requirement_message
)
class PythonUDFArrowTests(BaseUDFTestsMixin, ReusedSQLTestCase):
@classmethod
def setUpClass(cls):
super(PythonUDFArrowTests, cls).setUpClass()
cls.spark.conf.set("spark.sql.execution.pythonUDF.arrow.enabled", "true")

class PythonUDFArrowTestsMixin(BaseUDFTestsMixin):
@unittest.skip("Unrelated test, and it fails when it runs duplicatedly.")
def test_broadcast_in_udf(self):
super(PythonUDFArrowTests, self).test_broadcast_in_udf()
Expand Down Expand Up @@ -118,6 +113,20 @@ def test_use_arrow(self):
self.assertEquals(row_false[0], "[1, 2, 3]")


class PythonUDFArrowTests(PythonUDFArrowTestsMixin, ReusedSQLTestCase):
@classmethod
def setUpClass(cls):
super(PythonUDFArrowTests, cls).setUpClass()
cls.spark.conf.set("spark.sql.execution.pythonUDF.arrow.enabled", "true")

@classmethod
def tearDownClass(cls):
try:
cls.spark.conf.unset("spark.sql.execution.pythonUDF.arrow.enabled")
finally:
super(PythonUDFArrowTests, cls).tearDownClass()


if __name__ == "__main__":
from pyspark.sql.tests.test_arrow_python_udf import * # noqa: F401

Expand Down
41 changes: 0 additions & 41 deletions python/pyspark/sql/tests/test_udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -838,47 +838,6 @@ def setUpClass(cls):
cls.spark.conf.set("spark.sql.execution.pythonUDF.arrow.enabled", "false")


def test_use_arrow(self):
# useArrow=True
row_true = (
self.spark.range(1)
.selectExpr(
"array(1, 2, 3) as array",
)
.select(
udf(lambda x: str(x), useArrow=True)("array"),
)
.first()
)
# The input is a NumPy array when the Arrow optimization is on.
self.assertEquals(row_true[0], "[1 2 3]")

# useArrow=None
row_none = (
self.spark.range(1)
.selectExpr(
"array(1, 2, 3) as array",
)
.select(
udf(lambda x: str(x), useArrow=None)("array"),
)
.first()
)

# useArrow=False
row_false = (
self.spark.range(1)
.selectExpr(
"array(1, 2, 3) as array",
)
.select(
udf(lambda x: str(x), useArrow=False)("array"),
)
.first()
)
self.assertEquals(row_false[0], row_none[0]) # "[1, 2, 3]"


class UDFInitializationTests(unittest.TestCase):
def tearDown(self):
if SparkSession._instantiatedSession is not None:
Expand Down
93 changes: 54 additions & 39 deletions python/pyspark/sql/udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ def _create_udf(
name: Optional[str] = None,
deterministic: bool = True,
) -> "UserDefinedFunctionLike":
"""Create a regular(non-Arrow-optimized) Python UDF."""
# Set the name of the UserDefinedFunction object to be the name of function f
udf_obj = UserDefinedFunction(
f, returnType=returnType, name=name, evalType=evalType, deterministic=deterministic
Expand All @@ -88,6 +89,7 @@ def _create_py_udf(
evalType: int,
useArrow: Optional[bool] = None,
) -> "UserDefinedFunctionLike":
"""Create a regular/Arrow-optimized Python UDF."""
# The following table shows the results when the type coercion in Arrow is needed, that is,
# when the user-specified return type(SQL Type) of the UDF and the actual instance(Python
# Value(Type)) that the UDF returns are different.
Expand Down Expand Up @@ -138,49 +140,62 @@ def _create_py_udf(
and not isinstance(return_type, MapType)
and not isinstance(return_type, ArrayType)
)
if is_arrow_enabled and is_output_atomic_type and is_func_with_args:
require_minimum_pandas_version()
require_minimum_pyarrow_version()

import pandas as pd
from pyspark.sql.pandas.functions import _create_pandas_udf # type: ignore[attr-defined]

# "result_func" ensures the result of a Python UDF to be consistent with/without Arrow
# optimization.
# Otherwise, an Arrow-optimized Python UDF raises "pyarrow.lib.ArrowTypeError: Expected a
# string or bytes dtype, got ..." whereas a non-Arrow-optimized Python UDF returns
# successfully.
result_func = lambda pdf: pdf # noqa: E731
if type(return_type) == StringType:
result_func = lambda r: str(r) if r is not None else r # noqa: E731
elif type(return_type) == BinaryType:
result_func = lambda r: bytes(r) if r is not None else r # noqa: E731

def vectorized_udf(*args: pd.Series) -> pd.Series:
if any(map(lambda arg: isinstance(arg, pd.DataFrame), args)):
raise NotImplementedError(
"Struct input type are not supported with Arrow optimization "
"enabled in Python UDFs. Disable "
"'spark.sql.execution.pythonUDF.arrow.enabled' to workaround."
)
return pd.Series(result_func(f(*a)) for a in zip(*args))

# Regular UDFs can take callable instances too.
vectorized_udf.__name__ = f.__name__ if hasattr(f, "__name__") else f.__class__.__name__
vectorized_udf.__module__ = (
f.__module__ if hasattr(f, "__module__") else f.__class__.__module__
)
vectorized_udf.__doc__ = f.__doc__
pudf = _create_pandas_udf(vectorized_udf, returnType, None)
# Keep the attributes as if this is a regular Python UDF.
pudf.func = f
pudf.returnType = return_type
pudf.evalType = regular_udf.evalType
return pudf
if is_arrow_enabled:
if is_output_atomic_type and is_func_with_args:
return _create_arrow_py_udf(regular_udf)
else:
warnings.warn(
"Arrow optimization for Python UDFs cannot be enabled.",
UserWarning,
)
return regular_udf
else:
return regular_udf


def _create_arrow_py_udf(regular_udf): # type: ignore
"""Create an Arrow-optimized Python UDF out of a regular Python UDF."""
require_minimum_pandas_version()
require_minimum_pyarrow_version()

import pandas as pd
from pyspark.sql.pandas.functions import _create_pandas_udf

f = regular_udf.func
return_type = regular_udf.returnType
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it seems that the regular_udf is only used to pass the returnType and evalType ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And regular_udf.func based on the updated code.


# "result_func" ensures the result of a Python UDF to be consistent with/without Arrow
# optimization.
# Otherwise, an Arrow-optimized Python UDF raises "pyarrow.lib.ArrowTypeError: Expected a
# string or bytes dtype, got ..." whereas a non-Arrow-optimized Python UDF returns
# successfully.
result_func = lambda pdf: pdf # noqa: E731
if type(return_type) == StringType:
result_func = lambda r: str(r) if r is not None else r # noqa: E731
elif type(return_type) == BinaryType:
result_func = lambda r: bytes(r) if r is not None else r # noqa: E731

def vectorized_udf(*args: pd.Series) -> pd.Series:
if any(map(lambda arg: isinstance(arg, pd.DataFrame), args)):
raise NotImplementedError(
"Struct input type are not supported with Arrow optimization "
"enabled in Python UDFs. Disable "
"'spark.sql.execution.pythonUDF.arrow.enabled' to workaround."
)
return pd.Series(result_func(f(*a)) for a in zip(*args))

# Regular UDFs can take callable instances too.
vectorized_udf.__name__ = f.__name__ if hasattr(f, "__name__") else f.__class__.__name__
vectorized_udf.__module__ = f.__module__ if hasattr(f, "__module__") else f.__class__.__module__
vectorized_udf.__doc__ = f.__doc__
pudf = _create_pandas_udf(vectorized_udf, return_type, None)
# Keep the attributes as if this is a regular Python UDF.
pudf.func = f
pudf.returnType = return_type
pudf.evalType = regular_udf.evalType
return pudf


class UserDefinedFunction:
"""
User defined function in Python
Expand Down