Skip to content

Commit

Permalink
[SPARK-42247][CONNECT][PYTHON] Fix UserDefinedFunction to have return…
Browse files Browse the repository at this point in the history
…Type

### What changes were proposed in this pull request?

Fix `UserDefinedFunction` to have `returnType`.

### Why are the changes needed?

Currently `UserDefinedFunction` doesn't have `returnType` attribute.

### Does this PR introduce _any_ user-facing change?

No.

### How was this patch tested?

Enabled/modified the related tests.

Closes apache#40472 from ueshin/issues/SPARK-42247/returnType.

Lead-authored-by: Takuya UESHIN <ueshin@databricks.com>
Co-authored-by: Hyukjin Kwon <gurwls223@apache.org>
Signed-off-by: Hyukjin Kwon <gurwls223@apache.org>
  • Loading branch information
ueshin and HyukjinKwon committed Mar 20, 2023
1 parent e911c5e commit 708bda3
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 33 deletions.
6 changes: 3 additions & 3 deletions python/pyspark/sql/connect/udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def __init__(
)

self.func = func
self._returnType: DataType = (
self.returnType: DataType = (
UnparsedDataType(returnType) if isinstance(returnType, str) else returnType
)
self._name = name or (
Expand All @@ -116,7 +116,7 @@ def _build_common_inline_user_defined_function(
arg_exprs = [col._expr for col in arg_cols]

py_udf = PythonUDF(
output_type=self._returnType,
output_type=self.returnType,
eval_type=self.evalType,
func=self.func,
python_ver="%d.%d" % sys.version_info[:2],
Expand Down Expand Up @@ -160,7 +160,7 @@ def wrapper(*args: "ColumnOrName") -> Column:
)

wrapper.func = self.func # type: ignore[attr-defined]
wrapper.returnType = self._returnType # type: ignore[attr-defined]
wrapper.returnType = self.returnType # type: ignore[attr-defined]
wrapper.evalType = self.evalType # type: ignore[attr-defined]
wrapper.deterministic = self.deterministic # type: ignore[attr-defined]
wrapper.asNondeterministic = functools.wraps( # type: ignore[attr-defined]
Expand Down
40 changes: 32 additions & 8 deletions python/pyspark/sql/tests/connect/test_parity_pandas_udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

import unittest

from pyspark.sql.connect.types import UnparsedDataType
from pyspark.sql.functions import pandas_udf, PandasUDFType
from pyspark.sql.tests.pandas.test_pandas_udf import PandasUDFTestsMixin
from pyspark.testing.connectutils import ReusedConnectTestCase

Expand All @@ -32,15 +34,37 @@ def test_udf_wrong_arg(self):
def test_pandas_udf_timestamp_ntz(self):
super().test_pandas_udf_timestamp_ntz()

# TODO(SPARK-42247): standardize `returnType` attribute of UDF
@unittest.skip("Fails in Spark Connect, should enable.")
def test_pandas_udf_decorator(self):
super().test_pandas_udf_decorator()
def test_pandas_udf_decorator_with_return_type_string(self):
@pandas_udf("v double", PandasUDFType.GROUPED_MAP)
def foo(x):
return x

# TODO(SPARK-42247): standardize `returnType` attribute of UDF
@unittest.skip("Fails in Spark Connect, should enable.")
def test_pandas_udf_basic(self):
super().test_pandas_udf_basic()
self.assertEqual(foo.returnType, UnparsedDataType("v double"))
self.assertEqual(foo.evalType, PandasUDFType.GROUPED_MAP)

@pandas_udf(returnType="double", functionType=PandasUDFType.SCALAR)
def foo(x):
return x

self.assertEqual(foo.returnType, UnparsedDataType("double"))
self.assertEqual(foo.evalType, PandasUDFType.SCALAR)

def test_pandas_udf_basic_with_return_type_string(self):
udf = pandas_udf(lambda x: x, "double", PandasUDFType.SCALAR)
self.assertEqual(udf.returnType, UnparsedDataType("double"))
self.assertEqual(udf.evalType, PandasUDFType.SCALAR)

udf = pandas_udf(lambda x: x, "v double", PandasUDFType.GROUPED_MAP)
self.assertEqual(udf.returnType, UnparsedDataType("v double"))
self.assertEqual(udf.evalType, PandasUDFType.GROUPED_MAP)

udf = pandas_udf(lambda x: x, "v double", functionType=PandasUDFType.GROUPED_MAP)
self.assertEqual(udf.returnType, UnparsedDataType("v double"))
self.assertEqual(udf.evalType, PandasUDFType.GROUPED_MAP)

udf = pandas_udf(lambda x: x, returnType="v double", functionType=PandasUDFType.GROUPED_MAP)
self.assertEqual(udf.returnType, UnparsedDataType("v double"))
self.assertEqual(udf.evalType, PandasUDFType.GROUPED_MAP)

# TODO(SPARK-42340): implement GroupedData.applyInPandas
@unittest.skip("Fails in Spark Connect, should enable.")
Expand Down
10 changes: 0 additions & 10 deletions python/pyspark/sql/tests/connect/test_parity_udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,16 +78,6 @@ def test_worker_original_stdin_closed(self):
def test_udf_on_sql_context(self):
super().test_udf_on_sql_context()

# TODO(SPARK-42247): implement `UserDefinedFunction.returnType`
@unittest.skip("Fails in Spark Connect, should enable.")
def test_udf3(self):
super().test_udf3()

# TODO(SPARK-42247): implement `UserDefinedFunction.returnType`
@unittest.skip("Fails in Spark Connect, should enable.")
def test_udf_registration_return_type_none(self):
super().test_udf_registration_return_type_none()

@unittest.skip("Spark Connect does not support SQLContext but the test depends on it.")
def test_non_existed_udf_with_sql_context(self):
super().test_non_existed_udf_with_sql_context()
Expand Down
28 changes: 16 additions & 12 deletions python/pyspark/sql/tests/pandas/test_pandas_udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,16 +47,17 @@ def test_pandas_udf_basic(self):
self.assertEqual(udf.returnType, DoubleType())
self.assertEqual(udf.evalType, PythonEvalType.SQL_SCALAR_PANDAS_UDF)

udf = pandas_udf(lambda x: x, "double", PandasUDFType.SCALAR)
self.assertEqual(udf.returnType, DoubleType())
self.assertEqual(udf.evalType, PythonEvalType.SQL_SCALAR_PANDAS_UDF)

udf = pandas_udf(
lambda x: x, StructType([StructField("v", DoubleType())]), PandasUDFType.GROUPED_MAP
)
self.assertEqual(udf.returnType, StructType([StructField("v", DoubleType())]))
self.assertEqual(udf.evalType, PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF)

def test_pandas_udf_basic_with_return_type_string(self):
udf = pandas_udf(lambda x: x, "double", PandasUDFType.SCALAR)
self.assertEqual(udf.returnType, DoubleType())
self.assertEqual(udf.evalType, PythonEvalType.SQL_SCALAR_PANDAS_UDF)

udf = pandas_udf(lambda x: x, "v double", PandasUDFType.GROUPED_MAP)
self.assertEqual(udf.returnType, StructType([StructField("v", DoubleType())]))
self.assertEqual(udf.evalType, PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF)
Expand Down Expand Up @@ -93,33 +94,36 @@ def foo(x):
self.assertEqual(foo.returnType, schema)
self.assertEqual(foo.evalType, PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF)

@pandas_udf("v double", PandasUDFType.GROUPED_MAP)
@pandas_udf(schema, functionType=PandasUDFType.GROUPED_MAP)
def foo(x):
return x

self.assertEqual(foo.returnType, schema)
self.assertEqual(foo.evalType, PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF)

@pandas_udf(schema, functionType=PandasUDFType.GROUPED_MAP)
@pandas_udf(returnType=schema, functionType=PandasUDFType.GROUPED_MAP)
def foo(x):
return x

self.assertEqual(foo.returnType, schema)
self.assertEqual(foo.evalType, PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF)

@pandas_udf(returnType="double", functionType=PandasUDFType.SCALAR)
def test_pandas_udf_decorator_with_return_type_string(self):
schema = StructType([StructField("v", DoubleType())])

@pandas_udf("v double", PandasUDFType.GROUPED_MAP)
def foo(x):
return x

self.assertEqual(foo.returnType, DoubleType())
self.assertEqual(foo.evalType, PythonEvalType.SQL_SCALAR_PANDAS_UDF)
self.assertEqual(foo.returnType, schema)
self.assertEqual(foo.evalType, PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF)

@pandas_udf(returnType=schema, functionType=PandasUDFType.GROUPED_MAP)
@pandas_udf(returnType="double", functionType=PandasUDFType.SCALAR)
def foo(x):
return x

self.assertEqual(foo.returnType, schema)
self.assertEqual(foo.evalType, PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF)
self.assertEqual(foo.returnType, DoubleType())
self.assertEqual(foo.evalType, PythonEvalType.SQL_SCALAR_PANDAS_UDF)

def test_udf_wrong_arg(self):
with QuietTest(self.sc):
Expand Down

0 comments on commit 708bda3

Please sign in to comment.