diff --git a/examples/src/main/python/sql/arrow.py b/examples/src/main/python/sql/arrow.py index 03daf18eadbf3..48aee48d929c8 100644 --- a/examples/src/main/python/sql/arrow.py +++ b/examples/src/main/python/sql/arrow.py @@ -33,6 +33,22 @@ require_minimum_pyarrow_version() +def dataframe_to_arrow_table_example(spark: SparkSession) -> None: + import pyarrow as pa # noqa: F401 + from pyspark.sql.functions import rand + + # Create a Spark DataFrame + df = spark.range(100).drop("id").withColumns({"0": rand(), "1": rand(), "2": rand()}) + + # Convert the Spark DataFrame to a PyArrow Table + table = df.select("*").toArrow() + + print(table.schema) + # 0: double not null + # 1: double not null + # 2: double not null + + def dataframe_with_arrow_example(spark: SparkSession) -> None: import numpy as np import pandas as pd @@ -302,6 +318,8 @@ def arrow_slen(s): # type: ignore[no-untyped-def] .appName("Python Arrow-in-Spark example") \ .getOrCreate() + print("Running Arrow conversion example: DataFrame to Table") + dataframe_to_arrow_table_example(spark) print("Running Pandas to/from conversion example") dataframe_with_arrow_example(spark) print("Running pandas_udf example: Series to Frame") diff --git a/python/docs/source/reference/pyspark.sql/dataframe.rst b/python/docs/source/reference/pyspark.sql/dataframe.rst index b69a2771b04fc..ec39b645b1403 100644 --- a/python/docs/source/reference/pyspark.sql/dataframe.rst +++ b/python/docs/source/reference/pyspark.sql/dataframe.rst @@ -109,6 +109,7 @@ DataFrame DataFrame.tail DataFrame.take DataFrame.to + DataFrame.toArrow DataFrame.toDF DataFrame.toJSON DataFrame.toLocalIterator diff --git a/python/docs/source/user_guide/sql/arrow_pandas.rst b/python/docs/source/user_guide/sql/arrow_pandas.rst index 1d6a4df606906..0a527d832e211 100644 --- a/python/docs/source/user_guide/sql/arrow_pandas.rst +++ b/python/docs/source/user_guide/sql/arrow_pandas.rst @@ -39,6 +39,20 @@ is installed and available on all cluster nodes. You can install it using pip or conda from the conda-forge channel. See PyArrow `installation `_ for details. +Conversion to Arrow Table +------------------------- + +You can call :meth:`DataFrame.toArrow` to convert a Spark DataFrame to a PyArrow Table. + +.. literalinclude:: ../../../../../examples/src/main/python/sql/arrow.py + :language: python + :lines: 37-49 + :dedent: 4 + +Note that :meth:`DataFrame.toArrow` results in the collection of all records in the DataFrame to +the driver program and should be done on a small subset of the data. Not all Spark data types are +currently supported and an error can be raised if a column has an unsupported type. + Enabling for Conversion to/from Pandas -------------------------------------- @@ -53,7 +67,7 @@ This can be controlled by ``spark.sql.execution.arrow.pyspark.fallback.enabled`` .. literalinclude:: ../../../../../examples/src/main/python/sql/arrow.py :language: python - :lines: 37-52 + :lines: 53-68 :dedent: 4 Using the above optimizations with Arrow will produce the same results as when Arrow is not @@ -90,7 +104,7 @@ specify the type hints of ``pandas.Series`` and ``pandas.DataFrame`` as below: .. literalinclude:: ../../../../../examples/src/main/python/sql/arrow.py :language: python - :lines: 56-80 + :lines: 72-96 :dedent: 4 In the following sections, it describes the combinations of the supported type hints. For simplicity, @@ -113,7 +127,7 @@ The following example shows how to create this Pandas UDF that computes the prod .. literalinclude:: ../../../../../examples/src/main/python/sql/arrow.py :language: python - :lines: 84-114 + :lines: 100-130 :dedent: 4 For detailed usage, please see :func:`pandas_udf`. @@ -152,7 +166,7 @@ The following example shows how to create this Pandas UDF: .. literalinclude:: ../../../../../examples/src/main/python/sql/arrow.py :language: python - :lines: 118-140 + :lines: 134-156 :dedent: 4 For detailed usage, please see :func:`pandas_udf`. @@ -174,7 +188,7 @@ The following example shows how to create this Pandas UDF: .. literalinclude:: ../../../../../examples/src/main/python/sql/arrow.py :language: python - :lines: 144-167 + :lines: 160-183 :dedent: 4 For detailed usage, please see :func:`pandas_udf`. @@ -205,7 +219,7 @@ and window operations: .. literalinclude:: ../../../../../examples/src/main/python/sql/arrow.py :language: python - :lines: 171-212 + :lines: 187-228 :dedent: 4 .. currentmodule:: pyspark.sql.functions @@ -270,7 +284,7 @@ in the group. .. literalinclude:: ../../../../../examples/src/main/python/sql/arrow.py :language: python - :lines: 216-234 + :lines: 232-250 :dedent: 4 For detailed usage, please see please see :meth:`GroupedData.applyInPandas` @@ -288,7 +302,7 @@ The following example shows how to use :meth:`DataFrame.mapInPandas`: .. literalinclude:: ../../../../../examples/src/main/python/sql/arrow.py :language: python - :lines: 238-249 + :lines: 254-265 :dedent: 4 For detailed usage, please see :meth:`DataFrame.mapInPandas`. @@ -327,7 +341,7 @@ The following example shows how to use ``DataFrame.groupby().cogroup().applyInPa .. literalinclude:: ../../../../../examples/src/main/python/sql/arrow.py :language: python - :lines: 253-275 + :lines: 269-291 :dedent: 4 @@ -349,7 +363,7 @@ Here's an example that demonstrates the usage of both a default, pickled Python .. literalinclude:: ../../../../../examples/src/main/python/sql/arrow.py :language: python - :lines: 279-297 + :lines: 295-313 :dedent: 4 Compared to the default, pickled Python UDFs, Arrow Python UDFs provide a more coherent type coercion mechanism. UDF @@ -421,9 +435,12 @@ be verified by the user. Setting Arrow ``self_destruct`` for memory savings ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -Since Spark 3.2, the Spark configuration ``spark.sql.execution.arrow.pyspark.selfDestruct.enabled`` can be used to enable PyArrow's ``self_destruct`` feature, which can save memory when creating a Pandas DataFrame via ``toPandas`` by freeing Arrow-allocated memory while building the Pandas DataFrame. -This option is experimental, and some operations may fail on the resulting Pandas DataFrame due to immutable backing arrays. -Typically, you would see the error ``ValueError: buffer source array is read-only``. -Newer versions of Pandas may fix these errors by improving support for such cases. -You can work around this error by copying the column(s) beforehand. -Additionally, this conversion may be slower because it is single-threaded. +Since Spark 3.2, the Spark configuration ``spark.sql.execution.arrow.pyspark.selfDestruct.enabled`` +can be used to enable PyArrow's ``self_destruct`` feature, which can save memory when creating a +Pandas DataFrame via ``toPandas`` by freeing Arrow-allocated memory while building the Pandas +DataFrame. This option can also save memory when creating a PyArrow Table via ``toArrow``. +This option is experimental. When used with ``toPandas``, some operations may fail on the resulting +Pandas DataFrame due to immutable backing arrays. Typically, you would see the error +``ValueError: buffer source array is read-only``. Newer versions of Pandas may fix these errors by +improving support for such cases. You can work around this error by copying the column(s) +beforehand. Additionally, this conversion may be slower because it is single-threaded. diff --git a/python/pyspark/sql/classic/dataframe.py b/python/pyspark/sql/classic/dataframe.py index db9f22517ddad..9b6790d29aaa7 100644 --- a/python/pyspark/sql/classic/dataframe.py +++ b/python/pyspark/sql/classic/dataframe.py @@ -74,6 +74,7 @@ if TYPE_CHECKING: from py4j.java_gateway import JavaObject + import pyarrow as pa from pyspark.core.rdd import RDD from pyspark.core.context import SparkContext from pyspark._typing import PrimitiveType @@ -1825,6 +1826,9 @@ def mapInArrow( ) -> ParentDataFrame: return PandasMapOpsMixin.mapInArrow(self, func, schema, barrier, profile) + def toArrow(self) -> "pa.Table": + return PandasConversionMixin.toArrow(self) + def toPandas(self) -> "PandasDataFrameLike": return PandasConversionMixin.toPandas(self) diff --git a/python/pyspark/sql/connect/dataframe.py b/python/pyspark/sql/connect/dataframe.py index 843c92a9b27d2..3c9415adec2dd 100644 --- a/python/pyspark/sql/connect/dataframe.py +++ b/python/pyspark/sql/connect/dataframe.py @@ -1768,6 +1768,10 @@ def _to_table(self) -> Tuple["pa.Table", Optional[StructType]]: assert table is not None return (table, schema) + def toArrow(self) -> "pa.Table": + table, _ = self._to_table() + return table + def toPandas(self) -> "PandasDataFrameLike": query = self._plan.to_proto(self._session.client) return self._session.client.to_pandas(query, self._plan.observations) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index e3d52c45d0c1d..886f72cc371e9 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -44,6 +44,7 @@ if TYPE_CHECKING: from py4j.java_gateway import JavaObject + import pyarrow as pa from pyspark.core.context import SparkContext from pyspark.core.rdd import RDD from pyspark._typing import PrimitiveType @@ -1200,6 +1201,7 @@ def collect(self) -> List[Row]: DataFrame.take : Returns the first `n` rows. DataFrame.head : Returns the first `n` rows. DataFrame.toPandas : Returns the data as a pandas DataFrame. + DataFrame.toArrow : Returns the data as a PyArrow Table. Notes ----- @@ -6213,6 +6215,34 @@ def mapInArrow( """ ... + @dispatch_df_method + def toArrow(self) -> "pa.Table": + """ + Returns the contents of this :class:`DataFrame` as PyArrow ``pyarrow.Table``. + + This is only available if PyArrow is installed and available. + + .. versionadded:: 4.0.0 + + Notes + ----- + This method should only be used if the resulting PyArrow ``pyarrow.Table`` is + expected to be small, as all the data is loaded into the driver's memory. + + This API is a developer API. + + Examples + -------- + >>> df.toArrow() # doctest: +SKIP + pyarrow.Table + age: int64 + name: string + ---- + age: [[2,5]] + name: [["Alice","Bob"]] + """ + ... + def toPandas(self) -> "PandasDataFrameLike": """ Returns the contents of this :class:`DataFrame` as Pandas ``pandas.DataFrame``. diff --git a/python/pyspark/sql/pandas/conversion.py b/python/pyspark/sql/pandas/conversion.py index ec4e21daba97b..344608317beb7 100644 --- a/python/pyspark/sql/pandas/conversion.py +++ b/python/pyspark/sql/pandas/conversion.py @@ -225,15 +225,48 @@ def toPandas(self) -> "PandasDataFrameLike": else: return pdf - def _collect_as_arrow(self, split_batches: bool = False) -> List["pa.RecordBatch"]: + def toArrow(self) -> "pa.Table": + from pyspark.sql.dataframe import DataFrame + + assert isinstance(self, DataFrame) + + jconf = self.sparkSession._jconf + + from pyspark.sql.pandas.types import to_arrow_schema + from pyspark.sql.pandas.utils import require_minimum_pyarrow_version + + require_minimum_pyarrow_version() + to_arrow_schema(self.schema) + + import pyarrow as pa + + self_destruct = jconf.arrowPySparkSelfDestructEnabled() + batches = self._collect_as_arrow( + split_batches=self_destruct, empty_list_if_zero_records=False + ) + table = pa.Table.from_batches(batches) + # Ensure only the table has a reference to the batches, so that + # self_destruct (if enabled) is effective + del batches + return table + + def _collect_as_arrow( + self, + split_batches: bool = False, + empty_list_if_zero_records: bool = True, + ) -> List["pa.RecordBatch"]: """ - Returns all records as a list of ArrowRecordBatches, pyarrow must be installed + Returns all records as a list of Arrow RecordBatches. PyArrow must be installed and available on driver and worker Python environments. This is an experimental feature. :param split_batches: split batches such that each column is in its own allocation, so that the selfDestruct optimization is effective; default False. + :param empty_list_if_zero_records: If True (the default), returns an empty list if the + result has 0 records. Otherwise, returns a list of length 1 containing an empty + Arrow RecordBatch which includes the schema. + .. note:: Experimental. """ from pyspark.sql.dataframe import DataFrame @@ -282,8 +315,15 @@ def _collect_as_arrow(self, split_batches: bool = False) -> List["pa.RecordBatch batches = results[:-1] batch_order = results[-1] - # Re-order the batch list using the correct order - return [batches[i] for i in batch_order] + if len(batches) or empty_list_if_zero_records: + # Re-order the batch list using the correct order + return [batches[i] for i in batch_order] + else: + from pyspark.sql.pandas.types import to_arrow_schema + + schema = to_arrow_schema(self.schema) + empty_arrays = [pa.array([], type=field.type) for field in schema] + return [pa.RecordBatch.from_arrays(empty_arrays, schema=schema)] class SparkConversionMixin: diff --git a/python/pyspark/sql/tests/test_arrow.py b/python/pyspark/sql/tests/test_arrow.py index 8636e953aaf8f..71d3c46e5ee1e 100644 --- a/python/pyspark/sql/tests/test_arrow.py +++ b/python/pyspark/sql/tests/test_arrow.py @@ -179,6 +179,35 @@ def create_pandas_data_frame(self): data_dict["4_float_t"] = np.float32(data_dict["4_float_t"]) return pd.DataFrame(data=data_dict) + def create_arrow_table(self): + import pyarrow as pa + import pyarrow.compute as pc + + data_dict = {} + for j, name in enumerate(self.schema.names): + data_dict[name] = [self.data[i][j] for i in range(len(self.data))] + t = pa.Table.from_pydict(data_dict) + # convert these to Arrow types + new_schema = t.schema.set( + t.schema.get_field_index("2_int_t"), pa.field("2_int_t", pa.int32()) + ) + new_schema = new_schema.set( + new_schema.get_field_index("4_float_t"), pa.field("4_float_t", pa.float32()) + ) + new_schema = new_schema.set( + new_schema.get_field_index("6_decimal_t"), + pa.field("6_decimal_t", pa.decimal128(38, 18)), + ) + t = t.cast(new_schema) + # convert timestamp to local timezone + timezone = self.spark.conf.get("spark.sql.session.timeZone") + t = t.set_column( + t.schema.get_field_index("8_timestamp_t"), + "8_timestamp_t", + pc.assume_timezone(t["8_timestamp_t"], timezone), + ) + return t + @property def create_np_arrs(self): import numpy as np @@ -339,6 +368,12 @@ def test_pandas_round_trip(self): pdf_arrow = df.toPandas() assert_frame_equal(pdf_arrow, pdf) + def test_arrow_round_trip(self): + t_in = self.create_arrow_table() + df = self.spark.createDataFrame(self.data, schema=self.schema) + t_out = df.toArrow() + self.assertTrue(t_out.equals(t_in)) + def test_pandas_self_destruct(self): import pyarrow as pa