Skip to content

Commit

Permalink
[SPARK-48220][PYTHON] Allow passing PyArrow Table to createDataFrame()
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?
- Add support for passing a PyArrow Table to `createDataFrame()`.
- Document this on the **Apache Arrow in PySpark** user guide page.
- Fix an issue with timestamp and struct columns in `toArrow()`.

### Why are the changes needed?
This seems like a logical next step after the addition of a `toArrow()` DataFrame method in #45481.

### Does this PR introduce _any_ user-facing change?
Users will have the ability to pass PyArrow Tables to `createDataFrame()`. There are no changes to the parameters of `createDataFrame()`. The only difference is that `data` can now be a PyArrow Table.

### How was this patch tested?
Many tests were added, for Spark Classic and Spark Connect. I ran the tests locally with older versions of PyArrow installed (going back to 10.0).

### Was this patch authored or co-authored using generative AI tooling?
No

Closes #46529 from ianmcook/SPARK-48220.

Authored-by: Ian Cook <ianmcook@gmail.com>
Signed-off-by: Hyukjin Kwon <gurwls223@apache.org>
  • Loading branch information
ianmcook authored and HyukjinKwon committed Jun 2, 2024
1 parent 8cf3195 commit bc18701
Show file tree
Hide file tree
Showing 11 changed files with 1,048 additions and 196 deletions.
25 changes: 14 additions & 11 deletions examples/src/main/python/sql/arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,20 +33,23 @@
require_minimum_pyarrow_version()


def dataframe_to_arrow_table_example(spark: SparkSession) -> None:
import pyarrow as pa # noqa: F401
from pyspark.sql.functions import rand
def dataframe_to_from_arrow_table_example(spark: SparkSession) -> None:
import pyarrow as pa
import numpy as np

# Create a PyArrow Table
table = pa.table([pa.array(np.random.rand(100)) for i in range(3)], names=["a", "b", "c"])

# Create a Spark DataFrame
df = spark.range(100).drop("id").withColumns({"0": rand(), "1": rand(), "2": rand()})
# Create a Spark DataFrame from the PyArrow Table
df = spark.createDataFrame(table)

# Convert the Spark DataFrame to a PyArrow Table
table = df.select("*").toArrow()
result_table = df.select("*").toArrow()

print(table.schema)
# 0: double not null
# 1: double not null
# 2: double not null
print(result_table.schema)
# a: double
# b: double
# c: double


def dataframe_with_arrow_example(spark: SparkSession) -> None:
Expand Down Expand Up @@ -319,7 +322,7 @@ def arrow_slen(s): # type: ignore[no-untyped-def]
.getOrCreate()

print("Running Arrow conversion example: DataFrame to Table")
dataframe_to_arrow_table_example(spark)
dataframe_to_from_arrow_table_example(spark)
print("Running Pandas to/from conversion example")
dataframe_with_arrow_example(spark)
print("Running pandas_udf example: Series to Frame")
Expand Down
48 changes: 27 additions & 21 deletions python/docs/source/user_guide/sql/arrow_pandas.rst
Original file line number Diff line number Diff line change
Expand Up @@ -39,19 +39,21 @@ is installed and available on all cluster nodes.
You can install it using pip or conda from the conda-forge channel. See PyArrow
`installation <https://arrow.apache.org/docs/python/install.html>`_ for details.

Conversion to Arrow Table
-------------------------
Conversion to/from Arrow Table
------------------------------

You can call :meth:`DataFrame.toArrow` to convert a Spark DataFrame to a PyArrow Table.
From Spark 4.0, you can create a Spark DataFrame from a PyArrow Table with
:meth:`SparkSession.createDataFrame`, and you can convert a Spark DataFrame to a PyArrow Table
with :meth:`DataFrame.toArrow`.

.. literalinclude:: ../../../../../examples/src/main/python/sql/arrow.py
:language: python
:lines: 37-49
:lines: 37-52
:dedent: 4

