diff --git a/python/pyspark/sql/connect/expressions.py b/python/pyspark/sql/connect/expressions.py index 8cd386ba03aea..aac56c045ee21 100644 --- a/python/pyspark/sql/connect/expressions.py +++ b/python/pyspark/sql/connect/expressions.py @@ -914,7 +914,7 @@ def to_plan(self, session: "SparkConnectClient") -> proto.Expression: return expr def __repr__(self) -> str: - return f"(UnresolvedNamedLambdaVariable({', '.join(self._name_parts)})" + return ", ".join(self._name_parts) @staticmethod def fresh_var_name(name: str) -> str: @@ -959,7 +959,10 @@ def to_plan(self, session: "SparkConnectClient") -> proto.Expression: return expr def __repr__(self) -> str: - return f"(LambdaFunction({str(self._function)}, {', '.join(self._arguments)})" + return ( + f"LambdaFunction({str(self._function)}, " + + f"{', '.join([str(arg) for arg in self._arguments])})" + ) class WindowExpression(Expression): diff --git a/python/pyspark/sql/tests/connect/test_connect_column.py b/python/pyspark/sql/tests/connect/test_connect_column.py index 9a850dcae6f53..fbfb4486446ff 100644 --- a/python/pyspark/sql/tests/connect/test_connect_column.py +++ b/python/pyspark/sql/tests/connect/test_connect_column.py @@ -1024,6 +1024,28 @@ def test_distributed_sequence_id(self): expected.collect(), ) + def test_lambda_str_representation(self): + from pyspark.sql.connect.expressions import UnresolvedNamedLambdaVariable + + # forcely clear the internal increasing id, + # otherwise the string representation varies with this id + UnresolvedNamedLambdaVariable._nextVarNameId = 0 + + c = CF.array_sort( + "data", + lambda x, y: CF.when(x.isNull() | y.isNull(), CF.lit(0)).otherwise( + CF.length(y) - CF.length(x) + ), + ) + + self.assertEqual( + str(c), + ( + """Column<'array_sort(data, LambdaFunction(CASE WHEN or(isNull(x_0), """ + """isNull(y_1)) THEN 0 ELSE -(length(y_1), length(x_0)) END, x_0, y_1))'>""" + ), + ) + if __name__ == "__main__": import unittest