diff --git a/python/pyspark/sql/connect/udf.py b/python/pyspark/sql/connect/udf.py index 2128c908081a0..9afc6e0e626a5 100644 --- a/python/pyspark/sql/connect/udf.py +++ b/python/pyspark/sql/connect/udf.py @@ -98,7 +98,7 @@ def __init__( ) self.func = func - self._returnType: DataType = ( + self.returnType: DataType = ( UnparsedDataType(returnType) if isinstance(returnType, str) else returnType ) self._name = name or ( @@ -116,7 +116,7 @@ def _build_common_inline_user_defined_function( arg_exprs = [col._expr for col in arg_cols] py_udf = PythonUDF( - output_type=self._returnType, + output_type=self.returnType, eval_type=self.evalType, func=self.func, python_ver="%d.%d" % sys.version_info[:2], @@ -160,7 +160,7 @@ def wrapper(*args: "ColumnOrName") -> Column: ) wrapper.func = self.func # type: ignore[attr-defined] - wrapper.returnType = self._returnType # type: ignore[attr-defined] + wrapper.returnType = self.returnType # type: ignore[attr-defined] wrapper.evalType = self.evalType # type: ignore[attr-defined] wrapper.deterministic = self.deterministic # type: ignore[attr-defined] wrapper.asNondeterministic = functools.wraps( # type: ignore[attr-defined] diff --git a/python/pyspark/sql/tests/connect/test_parity_pandas_udf.py b/python/pyspark/sql/tests/connect/test_parity_pandas_udf.py index 841ade40f5e5a..571ee74287e96 100644 --- a/python/pyspark/sql/tests/connect/test_parity_pandas_udf.py +++ b/python/pyspark/sql/tests/connect/test_parity_pandas_udf.py @@ -17,6 +17,8 @@ import unittest +from pyspark.sql.connect.types import UnparsedDataType +from pyspark.sql.functions import pandas_udf, PandasUDFType from pyspark.sql.tests.pandas.test_pandas_udf import PandasUDFTestsMixin from pyspark.testing.connectutils import ReusedConnectTestCase @@ -32,15 +34,37 @@ def test_udf_wrong_arg(self): def test_pandas_udf_timestamp_ntz(self): super().test_pandas_udf_timestamp_ntz() - # TODO(SPARK-42247): standardize `returnType` attribute of UDF - @unittest.skip("Fails in Spark Connect, should enable.") - def test_pandas_udf_decorator(self): - super().test_pandas_udf_decorator() + def test_pandas_udf_decorator_with_return_type_string(self): + @pandas_udf("v double", PandasUDFType.GROUPED_MAP) + def foo(x): + return x - # TODO(SPARK-42247): standardize `returnType` attribute of UDF - @unittest.skip("Fails in Spark Connect, should enable.") - def test_pandas_udf_basic(self): - super().test_pandas_udf_basic() + self.assertEqual(foo.returnType, UnparsedDataType("v double")) + self.assertEqual(foo.evalType, PandasUDFType.GROUPED_MAP) + + @pandas_udf(returnType="double", functionType=PandasUDFType.SCALAR) + def foo(x): + return x + + self.assertEqual(foo.returnType, UnparsedDataType("double")) + self.assertEqual(foo.evalType, PandasUDFType.SCALAR) + + def test_pandas_udf_basic_with_return_type_string(self): + udf = pandas_udf(lambda x: x, "double", PandasUDFType.SCALAR) + self.assertEqual(udf.returnType, UnparsedDataType("double")) + self.assertEqual(udf.evalType, PandasUDFType.SCALAR) + + udf = pandas_udf(lambda x: x, "v double", PandasUDFType.GROUPED_MAP) + self.assertEqual(udf.returnType, UnparsedDataType("v double")) + self.assertEqual(udf.evalType, PandasUDFType.GROUPED_MAP) + + udf = pandas_udf(lambda x: x, "v double", functionType=PandasUDFType.GROUPED_MAP) + self.assertEqual(udf.returnType, UnparsedDataType("v double")) + self.assertEqual(udf.evalType, PandasUDFType.GROUPED_MAP) + + udf = pandas_udf(lambda x: x, returnType="v double", functionType=PandasUDFType.GROUPED_MAP) + self.assertEqual(udf.returnType, UnparsedDataType("v double")) + self.assertEqual(udf.evalType, PandasUDFType.GROUPED_MAP) # TODO(SPARK-42340): implement GroupedData.applyInPandas @unittest.skip("Fails in Spark Connect, should enable.") diff --git a/python/pyspark/sql/tests/connect/test_parity_udf.py b/python/pyspark/sql/tests/connect/test_parity_udf.py index 10d93b71ebf2f..1be7d69b8c329 100644 --- a/python/pyspark/sql/tests/connect/test_parity_udf.py +++ b/python/pyspark/sql/tests/connect/test_parity_udf.py @@ -78,16 +78,6 @@ def test_worker_original_stdin_closed(self): def test_udf_on_sql_context(self): super().test_udf_on_sql_context() - # TODO(SPARK-42247): implement `UserDefinedFunction.returnType` - @unittest.skip("Fails in Spark Connect, should enable.") - def test_udf3(self): - super().test_udf3() - - # TODO(SPARK-42247): implement `UserDefinedFunction.returnType` - @unittest.skip("Fails in Spark Connect, should enable.") - def test_udf_registration_return_type_none(self): - super().test_udf_registration_return_type_none() - @unittest.skip("Spark Connect does not support SQLContext but the test depends on it.") def test_non_existed_udf_with_sql_context(self): super().test_non_existed_udf_with_sql_context() diff --git a/python/pyspark/sql/tests/pandas/test_pandas_udf.py b/python/pyspark/sql/tests/pandas/test_pandas_udf.py index 0f92711313040..4e1eec38a0cb1 100644 --- a/python/pyspark/sql/tests/pandas/test_pandas_udf.py +++ b/python/pyspark/sql/tests/pandas/test_pandas_udf.py @@ -47,16 +47,17 @@ def test_pandas_udf_basic(self): self.assertEqual(udf.returnType, DoubleType()) self.assertEqual(udf.evalType, PythonEvalType.SQL_SCALAR_PANDAS_UDF) - udf = pandas_udf(lambda x: x, "double", PandasUDFType.SCALAR) - self.assertEqual(udf.returnType, DoubleType()) - self.assertEqual(udf.evalType, PythonEvalType.SQL_SCALAR_PANDAS_UDF) - udf = pandas_udf( lambda x: x, StructType([StructField("v", DoubleType())]), PandasUDFType.GROUPED_MAP ) self.assertEqual(udf.returnType, StructType([StructField("v", DoubleType())])) self.assertEqual(udf.evalType, PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF) + def test_pandas_udf_basic_with_return_type_string(self): + udf = pandas_udf(lambda x: x, "double", PandasUDFType.SCALAR) + self.assertEqual(udf.returnType, DoubleType()) + self.assertEqual(udf.evalType, PythonEvalType.SQL_SCALAR_PANDAS_UDF) + udf = pandas_udf(lambda x: x, "v double", PandasUDFType.GROUPED_MAP) self.assertEqual(udf.returnType, StructType([StructField("v", DoubleType())])) self.assertEqual(udf.evalType, PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF) @@ -93,33 +94,36 @@ def foo(x): self.assertEqual(foo.returnType, schema) self.assertEqual(foo.evalType, PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF) - @pandas_udf("v double", PandasUDFType.GROUPED_MAP) + @pandas_udf(schema, functionType=PandasUDFType.GROUPED_MAP) def foo(x): return x self.assertEqual(foo.returnType, schema) self.assertEqual(foo.evalType, PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF) - @pandas_udf(schema, functionType=PandasUDFType.GROUPED_MAP) + @pandas_udf(returnType=schema, functionType=PandasUDFType.GROUPED_MAP) def foo(x): return x self.assertEqual(foo.returnType, schema) self.assertEqual(foo.evalType, PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF) - @pandas_udf(returnType="double", functionType=PandasUDFType.SCALAR) + def test_pandas_udf_decorator_with_return_type_string(self): + schema = StructType([StructField("v", DoubleType())]) + + @pandas_udf("v double", PandasUDFType.GROUPED_MAP) def foo(x): return x - self.assertEqual(foo.returnType, DoubleType()) - self.assertEqual(foo.evalType, PythonEvalType.SQL_SCALAR_PANDAS_UDF) + self.assertEqual(foo.returnType, schema) + self.assertEqual(foo.evalType, PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF) - @pandas_udf(returnType=schema, functionType=PandasUDFType.GROUPED_MAP) + @pandas_udf(returnType="double", functionType=PandasUDFType.SCALAR) def foo(x): return x - self.assertEqual(foo.returnType, schema) - self.assertEqual(foo.evalType, PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF) + self.assertEqual(foo.returnType, DoubleType()) + self.assertEqual(foo.evalType, PythonEvalType.SQL_SCALAR_PANDAS_UDF) def test_udf_wrong_arg(self): with QuietTest(self.sc):