Note that :meth:`DataFrame.toArrow` results in the collection of all records in the DataFrame to
the driver program and should be done on a small subset of the data. Not all Spark data types are
currently supported and an error can be raised if a column has an unsupported type.
the driver program and should be done on a small subset of the data. Not all Spark and Arrow data
types are currently supported and an error can be raised if a column has an unsupported type.

Enabling for Conversion to/from Pandas
--------------------------------------
Expand All @@ -67,7 +69,7 @@ This can be controlled by ``spark.sql.execution.arrow.pyspark.fallback.enabled``

.. literalinclude:: ../../../../../examples/src/main/python/sql/arrow.py
:language: python
:lines: 53-68
:lines: 56-71
:dedent: 4

Using the above optimizations with Arrow will produce the same results as when Arrow is not
Expand Down Expand Up @@ -104,7 +106,7 @@ specify the type hints of ``pandas.Series`` and ``pandas.DataFrame`` as below:

.. literalinclude:: ../../../../../examples/src/main/python/sql/arrow.py
:language: python
:lines: 72-96
:lines: 75-99
:dedent: 4

In the following sections, it describes the combinations of the supported type hints. For simplicity,
Expand All @@ -127,7 +129,7 @@ The following example shows how to create this Pandas UDF that computes the prod

.. literalinclude:: ../../../../../examples/src/main/python/sql/arrow.py
:language: python
:lines: 100-130
:lines: 103-133
:dedent: 4

For detailed usage, please see :func:`pandas_udf`.
Expand Down Expand Up @@ -166,7 +168,7 @@ The following example shows how to create this Pandas UDF:

.. literalinclude:: ../../../../../examples/src/main/python/sql/arrow.py
:language: python
:lines: 134-156
:lines: 137-159
:dedent: 4

For detailed usage, please see :func:`pandas_udf`.
Expand All @@ -188,7 +190,7 @@ The following example shows how to create this Pandas UDF:

.. literalinclude:: ../../../../../examples/src/main/python/sql/arrow.py
:language: python
:lines: 160-183
:lines: 163-186
:dedent: 4

For detailed usage, please see :func:`pandas_udf`.
Expand Down Expand Up @@ -219,7 +221,7 @@ and window operations:

.. literalinclude:: ../../../../../examples/src/main/python/sql/arrow.py
:language: python
:lines: 187-228
:lines: 190-231
:dedent: 4

.. currentmodule:: pyspark.sql.functions
Expand Down Expand Up @@ -284,7 +286,7 @@ in the group.

.. literalinclude:: ../../../../../examples/src/main/python/sql/arrow.py
:language: python
:lines: 232-250
:lines: 235-253
:dedent: 4

For detailed usage, please see please see :meth:`GroupedData.applyInPandas`
Expand All @@ -302,7 +304,7 @@ The following example shows how to use :meth:`DataFrame.mapInPandas`:

.. literalinclude:: ../../../../../examples/src/main/python/sql/arrow.py
:language: python
:lines: 254-265
:lines: 257-268
:dedent: 4

For detailed usage, please see :meth:`DataFrame.mapInPandas`.
Expand Down Expand Up @@ -341,7 +343,7 @@ The following example shows how to use ``DataFrame.groupby().cogroup().applyInPa

.. literalinclude:: ../../../../../examples/src/main/python/sql/arrow.py
:language: python
:lines: 269-291
:lines: 272-294
:dedent: 4


Expand All @@ -363,7 +365,7 @@ Here's an example that demonstrates the usage of both a default, pickled Python

.. literalinclude:: ../../../../../examples/src/main/python/sql/arrow.py
:language: python
:lines: 295-313
:lines: 298-316
:dedent: 4

Compared to the default, pickled Python UDFs, Arrow Python UDFs provide a more coherent type coercion mechanism. UDF
Expand Down Expand Up @@ -414,11 +416,15 @@ and each column will be converted to the Spark session time zone then localized
zone, which removes the time zone and displays values as local time. This will occur
when calling :meth:`DataFrame.toPandas()` or ``pandas_udf`` with timestamp columns.

When timestamp data is transferred from Pandas to Spark, it will be converted to UTC microseconds. This
occurs when calling :meth:`SparkSession.createDataFrame` with a Pandas DataFrame or when returning a timestamp from a
``pandas_udf``. These conversions are done automatically to ensure Spark will have data in the
expected format, so it is not necessary to do any of these conversions yourself. Any nanosecond
values will be truncated.
When timestamp data is transferred from Spark to a PyArrow Table, it will remain in microsecond
resolution with the UTC time zone. This occurs when calling :meth:`DataFrame.toArrow()` with
timestamp columns.

When timestamp data is transferred from Pandas or PyArrow to Spark, it will be converted to UTC
microseconds. This occurs when calling :meth:`SparkSession.createDataFrame` with a Pandas DataFrame
or PyArrow Table, or when returning a timestamp from a ``pandas_udf``. These conversions are done
automatically to ensure Spark will have data in the expected format, so it is not necessary to do
any of these conversions yourself. Any nanosecond values will be truncated.

Note that a standard UDF (non-Pandas) will load timestamp data as Python datetime objects, which is
different from a Pandas timestamp. It is recommended to use Pandas time series functionality when
Expand Down
5 changes: 3 additions & 2 deletions python/pyspark/sql/connect/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@
UnresolvedStar,
)
from pyspark.sql.connect.functions import builtin as F
from pyspark.sql.pandas.types import from_arrow_schema
from pyspark.sql.pandas.types import from_arrow_schema, to_arrow_schema
from pyspark.sql.pandas.functions import _validate_pandas_udf # type: ignore[attr-defined]


Expand Down Expand Up @@ -1770,8 +1770,9 @@ def _to_table(self) -> Tuple["pa.Table", Optional[StructType]]:
return (table, schema)

def toArrow(self) -> "pa.Table":
schema = to_arrow_schema(self.schema, error_on_duplicated_field_names_in_struct=True)
table, _ = self._to_table()
return table
return table.cast(schema)

def toPandas(self) -> "PandasDataFrameLike":
query = self._plan.to_proto(self._session.client)
Expand Down
28 changes: 27 additions & 1 deletion python/pyspark/sql/connect/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
# limitations under the License.
#
from pyspark.sql.connect.utils import check_dependencies
from pyspark.sql.utils import is_timestamp_ntz_preferred

check_dependencies(__name__)

Expand Down Expand Up @@ -73,7 +74,9 @@
to_arrow_schema,
to_arrow_type,
_deduplicate_field_names,
from_arrow_schema,
from_arrow_type,
_check_arrow_table_timestamps_localize,
)
from pyspark.sql.profiler import Profile
from pyspark.sql.session import classproperty, SparkSession as PySparkSession
Expand Down Expand Up @@ -413,7 +416,7 @@ def _inferSchemaFromList(

def createDataFrame(
self,
data: Union["pd.DataFrame", "np.ndarray", Iterable[Any]],
data: Union["pd.DataFrame", "np.ndarray", "pa.Table", Iterable[Any]],
schema: Optional[Union[AtomicType, StructType, str, List[str], Tuple[str, ...]]] = None,
samplingRatio: Optional[float] = None,
verifySchema: Optional[bool] = None,
Expand Down Expand Up @@ -476,6 +479,7 @@ def createDataFrame(
)

_table: Optional[pa.Table] = None
timezone: Optional[str] = None

if isinstance(data, pd.DataFrame):
# Logic was borrowed from `_create_from_pandas_with_arrow` in
Expand Down Expand Up @@ -561,6 +565,28 @@ def createDataFrame(
cast(StructType, _deduplicate_field_names(schema)).names
).cast(arrow_schema)

elif isinstance(data, pa.Table):
prefer_timestamp_ntz = is_timestamp_ntz_preferred()

(timezone,) = self._client.get_configs("spark.sql.session.timeZone")

# If no schema supplied by user then get the names of columns only
if schema is None:
_cols = data.column_names
if isinstance(schema, (list, tuple)) and cast(int, _num_cols) < len(data.columns):
assert isinstance(_cols, list)
_cols.extend([f"_{i + 1}" for i in range(cast(int, _num_cols), len(data.columns))])
_num_cols = len(_cols)

if not isinstance(schema, StructType):
schema = from_arrow_schema(data.schema, prefer_timestamp_ntz=prefer_timestamp_ntz)

_table = (
_check_arrow_table_timestamps_localize(data, schema, True, timezone)
.cast(to_arrow_schema(schema, error_on_duplicated_field_names_in_struct=True))
.rename_columns(schema.names)
)

elif isinstance(data, np.ndarray):
if _cols is None:
if data.ndim == 1 or data.shape[1] == 1:
Expand Down
23 changes: 17 additions & 6 deletions python/pyspark/sql/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@

if TYPE_CHECKING:
from py4j.java_gateway import JavaObject
import pyarrow as pa
from pyspark.core.rdd import RDD
from pyspark.core.context import SparkContext
from pyspark.sql._typing import (
Expand Down Expand Up @@ -343,28 +344,29 @@ def createDataFrame(

@overload
def createDataFrame(
self, data: "PandasDataFrameLike", samplingRatio: Optional[float] = ...
self, data: Union["PandasDataFrameLike", "pa.Table"], samplingRatio: Optional[float] = ...
) -> DataFrame:
...

@overload
def createDataFrame(
self,
data: "PandasDataFrameLike",
data: Union["PandasDataFrameLike", "pa.Table"],
schema: Union[StructType, str],
verifySchema: bool = ...,
) -> DataFrame:
...

def createDataFrame( # type: ignore[misc]
self,
data: Union["RDD[Any]", Iterable[Any], "PandasDataFrameLike"],
data: Union["RDD[Any]", Iterable[Any], "PandasDataFrameLike", "pa.Table"],
schema: Optional[Union[AtomicType, StructType, str]] = None,
samplingRatio: Optional[float] = None,
verifySchema: bool = True,
) -> DataFrame:
"""
Creates a :class:`DataFrame` from an :class:`RDD`, a list or a :class:`pandas.DataFrame`.
Creates a :class:`DataFrame` from an :class:`RDD`, a list, a :class:`pandas.DataFrame`, or
a :class:`pyarrow.Table`.
When ``schema`` is a list of column names, the type of each column
will be inferred from ``data``.
Expand Down Expand Up @@ -393,12 +395,15 @@ def createDataFrame( # type: ignore[misc]
.. versionchanged:: 2.1.0
Added verifySchema.
.. versionchanged:: 4.0.0
Added support for :class:`pyarrow.Table`.
Parameters
----------
data : :class:`RDD` or iterable
an RDD of any kind of SQL data representation (:class:`Row`,
:class:`tuple`, ``int``, ``boolean``, etc.), or :class:`list`, or
:class:`pandas.DataFrame`.
:class:`tuple`, ``int``, ``boolean``, etc.), or :class:`list`,
:class:`pandas.DataFrame`, or :class:`pyarrow.Table`.
schema : :class:`pyspark.sql.types.DataType`, str or list, optional
a :class:`pyspark.sql.types.DataType` or a datatype string or a list of
column names, default is None. The data type string format equals to
Expand Down Expand Up @@ -452,6 +457,12 @@ def createDataFrame( # type: ignore[misc]
>>> sqlContext.createDataFrame(pandas.DataFrame([[1, 2]])).collect() # doctest: +SKIP
[Row(0=1, 1=2)]
>>> sqlContext.createDataFrame(df.toArrow()).collect() # doctest: +SKIP
[Row(name='Alice', age=1)]
>>> table = pyarrow.table({'0': [1], '1': [2]}) # doctest: +SKIP
>>> sqlContext.createDataFrame(table).collect() # doctest: +SKIP
[Row(0=1, 1=2)]
>>> sqlContext.createDataFrame(rdd, "a: string, b: int").collect()
[Row(a='Alice', b=1)]
>>> rdd = rdd.map(lambda row: row[1])
Expand Down
Loading

0 comments on commit bc18701

Please sign in to comment.