diff --git a/examples/src/main/python/sql/arrow.py b/examples/src/main/python/sql/arrow.py index 48aee48d929c8..0200d094185d5 100644 --- a/examples/src/main/python/sql/arrow.py +++ b/examples/src/main/python/sql/arrow.py @@ -33,20 +33,23 @@ require_minimum_pyarrow_version() -def dataframe_to_arrow_table_example(spark: SparkSession) -> None: - import pyarrow as pa # noqa: F401 - from pyspark.sql.functions import rand +def dataframe_to_from_arrow_table_example(spark: SparkSession) -> None: + import pyarrow as pa + import numpy as np + + # Create a PyArrow Table + table = pa.table([pa.array(np.random.rand(100)) for i in range(3)], names=["a", "b", "c"]) - # Create a Spark DataFrame - df = spark.range(100).drop("id").withColumns({"0": rand(), "1": rand(), "2": rand()}) + # Create a Spark DataFrame from the PyArrow Table + df = spark.createDataFrame(table) # Convert the Spark DataFrame to a PyArrow Table - table = df.select("*").toArrow() + result_table = df.select("*").toArrow() - print(table.schema) - # 0: double not null - # 1: double not null - # 2: double not null + print(result_table.schema) + # a: double + # b: double + # c: double def dataframe_with_arrow_example(spark: SparkSession) -> None: @@ -319,7 +322,7 @@ def arrow_slen(s): # type: ignore[no-untyped-def] .getOrCreate() print("Running Arrow conversion example: DataFrame to Table") - dataframe_to_arrow_table_example(spark) + dataframe_to_from_arrow_table_example(spark) print("Running Pandas to/from conversion example") dataframe_with_arrow_example(spark) print("Running pandas_udf example: Series to Frame") diff --git a/python/docs/source/user_guide/sql/arrow_pandas.rst b/python/docs/source/user_guide/sql/arrow_pandas.rst index 0a527d832e211..fde40140110f9 100644 --- a/python/docs/source/user_guide/sql/arrow_pandas.rst +++ b/python/docs/source/user_guide/sql/arrow_pandas.rst @@ -39,19 +39,21 @@ is installed and available on all cluster nodes. You can install it using pip or conda from the conda-forge channel. See PyArrow `installation `_ for details. -Conversion to Arrow Table -------------------------- +Conversion to/from Arrow Table +------------------------------ -You can call :meth:`DataFrame.toArrow` to convert a Spark DataFrame to a PyArrow Table. +From Spark 4.0, you can create a Spark DataFrame from a PyArrow Table with +:meth:`SparkSession.createDataFrame`, and you can convert a Spark DataFrame to a PyArrow Table +with :meth:`DataFrame.toArrow`. .. literalinclude:: ../../../../../examples/src/main/python/sql/arrow.py :language: python - :lines: 37-49 + :lines: 37-52 :dedent: 4 Note that :meth:`DataFrame.toArrow` results in the collection of all records in the DataFrame to -the driver program and should be done on a small subset of the data. Not all Spark data types are -currently supported and an error can be raised if a column has an unsupported type. +the driver program and should be done on a small subset of the data. Not all Spark and Arrow data +types are currently supported and an error can be raised if a column has an unsupported type. Enabling for Conversion to/from Pandas -------------------------------------- @@ -67,7 +69,7 @@ This can be controlled by ``spark.sql.execution.arrow.pyspark.fallback.enabled`` .. literalinclude:: ../../../../../examples/src/main/python/sql/arrow.py :language: python - :lines: 53-68 + :lines: 56-71 :dedent: 4 Using the above optimizations with Arrow will produce the same results as when Arrow is not @@ -104,7 +106,7 @@ specify the type hints of ``pandas.Series`` and ``pandas.DataFrame`` as below: .. literalinclude:: ../../../../../examples/src/main/python/sql/arrow.py :language: python - :lines: 72-96 + :lines: 75-99 :dedent: 4 In the following sections, it describes the combinations of the supported type hints. For simplicity, @@ -127,7 +129,7 @@ The following example shows how to create this Pandas UDF that computes the prod .. literalinclude:: ../../../../../examples/src/main/python/sql/arrow.py :language: python - :lines: 100-130 + :lines: 103-133 :dedent: 4 For detailed usage, please see :func:`pandas_udf`. @@ -166,7 +168,7 @@ The following example shows how to create this Pandas UDF: .. literalinclude:: ../../../../../examples/src/main/python/sql/arrow.py :language: python - :lines: 134-156 + :lines: 137-159 :dedent: 4 For detailed usage, please see :func:`pandas_udf`. @@ -188,7 +190,7 @@ The following example shows how to create this Pandas UDF: .. literalinclude:: ../../../../../examples/src/main/python/sql/arrow.py :language: python - :lines: 160-183 + :lines: 163-186 :dedent: 4 For detailed usage, please see :func:`pandas_udf`. @@ -219,7 +221,7 @@ and window operations: .. literalinclude:: ../../../../../examples/src/main/python/sql/arrow.py :language: python - :lines: 187-228 + :lines: 190-231 :dedent: 4 .. currentmodule:: pyspark.sql.functions @@ -284,7 +286,7 @@ in the group. .. literalinclude:: ../../../../../examples/src/main/python/sql/arrow.py :language: python - :lines: 232-250 + :lines: 235-253 :dedent: 4 For detailed usage, please see please see :meth:`GroupedData.applyInPandas` @@ -302,7 +304,7 @@ The following example shows how to use :meth:`DataFrame.mapInPandas`: .. literalinclude:: ../../../../../examples/src/main/python/sql/arrow.py :language: python - :lines: 254-265 + :lines: 257-268 :dedent: 4 For detailed usage, please see :meth:`DataFrame.mapInPandas`. @@ -341,7 +343,7 @@ The following example shows how to use ``DataFrame.groupby().cogroup().applyInPa .. literalinclude:: ../../../../../examples/src/main/python/sql/arrow.py :language: python - :lines: 269-291 + :lines: 272-294 :dedent: 4 @@ -363,7 +365,7 @@ Here's an example that demonstrates the usage of both a default, pickled Python .. literalinclude:: ../../../../../examples/src/main/python/sql/arrow.py :language: python - :lines: 295-313 + :lines: 298-316 :dedent: 4 Compared to the default, pickled Python UDFs, Arrow Python UDFs provide a more coherent type coercion mechanism. UDF @@ -414,11 +416,15 @@ and each column will be converted to the Spark session time zone then localized zone, which removes the time zone and displays values as local time. This will occur when calling :meth:`DataFrame.toPandas()` or ``pandas_udf`` with timestamp columns. -When timestamp data is transferred from Pandas to Spark, it will be converted to UTC microseconds. This -occurs when calling :meth:`SparkSession.createDataFrame` with a Pandas DataFrame or when returning a timestamp from a -``pandas_udf``. These conversions are done automatically to ensure Spark will have data in the -expected format, so it is not necessary to do any of these conversions yourself. Any nanosecond -values will be truncated. +When timestamp data is transferred from Spark to a PyArrow Table, it will remain in microsecond +resolution with the UTC time zone. This occurs when calling :meth:`DataFrame.toArrow()` with +timestamp columns. + +When timestamp data is transferred from Pandas or PyArrow to Spark, it will be converted to UTC +microseconds. This occurs when calling :meth:`SparkSession.createDataFrame` with a Pandas DataFrame +or PyArrow Table, or when returning a timestamp from a ``pandas_udf``. These conversions are done +automatically to ensure Spark will have data in the expected format, so it is not necessary to do +any of these conversions yourself. Any nanosecond values will be truncated. Note that a standard UDF (non-Pandas) will load timestamp data as Python datetime objects, which is different from a Pandas timestamp. It is recommended to use Pandas time series functionality when diff --git a/python/pyspark/sql/connect/dataframe.py b/python/pyspark/sql/connect/dataframe.py index 354cf60c20144..d827fe3d28ec0 100644 --- a/python/pyspark/sql/connect/dataframe.py +++ b/python/pyspark/sql/connect/dataframe.py @@ -82,7 +82,7 @@ UnresolvedStar, ) from pyspark.sql.connect.functions import builtin as F -from pyspark.sql.pandas.types import from_arrow_schema +from pyspark.sql.pandas.types import from_arrow_schema, to_arrow_schema from pyspark.sql.pandas.functions import _validate_pandas_udf # type: ignore[attr-defined] @@ -1770,8 +1770,9 @@ def _to_table(self) -> Tuple["pa.Table", Optional[StructType]]: return (table, schema) def toArrow(self) -> "pa.Table": + schema = to_arrow_schema(self.schema, error_on_duplicated_field_names_in_struct=True) table, _ = self._to_table() - return table + return table.cast(schema) def toPandas(self) -> "PandasDataFrameLike": query = self._plan.to_proto(self._session.client) diff --git a/python/pyspark/sql/connect/session.py b/python/pyspark/sql/connect/session.py index f99d298ea1170..07a38b8e80042 100644 --- a/python/pyspark/sql/connect/session.py +++ b/python/pyspark/sql/connect/session.py @@ -15,6 +15,7 @@ # limitations under the License. # from pyspark.sql.connect.utils import check_dependencies +from pyspark.sql.utils import is_timestamp_ntz_preferred check_dependencies(__name__) @@ -73,7 +74,9 @@ to_arrow_schema, to_arrow_type, _deduplicate_field_names, + from_arrow_schema, from_arrow_type, + _check_arrow_table_timestamps_localize, ) from pyspark.sql.profiler import Profile from pyspark.sql.session import classproperty, SparkSession as PySparkSession @@ -413,7 +416,7 @@ def _inferSchemaFromList( def createDataFrame( self, - data: Union["pd.DataFrame", "np.ndarray", Iterable[Any]], + data: Union["pd.DataFrame", "np.ndarray", "pa.Table", Iterable[Any]], schema: Optional[Union[AtomicType, StructType, str, List[str], Tuple[str, ...]]] = None, samplingRatio: Optional[float] = None, verifySchema: Optional[bool] = None, @@ -476,6 +479,7 @@ def createDataFrame( ) _table: Optional[pa.Table] = None + timezone: Optional[str] = None if isinstance(data, pd.DataFrame): # Logic was borrowed from `_create_from_pandas_with_arrow` in @@ -561,6 +565,28 @@ def createDataFrame( cast(StructType, _deduplicate_field_names(schema)).names ).cast(arrow_schema) + elif isinstance(data, pa.Table): + prefer_timestamp_ntz = is_timestamp_ntz_preferred() + + (timezone,) = self._client.get_configs("spark.sql.session.timeZone") + + # If no schema supplied by user then get the names of columns only + if schema is None: + _cols = data.column_names + if isinstance(schema, (list, tuple)) and cast(int, _num_cols) < len(data.columns): + assert isinstance(_cols, list) + _cols.extend([f"_{i + 1}" for i in range(cast(int, _num_cols), len(data.columns))]) + _num_cols = len(_cols) + + if not isinstance(schema, StructType): + schema = from_arrow_schema(data.schema, prefer_timestamp_ntz=prefer_timestamp_ntz) + + _table = ( + _check_arrow_table_timestamps_localize(data, schema, True, timezone) + .cast(to_arrow_schema(schema, error_on_duplicated_field_names_in_struct=True)) + .rename_columns(schema.names) + ) + elif isinstance(data, np.ndarray): if _cols is None: if data.ndim == 1 or data.shape[1] == 1: diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py index cbb0299e2195d..3fe47615b8761 100644 --- a/python/pyspark/sql/context.py +++ b/python/pyspark/sql/context.py @@ -46,6 +46,7 @@ if TYPE_CHECKING: from py4j.java_gateway import JavaObject + import pyarrow as pa from pyspark.core.rdd import RDD from pyspark.core.context import SparkContext from pyspark.sql._typing import ( @@ -343,14 +344,14 @@ def createDataFrame( @overload def createDataFrame( - self, data: "PandasDataFrameLike", samplingRatio: Optional[float] = ... + self, data: Union["PandasDataFrameLike", "pa.Table"], samplingRatio: Optional[float] = ... ) -> DataFrame: ... @overload def createDataFrame( self, - data: "PandasDataFrameLike", + data: Union["PandasDataFrameLike", "pa.Table"], schema: Union[StructType, str], verifySchema: bool = ..., ) -> DataFrame: @@ -358,13 +359,14 @@ def createDataFrame( def createDataFrame( # type: ignore[misc] self, - data: Union["RDD[Any]", Iterable[Any], "PandasDataFrameLike"], + data: Union["RDD[Any]", Iterable[Any], "PandasDataFrameLike", "pa.Table"], schema: Optional[Union[AtomicType, StructType, str]] = None, samplingRatio: Optional[float] = None, verifySchema: bool = True, ) -> DataFrame: """ - Creates a :class:`DataFrame` from an :class:`RDD`, a list or a :class:`pandas.DataFrame`. + Creates a :class:`DataFrame` from an :class:`RDD`, a list, a :class:`pandas.DataFrame`, or + a :class:`pyarrow.Table`. When ``schema`` is a list of column names, the type of each column will be inferred from ``data``. @@ -393,12 +395,15 @@ def createDataFrame( # type: ignore[misc] .. versionchanged:: 2.1.0 Added verifySchema. + .. versionchanged:: 4.0.0 + Added support for :class:`pyarrow.Table`. + Parameters ---------- data : :class:`RDD` or iterable an RDD of any kind of SQL data representation (:class:`Row`, - :class:`tuple`, ``int``, ``boolean``, etc.), or :class:`list`, or - :class:`pandas.DataFrame`. + :class:`tuple`, ``int``, ``boolean``, etc.), or :class:`list`, + :class:`pandas.DataFrame`, or :class:`pyarrow.Table`. schema : :class:`pyspark.sql.types.DataType`, str or list, optional a :class:`pyspark.sql.types.DataType` or a datatype string or a list of column names, default is None. The data type string format equals to @@ -452,6 +457,12 @@ def createDataFrame( # type: ignore[misc] >>> sqlContext.createDataFrame(pandas.DataFrame([[1, 2]])).collect() # doctest: +SKIP [Row(0=1, 1=2)] + >>> sqlContext.createDataFrame(df.toArrow()).collect() # doctest: +SKIP + [Row(name='Alice', age=1)] + >>> table = pyarrow.table({'0': [1], '1': [2]}) # doctest: +SKIP + >>> sqlContext.createDataFrame(table).collect() # doctest: +SKIP + [Row(0=1, 1=2)] + >>> sqlContext.createDataFrame(rdd, "a: string, b: int").collect() [Row(a='Alice', b=1)] >>> rdd = rdd.map(lambda row: row[1]) diff --git a/python/pyspark/sql/pandas/conversion.py b/python/pyspark/sql/pandas/conversion.py index 344608317beb7..9da15caac8025 100644 --- a/python/pyspark/sql/pandas/conversion.py +++ b/python/pyspark/sql/pandas/conversion.py @@ -56,8 +56,8 @@ class PandasConversionMixin: """ - Mix-in for the conversion from Spark to pandas. Currently, only :class:`DataFrame` - can use this class. + Mix-in for the conversion from Spark to pandas and PyArrow. Currently, only + :class:`DataFrame` can use this class. """ def toPandas(self) -> "PandasDataFrameLike": @@ -236,7 +236,7 @@ def toArrow(self) -> "pa.Table": from pyspark.sql.pandas.utils import require_minimum_pyarrow_version require_minimum_pyarrow_version() - to_arrow_schema(self.schema) + schema = to_arrow_schema(self.schema, error_on_duplicated_field_names_in_struct=True) import pyarrow as pa @@ -244,7 +244,7 @@ def toArrow(self) -> "pa.Table": batches = self._collect_as_arrow( split_batches=self_destruct, empty_list_if_zero_records=False ) - table = pa.Table.from_batches(batches) + table = pa.Table.from_batches(batches).cast(schema) # Ensure only the table has a reference to the batches, so that # self_destruct (if enabled) is effective del batches @@ -320,6 +320,7 @@ def _collect_as_arrow( return [batches[i] for i in batch_order] else: from pyspark.sql.pandas.types import to_arrow_schema + import pyarrow as pa schema = to_arrow_schema(self.schema) empty_arrays = [pa.array([], type=field.type) for field in schema] @@ -328,8 +329,8 @@ def _collect_as_arrow( class SparkConversionMixin: """ - Min-in for the conversion from pandas to Spark. Currently, only :class:`SparkSession` - can use this class. + Min-in for the conversion from pandas and PyArrow to Spark. Currently, only + :class:`SparkSession` can use this class. """ _jsparkSession: "JavaObject" @@ -340,6 +341,12 @@ def createDataFrame( ) -> "DataFrame": ... + @overload + def createDataFrame( + self, data: "pa.Table", samplingRatio: Optional[float] = ... + ) -> "DataFrame": + ... + @overload def createDataFrame( self, @@ -349,9 +356,18 @@ def createDataFrame( ) -> "DataFrame": ... + @overload + def createDataFrame( + self, + data: "pa.Table", + schema: Union[StructType, str], + verifySchema: bool = ..., + ) -> "DataFrame": + ... + def createDataFrame( # type: ignore[misc] self, - data: "PandasDataFrameLike", + data: Union["PandasDataFrameLike", "pa.Table"], schema: Optional[Union[StructType, List[str]]] = None, samplingRatio: Optional[float] = None, verifySchema: bool = True, @@ -360,12 +376,29 @@ def createDataFrame( # type: ignore[misc] assert isinstance(self, SparkSession) + timezone = self._jconf.sessionLocalTimeZone() + + if type(data).__name__ == "Table": + # `data` is a PyArrow Table + from pyspark.sql.pandas.utils import require_minimum_pyarrow_version + + require_minimum_pyarrow_version() + + import pyarrow as pa + + assert isinstance(data, pa.Table) + + # If no schema supplied by user then get the names of columns only + if schema is None: + schema = data.schema.names + + return self._create_from_arrow_table(data, schema, timezone) + + # `data` is a PandasDataFrameLike object from pyspark.sql.pandas.utils import require_minimum_pandas_version require_minimum_pandas_version() - timezone = self._jconf.sessionLocalTimeZone() - # If no schema supplied by user then get the names of columns only if schema is None: schema = [str(x) if not isinstance(x, str) else x for x in data.columns] @@ -711,6 +744,75 @@ def create_iter_server(): df._schema = schema return df + def _create_from_arrow_table( + self, table: "pa.Table", schema: Union[StructType, List[str]], timezone: str + ) -> "DataFrame": + """ + Create a DataFrame from a given pyarrow.Table by slicing it into partitions then + sending to the JVM to parallelize. + """ + from pyspark.sql import SparkSession + from pyspark.sql.dataframe import DataFrame + + assert isinstance(self, SparkSession) + + from pyspark.sql.pandas.serializers import ArrowStreamSerializer + from pyspark.sql.pandas.types import ( + from_arrow_type, + from_arrow_schema, + to_arrow_schema, + _check_arrow_table_timestamps_localize, + ) + from pyspark.sql.pandas.utils import require_minimum_pyarrow_version + + require_minimum_pyarrow_version() + + prefer_timestamp_ntz = is_timestamp_ntz_preferred() + + # Create the Spark schema from list of names passed in with Arrow types + if isinstance(schema, (list, tuple)): + table = table.rename_columns(schema) + arrow_schema = table.schema + struct = StructType() + for name, field in zip(schema, arrow_schema): + struct.add( + name, + from_arrow_type(field.type, prefer_timestamp_ntz), + nullable=field.nullable, + ) + schema = struct + + if not isinstance(schema, StructType): + schema = from_arrow_schema(table.schema, prefer_timestamp_ntz=prefer_timestamp_ntz) + + table = _check_arrow_table_timestamps_localize(table, schema, True, timezone).cast( + to_arrow_schema(schema, error_on_duplicated_field_names_in_struct=True) + ) + + # Chunk the Arrow Table into RecordBatches + chunk_size = self._jconf.arrowMaxRecordsPerBatch() + arrow_data = table.to_batches(max_chunksize=chunk_size) + + jsparkSession = self._jsparkSession + + ser = ArrowStreamSerializer() + + @no_type_check + def reader_func(temp_filename): + return self._jvm.PythonSQLUtils.readArrowStreamFromFile(temp_filename) + + @no_type_check + def create_iter_server(): + return self._jvm.ArrowIteratorServer() + + # Create Spark DataFrame from Arrow stream file, using one batch per partition + jiter = self._sc._serialize_to_jvm(arrow_data, ser, reader_func, create_iter_server) + assert self._jvm is not None + jdf = self._jvm.PythonSQLUtils.toDataFrame(jiter, schema.json(), jsparkSession) + df = DataFrame(jdf, self) + df._schema = schema + return df + def _test() -> None: import doctest diff --git a/python/pyspark/sql/pandas/types.py b/python/pyspark/sql/pandas/types.py index 559512bd00c1c..30675d5550465 100644 --- a/python/pyspark/sql/pandas/types.py +++ b/python/pyspark/sql/pandas/types.py @@ -60,8 +60,32 @@ from pyspark.sql.pandas._typing import SeriesLike as PandasSeriesLike -def to_arrow_type(dt: DataType) -> "pa.DataType": - """Convert Spark data type to pyarrow type""" +def to_arrow_type( + dt: DataType, + error_on_duplicated_field_names_in_struct: bool = False, + timestamp_utc: bool = True, +) -> "pa.DataType": + """ + Convert Spark data type to PyArrow type + + Parameters + ---------- + dt : :class:`DataType` + The Spark data type. + error_on_duplicated_field_names_in_struct: bool, default False + Whether to raise an exception when there are duplicated field names in a + :class:`pyspark.sql.types.StructType`. (default ``False``) + timestamp_utc : bool, default True + If ``True`` (the default), :class:`TimestampType` is converted to a timezone-aware + :class:`pyarrow.TimestampType` with UTC as the timezone. If ``False``, + :class:`TimestampType` is converted to a timezone-naive :class:`pyarrow.TimestampType`. + The JVM expects timezone-aware timestamps to be in UTC. Always keep this set to ``True`` + except in special cases, such as when this function is used in a test. + + Returns + ------- + :class:`pyarrow.DataType` + """ import pyarrow as pa if type(dt) == BooleanType: @@ -86,30 +110,58 @@ def to_arrow_type(dt: DataType) -> "pa.DataType": arrow_type = pa.binary() elif type(dt) == DateType: arrow_type = pa.date32() - elif type(dt) == TimestampType: + elif type(dt) == TimestampType and timestamp_utc: # Timestamps should be in UTC, JVM Arrow timestamps require a timezone to be read arrow_type = pa.timestamp("us", tz="UTC") + elif type(dt) == TimestampType: + arrow_type = pa.timestamp("us", tz=None) elif type(dt) == TimestampNTZType: arrow_type = pa.timestamp("us", tz=None) elif type(dt) == DayTimeIntervalType: arrow_type = pa.duration("us") elif type(dt) == ArrayType: - field = pa.field("element", to_arrow_type(dt.elementType), nullable=dt.containsNull) + field = pa.field( + "element", + to_arrow_type(dt.elementType, error_on_duplicated_field_names_in_struct, timestamp_utc), + nullable=dt.containsNull, + ) arrow_type = pa.list_(field) elif type(dt) == MapType: - key_field = pa.field("key", to_arrow_type(dt.keyType), nullable=False) - value_field = pa.field("value", to_arrow_type(dt.valueType), nullable=dt.valueContainsNull) + key_field = pa.field( + "key", + to_arrow_type(dt.keyType, error_on_duplicated_field_names_in_struct, timestamp_utc), + nullable=False, + ) + value_field = pa.field( + "value", + to_arrow_type(dt.valueType, error_on_duplicated_field_names_in_struct, timestamp_utc), + nullable=dt.valueContainsNull, + ) arrow_type = pa.map_(key_field, value_field) elif type(dt) == StructType: + field_names = dt.names + if error_on_duplicated_field_names_in_struct and len(set(field_names)) != len(field_names): + raise UnsupportedOperationException( + error_class="DUPLICATED_FIELD_NAME_IN_ARROW_STRUCT", + message_parameters={"field_names": str(field_names)}, + ) fields = [ - pa.field(field.name, to_arrow_type(field.dataType), nullable=field.nullable) + pa.field( + field.name, + to_arrow_type( + field.dataType, error_on_duplicated_field_names_in_struct, timestamp_utc + ), + nullable=field.nullable, + ) for field in dt ] arrow_type = pa.struct(fields) elif type(dt) == NullType: arrow_type = pa.null() elif isinstance(dt, UserDefinedType): - arrow_type = to_arrow_type(dt.sqlType()) + arrow_type = to_arrow_type( + dt.sqlType(), error_on_duplicated_field_names_in_struct, timestamp_utc + ) elif type(dt) == VariantType: fields = [ pa.field("value", pa.binary(), nullable=False), @@ -124,12 +176,40 @@ def to_arrow_type(dt: DataType) -> "pa.DataType": return arrow_type -def to_arrow_schema(schema: StructType) -> "pa.Schema": - """Convert a schema from Spark to Arrow""" +def to_arrow_schema( + schema: StructType, + error_on_duplicated_field_names_in_struct: bool = False, + timestamp_utc: bool = True, +) -> "pa.Schema": + """ + Convert a schema from Spark to Arrow + + Parameters + ---------- + schema : :class:`StructType` + The Spark schema. + error_on_duplicated_field_names_in_struct: bool, default False + Whether to raise an exception when there are duplicated field names in an inner + :class:`pyspark.sql.types.StructType`. (default ``False``) + timestamp_utc : bool, default True + If ``True`` (the default), :class:`TimestampType` is converted to a timezone-aware + :class:`pyarrow.TimestampType` with UTC as the timezone. If ``False``, + :class:`TimestampType` is converted to a timezone-naive :class:`pyarrow.TimestampType`. + The JVM expects timezone-aware timestamps to be in UTC. Always keep this set to ``True`` + except in special cases, such as when this function is used in a test + + Returns + ------- + :class:`pyarrow.Schema` + """ import pyarrow as pa fields = [ - pa.field(field.name, to_arrow_type(field.dataType), nullable=field.nullable) + pa.field( + field.name, + to_arrow_type(field.dataType, error_on_duplicated_field_names_in_struct, timestamp_utc), + nullable=field.nullable, + ) for field in schema ] return pa.schema(fields) @@ -232,6 +312,154 @@ def _get_local_timezone() -> str: return os.environ.get("TZ", "dateutil/:") +def _check_arrow_array_timestamps_localize( + a: Union["pa.Array", "pa.ChunkedArray"], + dt: DataType, + truncate: bool = True, + timezone: Optional[str] = None, +) -> Union["pa.Array", "pa.ChunkedArray"]: + """ + Convert Arrow timestamps to timezone-naive in the specified timezone if the specified Spark + data type is TimestampType, and optionally truncate nanosecond timestamps to microseconds. + + This function works on Arrow Arrays and ChunkedArrays, and it recurses to convert nested + timestamps. + + Parameters + ---------- + a : :class:`pyarrow.Array` or :class:`pyarrow.ChunkedArray` + dt : :class:`DataType` + The Spark data type corresponding to the Arrow Array to be converted. + truncate : bool, default True + Whether to truncate nanosecond timestamps to microseconds. (default ``True``) + timezone : str, optional + The timezone to convert from. If there is a timestamp type, it's required. + + Returns + ------- + :class:`pyarrow.Array` or :class:`pyarrow.ChunkedArray` + """ + import pyarrow.types as types + import pyarrow as pa + import pyarrow.compute as pc + + if isinstance(a, pa.ChunkedArray) and (types.is_nested(a.type) or types.is_dictionary(a.type)): + return pa.chunked_array( + [ + _check_arrow_array_timestamps_localize(chunk, dt, truncate, timezone) + for chunk in a.iterchunks() + ] + ) + + if types.is_timestamp(a.type) and truncate and a.type.unit == "ns": + a = pc.floor_temporal(a, unit="microsecond") + + if types.is_timestamp(a.type) and a.type.tz is None and type(dt) == TimestampType: + assert timezone is not None + + # Only localize timestamps that will become Spark TimestampType columns. + # Do not localize timestamps that will become Spark TimestampNTZType columns. + return pc.assume_timezone(a, timezone) + if types.is_list(a.type): + # Return the ListArray as-is if it contains no nested fields or timestamps + if not types.is_nested(a.type.value_type) and not types.is_timestamp(a.type.value_type): + return a + + at: ArrayType = cast(ArrayType, dt) + return pa.ListArray.from_arrays( + a.offsets, + _check_arrow_array_timestamps_localize(a.values, at.elementType, truncate, timezone), + mask=a.is_null() if a.null_count else None, + ) + if types.is_map(a.type): + # Return the MapArray as-is if it contains no nested fields or timestamps + if ( + not types.is_nested(a.type.key_type) + and not types.is_nested(a.type.item_type) + and not types.is_timestamp(a.type.key_type) + and not types.is_timestamp(a.type.item_type) + ): + return a + + mt: MapType = cast(MapType, dt) + # TODO(SPARK-48302): Do not replace nulls in MapArray with empty lists + return pa.MapArray.from_arrays( + a.offsets, + _check_arrow_array_timestamps_localize(a.keys, mt.keyType, truncate, timezone), + _check_arrow_array_timestamps_localize(a.items, mt.valueType, truncate, timezone), + ) + if types.is_struct(a.type): + # Return the StructArray as-is if it contains no nested fields or timestamps + if all( + [ + not types.is_nested(a.type.field(i).type) + and not types.is_timestamp(a.type.field(i).type) + for i in range(a.type.num_fields) + ] + ): + return a + + st: StructType = cast(StructType, dt) + assert len(a.type) == len(st.fields) + + return pa.StructArray.from_arrays( + [ + _check_arrow_array_timestamps_localize( + a.field(i), st.fields[i].dataType, truncate, timezone + ) + for i in range(len(a.type)) + ], + [a.type[i].name for i in range(len(a.type))], + mask=a.is_null() if a.null_count else None, + ) + if types.is_dictionary(a.type): + return pa.DictionaryArray.from_arrays( + a.indices, + _check_arrow_array_timestamps_localize(a.dictionary, dt, truncate, timezone), + ) + return a + + +def _check_arrow_table_timestamps_localize( + table: "pa.Table", schema: StructType, truncate: bool = True, timezone: Optional[str] = None +) -> "pa.Table": + """ + Convert timestamps in a PyArrow Table to timezone-naive in the specified timezone if the + corresponding Spark data type is TimestampType in the specified Spark schema is TimestampType, + and optionally truncate nanosecond timestamps to microseconds. + + Parameters + ---------- + table : :class:`pyarrow.Table` + schema : :class:`StructType` + The Spark schema corresponding to the schema of the Arrow Table. + truncate : bool, default True + Whether to truncate nanosecond timestamps to microseconds. (default ``True``) + timezone : str, optional + The timezone to convert from. If there is a timestamp type, it's required. + + Returns + ------- + :class:`pyarrow.Table` + """ + import pyarrow.types as types + import pyarrow as pa + + # Return the table as-is if it contains no nested fields or timestamps + if all([not types.is_nested(at) and not types.is_timestamp(at) for at in table.schema.types]): + return table + + assert len(table.schema) == len(schema.fields) + + return pa.Table.from_arrays( + [ + _check_arrow_array_timestamps_localize(a, f.dataType, truncate, timezone) + for a, f in zip(table.columns, schema.fields) + ], + schema=table.schema, + ) + + def _check_series_localize_timestamps(s: "PandasSeriesLike", timezone: str) -> "PandasSeriesLike": """ Convert timezone aware timestamps to timezone-naive in the specified timezone or local timezone. diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py index 9077ee8874444..d6fb4b60d90a9 100644 --- a/python/pyspark/sql/session.py +++ b/python/pyspark/sql/session.py @@ -67,6 +67,7 @@ if TYPE_CHECKING: from py4j.java_gateway import JavaObject + import pyarrow as pa from pyspark.core.context import SparkContext from pyspark.core.rdd import RDD from pyspark.sql._typing import AtomicValue, RowLike, OptionalPrimitiveType @@ -1256,7 +1257,7 @@ def _getActiveSessionOrCreate(**static_conf: Any) -> "SparkSession": spark = builder.getOrCreate() return spark - @overload + @overload # type: ignore[override] def createDataFrame( self, data: Iterable["RowLike"], @@ -1318,6 +1319,10 @@ def createDataFrame( ) -> DataFrame: ... + @overload + def createDataFrame(self, data: "pa.Table", samplingRatio: Optional[float] = ...) -> DataFrame: + ... + @overload def createDataFrame( self, @@ -1327,28 +1332,40 @@ def createDataFrame( ) -> DataFrame: ... + @overload + def createDataFrame( + self, + data: "pa.Table", + schema: Union[StructType, str], + verifySchema: bool = ..., + ) -> DataFrame: + ... + def createDataFrame( # type: ignore[misc] self, - data: Union["RDD[Any]", Iterable[Any], "PandasDataFrameLike", "ArrayLike"], + data: Union["RDD[Any]", Iterable[Any], "PandasDataFrameLike", "ArrayLike", "pa.Table"], schema: Optional[Union[AtomicType, StructType, str]] = None, samplingRatio: Optional[float] = None, verifySchema: bool = True, ) -> DataFrame: """ - Creates a :class:`DataFrame` from an :class:`RDD`, a list, a :class:`pandas.DataFrame` - or a :class:`numpy.ndarray`. + Creates a :class:`DataFrame` from an :class:`RDD`, a list, a :class:`pandas.DataFrame`, + a :class:`numpy.ndarray`, or a :class:`pyarrow.Table`. .. versionadded:: 2.0.0 .. versionchanged:: 3.4.0 Supports Spark Connect. + .. versionchanged:: 4.0.0 + Supports :class:`pyarrow.Table`. + Parameters ---------- data : :class:`RDD` or iterable an RDD of any kind of SQL data representation (:class:`Row`, :class:`tuple`, ``int``, ``boolean``, ``dict``, etc.), or :class:`list`, - :class:`pandas.DataFrame` or :class:`numpy.ndarray`. + :class:`pandas.DataFrame`, :class:`numpy.ndarray`, or :class:`pyarrow.Table`. schema : :class:`pyspark.sql.types.DataType`, str or list, optional a :class:`pyspark.sql.types.DataType` or a datatype string or a list of column names, default is None. The data type string format equals to @@ -1374,10 +1391,10 @@ def createDataFrame( # type: ignore[misc] :class:`RDD`. verifySchema : bool, optional verify data types of every row against schema. Enabled by default. - When the input is :class:`pandas.DataFrame` and - `spark.sql.execution.arrow.pyspark.enabled` is enabled, this option is not - effective. It follows Arrow type coercion. This option is not supported with - Spark Connect. + When the input is :class:`pyarrow.Table` or when the input class is + :class:`pandas.DataFrame` and `spark.sql.execution.arrow.pyspark.enabled` is enabled, + this option is not effective. It follows Arrow type coercion. This option is not + supported with Spark Connect. .. versionadded:: 2.1.0 @@ -1477,6 +1494,22 @@ def createDataFrame( # type: ignore[misc] +---+---+ | 1| 2| +---+---+ + + Create a DataFrame from a PyArrow Table. + + >>> spark.createDataFrame(df.toArrow()).show() # doctest: +SKIP + +-----+---+ + | name|age| + +-----+---+ + |Alice| 1| + +-----+---+ + >>> table = pyarrow.table({'0': [1], '1': [2]}) # doctest: +SKIP + >>> spark.createDataFrame(table).collect() # doctest: +SKIP + +---+---+ + | 0| 1| + +---+---+ + | 1| 2| + +---+---+ """ SparkSession._activeSession = self assert self._jvm is not None @@ -1507,6 +1540,13 @@ def createDataFrame( # type: ignore[misc] except Exception: has_numpy = False + try: + import pyarrow as pa + + has_pyarrow = True + except Exception: + has_pyarrow = False + if has_numpy and isinstance(data, np.ndarray): # `data` of numpy.ndarray type will be converted to a pandas DataFrame, # so pandas is required. @@ -1540,6 +1580,11 @@ def createDataFrame( # type: ignore[misc] return super(SparkSession, self).createDataFrame( # type: ignore[call-overload] data, schema, samplingRatio, verifySchema ) + if has_pyarrow and isinstance(data, pa.Table): + # Create a DataFrame from PyArrow Table. + return super(SparkSession, self).createDataFrame( # type: ignore[call-overload] + data, schema, samplingRatio, verifySchema + ) return self._create_dataframe( data, schema, samplingRatio, verifySchema # type: ignore[arg-type] ) diff --git a/python/pyspark/sql/tests/connect/test_parity_arrow.py b/python/pyspark/sql/tests/connect/test_parity_arrow.py index fa0bc25885810..885b3001b1db1 100644 --- a/python/pyspark/sql/tests/connect/test_parity_arrow.py +++ b/python/pyspark/sql/tests/connect/test_parity_arrow.py @@ -31,8 +31,11 @@ def test_createDataFrame_fallback_disabled(self): def test_createDataFrame_fallback_enabled(self): super().test_createDataFrame_fallback_enabled() - def test_createDataFrame_with_map_type(self): - self.check_createDataFrame_with_map_type(True) + def test_createDataFrame_pandas_with_map_type(self): + self.check_createDataFrame_pandas_with_map_type(True) + + def test_createDataFrame_pandas_with_struct_type(self): + self.check_createDataFrame_pandas_with_struct_type(True) def test_createDataFrame_with_ndarray(self): self.check_createDataFrame_with_ndarray(True) @@ -69,6 +72,9 @@ def test_create_data_frame_to_pandas_timestamp_ntz(self): def test_create_data_frame_to_pandas_day_time_internal(self): self.check_create_data_frame_to_pandas_day_time_internal(True) + def test_createDataFrame_pandas_respect_session_timezone(self): + self.check_createDataFrame_pandas_respect_session_timezone(True) + def test_toPandas_respect_session_timezone(self): self.check_toPandas_respect_session_timezone(True) @@ -89,11 +95,11 @@ def test_toPandas_with_map_type(self): def test_toPandas_with_map_type_nulls(self): self.check_toPandas_with_map_type_nulls(True) - def test_createDataFrame_with_array_type(self): - self.check_createDataFrame_with_array_type(True) + def test_createDataFrame_pandas_with_array_type(self): + self.check_createDataFrame_pandas_with_array_type(True) - def test_createDataFrame_with_int_col_names(self): - self.check_createDataFrame_with_int_col_names(True) + def test_createDataFrame_pandas_with_int_col_names(self): + self.check_createDataFrame_pandas_with_int_col_names(True) def test_timestamp_nat(self): self.check_timestamp_nat(True) @@ -104,14 +110,17 @@ def test_toPandas_error(self): def test_toPandas_duplicate_field_names(self): self.check_toPandas_duplicate_field_names(True) - def test_createDataFrame_duplicate_field_names(self): - self.check_createDataFrame_duplicate_field_names(True) + def test_createDataFrame_pandas_duplicate_field_names(self): + self.check_createDataFrame_pandas_duplicate_field_names(True) + + def test_toPandas_empty_rows(self): + self.check_toPandas_empty_rows(True) def test_toPandas_empty_columns(self): self.check_toPandas_empty_columns(True) - def test_createDataFrame_nested_timestamp(self): - self.check_createDataFrame_nested_timestamp(True) + def test_createDataFrame_pandas_nested_timestamp(self): + self.check_createDataFrame_pandas_nested_timestamp(True) def test_toPandas_nested_timestamp(self): self.check_toPandas_nested_timestamp(True) diff --git a/python/pyspark/sql/tests/test_arrow.py b/python/pyspark/sql/tests/test_arrow.py index 71d3c46e5ee1e..a2221f983694e 100644 --- a/python/pyspark/sql/tests/test_arrow.py +++ b/python/pyspark/sql/tests/test_arrow.py @@ -160,6 +160,45 @@ def setUpClass(cls): ] cls.data = [tuple(list(d) + [None]) for d in cls.data_wo_null] + cls.schema_nested_timestamp = ( + StructType() + .add("ts", TimestampType()) + .add("ts_ntz", TimestampNTZType()) + .add( + "struct", StructType().add("ts", TimestampType()).add("ts_ntz", TimestampNTZType()) + ) + .add("array", ArrayType(TimestampType())) + .add("array_ntz", ArrayType(TimestampNTZType())) + .add("map", MapType(StringType(), TimestampType())) + .add("map_ntz", MapType(StringType(), TimestampNTZType())) + ) + cls.data_nested_timestamp = [ + Row( + datetime(2023, 1, 1, 0, 0, 0), + datetime(2023, 1, 1, 0, 0, 0), + Row( + datetime(2023, 1, 1, 0, 0, 0), + datetime(2023, 1, 1, 0, 0, 0), + ), + [datetime(2023, 1, 1, 0, 0, 0)], + [datetime(2023, 1, 1, 0, 0, 0)], + dict(ts=datetime(2023, 1, 1, 0, 0, 0)), + dict(ts_ntz=datetime(2023, 1, 1, 0, 0, 0)), + ) + ] + cls.data_nested_timestamp_expected_ny = Row( + ts=datetime(2022, 12, 31, 21, 0, 0), + ts_ntz=datetime(2023, 1, 1, 0, 0, 0), + struct=Row( + ts=datetime(2022, 12, 31, 21, 0, 0), + ts_ntz=datetime(2023, 1, 1, 0, 0, 0), + ), + array=[datetime(2022, 12, 31, 21, 0, 0)], + array_ntz=[datetime(2023, 1, 1, 0, 0, 0)], + map=dict(ts=datetime(2022, 12, 31, 21, 0, 0)), + map_ntz=dict(ts_ntz=datetime(2023, 1, 1, 0, 0, 0)), + ) + @classmethod def tearDownClass(cls): del os.environ["TZ"] @@ -181,7 +220,6 @@ def create_pandas_data_frame(self): def create_arrow_table(self): import pyarrow as pa - import pyarrow.compute as pc data_dict = {} for j, name in enumerate(self.schema.names): @@ -199,13 +237,6 @@ def create_arrow_table(self): pa.field("6_decimal_t", pa.decimal128(38, 18)), ) t = t.cast(new_schema) - # convert timestamp to local timezone - timezone = self.spark.conf.get("spark.sql.session.timeZone") - t = t.set_column( - t.schema.get_field_index("8_timestamp_t"), - "8_timestamp_t", - pc.assume_timezone(t["8_timestamp_t"], timezone), - ) return t @property @@ -315,6 +346,17 @@ def check_create_data_frame_to_pandas_timestamp_ntz(self, arrow_enabled): pdf = df.toPandas() assert_frame_equal(origin, pdf) + def test_create_data_frame_to_arrow_timestamp_ntz(self): + with self.sql_conf({"spark.sql.session.timeZone": "America/Los_Angeles"}): + origin = pa.table({"a": [datetime.datetime(2012, 2, 2, 2, 2, 2)]}) + df = self.spark.createDataFrame( + origin, schema=StructType([StructField("a", TimestampNTZType(), True)]) + ) + df.selectExpr("assert_true('2012-02-02 02:02:02' == CAST(a AS STRING))").collect() + + t = df.toArrow() + self.assertTrue(origin.equals(t)) + def test_create_data_frame_to_pandas_day_time_internal(self): for arrow_enabled in [True, False]: with self.subTest(arrow_enabled=arrow_enabled): @@ -332,6 +374,16 @@ def check_create_data_frame_to_pandas_day_time_internal(self, arrow_enabled): pdf = df.toPandas() assert_frame_equal(origin, pdf) + def test_create_data_frame_to_arrow_day_time_internal(self): + origin = pa.table({"a": [datetime.timedelta(microseconds=123)]}) + df = self.spark.createDataFrame(origin) + df.select( + assert_true(lit("INTERVAL '0 00:00:00.000123' DAY TO SECOND") == df.a.cast("string")) + ).collect() + + t = df.toArrow() + self.assertTrue(origin.equals(t)) + def test_toPandas_respect_session_timezone(self): for arrow_enabled in [True, False]: with self.subTest(arrow_enabled=arrow_enabled): @@ -362,6 +414,21 @@ def check_toPandas_respect_session_timezone(self, arrow_enabled): ) assert_frame_equal(pdf_ny, pdf_la_corrected) + def test_toArrow_keep_utc_timezone(self): + df = self.spark.createDataFrame(self.data, schema=self.schema) + + timezone = "America/Los_Angeles" + with self.sql_conf({"spark.sql.session.timeZone": timezone}): + t_la = df.toArrow() + + timezone = "America/New_York" + with self.sql_conf({"spark.sql.session.timeZone": timezone}): + t_ny = df.toArrow() + + self.assertTrue(t_ny.equals(t_la)) + self.assertEqual(t_la["8_timestamp_t"].type.tz, "UTC") + self.assertEqual(t_ny["8_timestamp_t"].type.tz, "UTC") + def test_pandas_round_trip(self): pdf = self.create_pandas_data_frame() df = self.spark.createDataFrame(self.data, schema=self.schema) @@ -369,9 +436,28 @@ def test_pandas_round_trip(self): assert_frame_equal(pdf_arrow, pdf) def test_arrow_round_trip(self): + import pyarrow.compute as pc + t_in = self.create_arrow_table() + + # Convert timezone-naive local timestamp column in input table to UTC + # to enable comparison to UTC timestamp column in output table + timezone = self.spark.conf.get("spark.sql.session.timeZone") + t_in = t_in.set_column( + t_in.schema.get_field_index("8_timestamp_t"), + "8_timestamp_t", + pc.assume_timezone(t_in["8_timestamp_t"], timezone), + ) + t_in = t_in.cast( + t_in.schema.set( + t_in.schema.get_field_index("8_timestamp_t"), + pa.field("8_timestamp_t", pa.timestamp("us", tz="UTC")), + ) + ) + df = self.spark.createDataFrame(self.data, schema=self.schema) t_out = df.toArrow() + self.assertTrue(t_out.equals(t_in)) def test_pandas_self_destruct(self): @@ -437,6 +523,13 @@ def raise_exception(): with self.assertRaisesRegex(Exception, "My error"): df.toPandas() + def test_createDataFrame_arrow_pandas(self): + table = self.create_arrow_table() + pdf = self.create_pandas_data_frame() + df_arrow = self.spark.createDataFrame(table) + df_pandas = self.spark.createDataFrame(pdf) + self.assertEqual(df_arrow.collect(), df_pandas.collect()) + def _createDataFrame_toggle(self, data, schema=None): with self.sql_conf({"spark.sql.execution.arrow.pyspark.enabled": False}): df_no_arrow = self.spark.createDataFrame(data, schema=schema) @@ -450,12 +543,12 @@ def test_createDataFrame_toggle(self): df_no_arrow, df_arrow = self._createDataFrame_toggle(pdf, schema=self.schema) self.assertEqual(df_no_arrow.collect(), df_arrow.collect()) - def test_createDataFrame_respect_session_timezone(self): + def test_createDataFrame_pandas_respect_session_timezone(self): for arrow_enabled in [True, False]: with self.subTest(arrow_enabled=arrow_enabled): - self.check_createDataFrame_respect_session_timezone(arrow_enabled) + self.check_createDataFrame_pandas_respect_session_timezone(arrow_enabled) - def check_createDataFrame_respect_session_timezone(self, arrow_enabled): + def check_createDataFrame_pandas_respect_session_timezone(self, arrow_enabled): from datetime import timedelta pdf = self.create_pandas_data_frame() @@ -485,18 +578,46 @@ def check_createDataFrame_respect_session_timezone(self, arrow_enabled): ] self.assertEqual(result_ny, result_la_corrected) - def test_createDataFrame_with_schema(self): + def test_createDataFrame_arrow_respect_session_timezone(self): + from datetime import timedelta + + t = self.create_arrow_table() + timezone = "America/Los_Angeles" + with self.sql_conf({"spark.sql.session.timeZone": timezone}): + df_la = self.spark.createDataFrame(t, schema=self.schema) + result_la = df_la.collect() + + timezone = "America/New_York" + with self.sql_conf({"spark.sql.session.timeZone": timezone}): + df_ny = self.spark.createDataFrame(t, schema=self.schema) + result_ny = df_ny.collect() + + self.assertNotEqual(result_ny, result_la) + + # Correct result_la by adjusting 3 hours difference between Los Angeles and New York + result_la_corrected = [ + Row( + **{ + k: v - timedelta(hours=3) if k == "8_timestamp_t" else v + for k, v in row.asDict().items() + } + ) + for row in result_la + ] + self.assertEqual(result_ny, result_la_corrected) + + def test_createDataFrame_pandas_with_schema(self): pdf = self.create_pandas_data_frame() df = self.spark.createDataFrame(pdf, schema=self.schema) self.assertEqual(self.schema, df.schema) pdf_arrow = df.toPandas() assert_frame_equal(pdf_arrow, pdf) - def test_createDataFrame_with_incorrect_schema(self): + def test_createDataFrame_pandas_with_incorrect_schema(self): with self.quiet(): - self.check_createDataFrame_with_incorrect_schema() + self.check_createDataFrame_pandas_with_incorrect_schema() - def check_createDataFrame_with_incorrect_schema(self): + def check_createDataFrame_pandas_with_incorrect_schema(self): pdf = self.create_pandas_data_frame() fields = list(self.schema) fields[5], fields[6] = fields[6], fields[5] # swap decimal with date @@ -520,7 +641,15 @@ def check_createDataFrame_with_incorrect_schema(self): self.assertEqual(len(exception.args), 1) self.assertRegex(exception.args[0], "[D|d]ecimal.*got.*date") - def test_createDataFrame_with_names(self): + def test_createDataFrame_arrow_with_incorrect_schema(self): + t = self.create_arrow_table() + fields = list(self.schema) + fields[5], fields[6] = fields[6], fields[5] # swap decimal with date + wrong_schema = StructType(fields) + with self.assertRaises(Exception): + self.spark.createDataFrame(t, schema=wrong_schema) + + def test_createDataFrame_pandas_with_names(self): pdf = self.create_pandas_data_frame() new_names = list(map(str, range(len(self.schema.fieldNames())))) # Test that schema as a list of column names gets applied @@ -530,7 +659,17 @@ def test_createDataFrame_with_names(self): df = self.spark.createDataFrame(pdf, schema=tuple(new_names)) self.assertEqual(df.schema.fieldNames(), new_names) - def test_createDataFrame_column_name_encoding(self): + def test_createDataFrame_arrow_with_names(self): + t = self.create_arrow_table() + new_names = list(map(str, range(len(self.schema.fieldNames())))) + # Test that schema as a list of column names gets applied + df = self.spark.createDataFrame(t, schema=list(new_names)) + self.assertEqual(df.schema.fieldNames(), new_names) + # Test that schema as tuple of column names gets applied + df = self.spark.createDataFrame(t, schema=tuple(new_names)) + self.assertEqual(df.schema.fieldNames(), new_names) + + def test_createDataFrame_pandas_column_name_encoding(self): pdf = pd.DataFrame({"a": [1]}) columns = self.spark.createDataFrame(pdf).columns self.assertTrue(isinstance(columns[0], str)) @@ -539,6 +678,15 @@ def test_createDataFrame_column_name_encoding(self): self.assertTrue(isinstance(columns[0], str)) self.assertEqual(columns[0], "b") + def test_createDataFrame_arrow_column_name_encoding(self): + t = pa.table({"a": [1]}) + columns = self.spark.createDataFrame(t).columns + self.assertTrue(isinstance(columns[0], str)) + self.assertEqual(columns[0], "a") + columns = self.spark.createDataFrame(t, ["b"]).columns + self.assertTrue(isinstance(columns[0], str)) + self.assertEqual(columns[0], "b") + def test_createDataFrame_with_single_data_type(self): with self.quiet(): self.check_createDataFrame_with_single_data_type() @@ -566,6 +714,17 @@ def test_createDataFrame_does_not_modify_input(self): self.spark.createDataFrame(pdf, schema=self.schema) self.assertTrue(pdf.equals(pdf_copy)) + def test_createDataFrame_arrow_truncate_timestamp(self): + t_in = pa.Table.from_arrays( + [pa.array([1234567890123456789], type=pa.timestamp("ns", tz="UTC"))], names=["ts"] + ) + df = self.spark.createDataFrame(t_in) + t_out = df.toArrow() + expected = pa.Table.from_arrays( + [pa.array([1234567890123456], type=pa.timestamp("us", tz="UTC"))], names=["ts"] + ) + self.assertTrue(t_out.equals(expected)) + def test_schema_conversion_roundtrip(self): from pyspark.sql.pandas.types import from_arrow_schema, to_arrow_schema @@ -600,12 +759,12 @@ def check_createDataFrame_with_ndarray(self, arrow_enabled): ): self.spark.createDataFrame(np.array(0)) - def test_createDataFrame_with_array_type(self): + def test_createDataFrame_pandas_with_array_type(self): for arrow_enabled in [True, False]: with self.subTest(arrow_enabled=arrow_enabled): - self.check_createDataFrame_with_array_type(arrow_enabled) + self.check_createDataFrame_pandas_with_array_type(arrow_enabled) - def check_createDataFrame_with_array_type(self, arrow_enabled): + def check_createDataFrame_pandas_with_array_type(self, arrow_enabled): pdf = pd.DataFrame({"a": [[1, 2], [3, 4]], "b": [["x", "y"], ["y", "z"]]}) with self.sql_conf({"spark.sql.execution.arrow.pyspark.enabled": arrow_enabled}): df = self.spark.createDataFrame(pdf) @@ -615,6 +774,18 @@ def check_createDataFrame_with_array_type(self, arrow_enabled): for e in range(len(expected[r])): self.assertTrue(expected[r][e] == result[r][e]) + def test_createDataFrame_arrow_with_array_type_nulls(self): + t = pa.table({"a": [[1, 2], None, [3, 4]], "b": [["x", "y"], ["y", "z"], None]}) + df = self.spark.createDataFrame(t) + result = df.collect() + expected = [ + tuple(list(e) if e is not None else None for e in rec) + for rec in t.to_pandas().to_records(index=False) + ] + for r in range(len(expected)): + for e in range(len(expected[r])): + self.assertTrue(expected[r][e] == result[r][e]) + def test_toPandas_with_array_type(self): for arrow_enabled in [True, False]: with self.subTest(arrow_enabled=arrow_enabled): @@ -632,13 +803,28 @@ def check_toPandas_with_array_type(self, arrow_enabled): for e in range(len(expected[r])): self.assertTrue(expected[r][e] == result[r][e]) - def test_createDataFrame_with_map_type(self): + def test_toArrow_with_array_type_nulls(self): + expected = [([1, 2], ["x", "y"]), (None, ["y", "z"]), ([3, 4], None)] + array_schema = StructType( + [StructField("a", ArrayType(IntegerType())), StructField("b", ArrayType(StringType()))] + ) + df = self.spark.createDataFrame(expected, schema=array_schema) + t = df.toArrow() + result = [ + tuple(None if e is None else list(e) for e in rec) + for rec in t.to_pandas().to_records(index=False) + ] + for r in range(len(expected)): + for e in range(len(expected[r])): + self.assertTrue(expected[r][e] == result[r][e]) + + def test_createDataFrame_pandas_with_map_type(self): with self.quiet(): for arrow_enabled in [True, False]: with self.subTest(arrow_enabled=arrow_enabled): - self.check_createDataFrame_with_map_type(arrow_enabled) + self.check_createDataFrame_pandas_with_map_type(arrow_enabled) - def check_createDataFrame_with_map_type(self, arrow_enabled): + def check_createDataFrame_pandas_with_map_type(self, arrow_enabled): map_data = [{"a": 1}, {"b": 2, "c": 3}, {}, None, {"d": None}] pdf = pd.DataFrame({"id": [0, 1, 2, 3, 4], "m": map_data}) @@ -656,12 +842,52 @@ def check_createDataFrame_with_map_type(self, arrow_enabled): i, m = row self.assertEqual(m, map_data[i]) - def test_createDataFrame_with_struct_type(self): + def test_createDataFrame_arrow_with_map_type(self): + map_data = [{"a": 1}, {"b": 2, "c": 3}, {}, {}, {"d": None}] + + t = pa.table( + {"id": [0, 1, 2, 3, 4], "m": map_data}, + schema=pa.schema([("id", pa.int64()), ("m", pa.map_(pa.string(), pa.int64()))]), + ) + for schema in ( + "id long, m map", + StructType().add("id", LongType()).add("m", MapType(StringType(), LongType())), + ): + with self.subTest(schema=schema): + df = self.spark.createDataFrame(t, schema=schema) + + result = df.collect() + + for row in result: + i, m = row + self.assertEqual(m, map_data[i]) + + def test_createDataFrame_arrow_with_map_type_nulls(self): + map_data = [{"a": 1}, {"b": 2, "c": 3}, {}, None, {"d": None}] + + t = pa.table( + {"id": [0, 1, 2, 3, 4], "m": map_data}, + schema=pa.schema([("id", pa.int64()), ("m", pa.map_(pa.string(), pa.int64()))]), + ) + for schema in ( + "id long, m map", + StructType().add("id", LongType()).add("m", MapType(StringType(), LongType())), + ): + with self.subTest(schema=schema): + df = self.spark.createDataFrame(t, schema=schema) + + result = df.collect() + + for row in result: + i, m = row + self.assertEqual(m, map_data[i]) + + def test_createDataFrame_pandas_with_struct_type(self): for arrow_enabled in [True, False]: with self.subTest(arrow_enabled=arrow_enabled): - self.check_createDataFrame_with_struct_type(arrow_enabled) + self.check_createDataFrame_pandas_with_struct_type(arrow_enabled) - def check_createDataFrame_with_struct_type(self, arrow_enabled): + def check_createDataFrame_pandas_with_struct_type(self, arrow_enabled): pdf = pd.DataFrame( {"a": [Row(1, "a"), Row(2, "b")], "b": [{"s": 3, "t": "x"}, {"s": 4, "t": "y"}]} ) @@ -682,6 +908,42 @@ def check_createDataFrame_with_struct_type(self, arrow_enabled): expected[r][e] == result[r][e], f"{expected[r][e]} == {result[r][e]}" ) + def test_createDataFrame_pandas_with_struct_type(self): + for arrow_enabled in [True, False]: + with self.subTest(arrow_enabled=arrow_enabled): + self.check_createDataFrame_pandas_with_struct_type(arrow_enabled) + + def test_createDataFrame_arrow_with_struct_type_nulls(self): + t = pa.table( + { + "a": [{"x": 1, "y": "a"}, None, {"x": None, "y": "b"}], + "b": [{"s": 3, "t": None}, {"s": 4, "t": "y"}, None], + }, + ) + for schema in ( + "a struct, b struct", + StructType() + .add("a", StructType().add("x", LongType()).add("y", StringType())) + .add("b", StructType().add("s", LongType()).add("t", StringType())), + ): + with self.subTest(schema=schema): + df = self.spark.createDataFrame(t, schema) + result = df.collect() + expected = [ + ( + Row( + a=None if rec[0] is None else (Row(**rec[0])), + b=None if rec[1] is None else Row(**rec[1]), + ) + ) + for rec in t.to_pandas().to_records(index=False) + ] + for r in range(len(expected)): + for e in range(len(expected[r])): + self.assertTrue( + expected[r][e] == result[r][e], f"{expected[r][e]} == {result[r][e]}" + ) + def test_createDataFrame_with_string_dtype(self): # SPARK-34521: spark.createDataFrame does not support Pandas StringDtype extension type with self.sql_conf({"spark.sql.execution.arrow.pyspark.enabled": True}): @@ -725,6 +987,22 @@ def check_toPandas_with_map_type(self, arrow_enabled): pdf = df.toPandas() assert_frame_equal(origin, pdf) + def test_toArrow_with_map_type(self): + origin = pa.table( + {"id": [0, 1, 2, 3], "m": [{}, {"a": 1}, {"a": 1, "b": 2}, {"a": 1, "b": 2, "c": 3}]}, + schema=pa.schema( + [pa.field("id", pa.int64()), pa.field("m", pa.map_(pa.string(), pa.int64()), True)] + ), + ) + for schema in [ + "id long, m map", + StructType().add("id", LongType()).add("m", MapType(StringType(), LongType())), + ]: + df = self.spark.createDataFrame(origin, schema=schema) + + t = df.toArrow() + self.assertTrue(origin.equals(t)) + def test_toPandas_with_map_type_nulls(self): with self.quiet(): for arrow_enabled in [True, False]: @@ -747,12 +1025,29 @@ def check_toPandas_with_map_type_nulls(self, arrow_enabled): pdf = df.toPandas() assert_frame_equal(origin, pdf) - def test_createDataFrame_with_int_col_names(self): + def test_toArrow_with_map_type_nulls(self): + map_data = [{"a": 1}, {"b": 2, "c": 3}, {}, None, {"d": None}] + + origin = pa.table( + {"id": [0, 1, 2, 3, 4], "m": map_data}, + schema=pa.schema( + [pa.field("id", pa.int64()), pa.field("m", pa.map_(pa.string(), pa.int64()), True)] + ), + ) + for schema in [ + "id long, m map", + StructType().add("id", LongType()).add("m", MapType(StringType(), LongType())), + ]: + df = self.spark.createDataFrame(origin, schema=schema) + pdf = df.toArrow().to_pandas() + assert_frame_equal(origin.to_pandas(), pdf) + + def test_createDataFrame_pandas_with_int_col_names(self): for arrow_enabled in [True, False]: with self.subTest(arrow_enabled=arrow_enabled): - self.check_createDataFrame_with_int_col_names(arrow_enabled) + self.check_createDataFrame_pandas_with_int_col_names(arrow_enabled) - def check_createDataFrame_with_int_col_names(self, arrow_enabled): + def check_createDataFrame_pandas_with_int_col_names(self, arrow_enabled): import numpy as np pdf = pd.DataFrame(np.random.rand(4, 2)) @@ -761,6 +1056,13 @@ def check_createDataFrame_with_int_col_names(self, arrow_enabled): pdf_col_names = [str(c) for c in pdf.columns] self.assertEqual(pdf_col_names, df.columns) + def test_createDataFrame_arrow_with_int_col_names(self): + import numpy as np + + t = pa.table(pd.DataFrame(np.random.rand(4, 2))) + df = self.spark.createDataFrame(t) + self.assertEqual(t.schema.names, df.columns) + # Regression test for SPARK-23314 def test_timestamp_dst(self): # Daylight saving time for Los Angeles for 2015 is Sun, Nov 1 at 2:00 am @@ -845,6 +1147,23 @@ def test_createDataFrame_with_category_type(self): self.assertIsInstance(arrow_first_category_element, str) self.assertIsInstance(spark_first_category_element, str) + def test_createDataFrame_with_dictionary_type_nulls(self): + import pyarrow.compute as pc + + t = pa.table({"A": ["a", "b", "c", None, "a"]}) + t = t.add_column(1, "B", pc.dictionary_encode(t["A"])) + category_first_element = sorted(t["B"].combine_chunks().dictionary.to_pylist())[0] + + df = self.spark.createDataFrame(t) + type = df.dtypes[1][1] + result = df.toArrow() + result_first_category_element = result["B"][0].as_py() + + # ensure original category elements are string + self.assertIsInstance(category_first_element, str) + self.assertEqual(type, "string") + self.assertIsInstance(result_first_category_element, str) + def test_createDataFrame_with_float_index(self): # SPARK-32098: float index should not produce duplicated or truncated Spark DataFrame self.assertEqual( @@ -883,6 +1202,10 @@ def check_toPandas_error(self, arrow_enabled): with self.assertRaises(ArithmeticException): self.spark.sql("select 1/0").toPandas() + def test_toArrow_error(self): + with self.assertRaises(ArithmeticException): + self.spark.sql("select 1/0").toArrow() + def test_toPandas_duplicate_field_names(self): for arrow_enabled in [True, False]: with self.subTest(arrow_enabled=arrow_enabled): @@ -937,12 +1260,45 @@ def check_toPandas_duplicate_field_names(self, arrow_enabled): expected = pd.DataFrame.from_records(data, columns=schema.names) assert_frame_equal(df.toPandas(), expected) - def test_createDataFrame_duplicate_field_names(self): + def test_toArrow_duplicate_field_names(self): + data = [[1, 1], [2, 2]] + names = ["a", "a"] + df = self.spark.createDataFrame(data, names) + + expected = pa.table( + [[1, 2], [1, 2]], + schema=pa.schema([pa.field("a", pa.int64()), pa.field("a", pa.int64())]), + ) + + self.assertTrue(df.toArrow().equals(expected)) + + data = [Row(Row("a", 1), Row(2, 3, "b", 4, "c")), Row(Row("x", 6), Row(7, 8, "y", 9, "z"))] + schema = ( + StructType() + .add("struct", StructType().add("x", StringType()).add("x", IntegerType())) + .add( + "struct", + StructType() + .add("a", IntegerType()) + .add("x", IntegerType()) + .add("x", StringType()) + .add("y", IntegerType()) + .add("y", StringType()), + ) + ) + df = self.spark.createDataFrame(data, schema=schema) + + with self.assertRaisesRegex( + UnsupportedOperationException, "DUPLICATED_FIELD_NAME_IN_ARROW_STRUCT" + ): + df.toArrow() + + def test_createDataFrame_pandas_duplicate_field_names(self): for arrow_enabled in [True, False]: with self.subTest(arrow_enabled=arrow_enabled): - self.check_createDataFrame_duplicate_field_names(arrow_enabled) + self.check_createDataFrame_pandas_duplicate_field_names(arrow_enabled) - def check_createDataFrame_duplicate_field_names(self, arrow_enabled): + def check_createDataFrame_pandas_duplicate_field_names(self, arrow_enabled): schema = ( StructType() .add("struct", StructType().add("x", StringType()).add("x", IntegerType())) @@ -965,6 +1321,66 @@ def check_createDataFrame_duplicate_field_names(self, arrow_enabled): self.assertEqual(df.collect(), data) + def test_createDataFrame_arrow_duplicate_field_names(self): + t = pa.table( + [[1, 2], [1, 2]], + schema=pa.schema([pa.field("a", pa.int64()), pa.field("a", pa.int64())]), + ) + schema = StructType().add("a", LongType()).add("a", LongType()) + + df = self.spark.createDataFrame(t) + + self.assertTrue(df.toArrow().equals(t)) + + df = self.spark.createDataFrame(t, schema=schema) + + self.assertTrue(df.toArrow().equals(t)) + + t = pa.table( + [ + pa.StructArray.from_arrays( + [ + pa.array(["a", "x"], type=pa.string()), + pa.array([1, 6], type=pa.int32()), + ], + names=["x", "x"], + ), + pa.StructArray.from_arrays( + [ + pa.array([2, 7], type=pa.int32()), + pa.array([3, 8], type=pa.int32()), + pa.array(["b", "y"], type=pa.string()), + pa.array([4, 9], type=pa.int32()), + pa.array(["c", "z"], type=pa.string()), + ], + names=["a", "x", "x", "y", "y"], + ), + ], + names=["struct", "struct"], + ) + schema = ( + StructType() + .add("struct", StructType().add("x", StringType()).add("x", IntegerType())) + .add( + "struct", + StructType() + .add("a", IntegerType()) + .add("x", IntegerType()) + .add("x", StringType()) + .add("y", IntegerType()) + .add("y", StringType()), + ) + ) + with self.assertRaisesRegex( + UnsupportedOperationException, "DUPLICATED_FIELD_NAME_IN_ARROW_STRUCT" + ): + self.spark.createDataFrame(t) + + with self.assertRaisesRegex( + UnsupportedOperationException, "DUPLICATED_FIELD_NAME_IN_ARROW_STRUCT" + ): + self.spark.createDataFrame(t, schema) + def test_toPandas_empty_columns(self): for arrow_enabled in [True, False]: with self.subTest(arrow_enabled=arrow_enabled): @@ -976,38 +1392,39 @@ def check_toPandas_empty_columns(self, arrow_enabled): with self.sql_conf({"spark.sql.execution.arrow.pyspark.enabled": arrow_enabled}): assert_frame_equal(df.toPandas(), pd.DataFrame(columns=[], index=range(2))) - def test_createDataFrame_nested_timestamp(self): + def test_toArrow_empty_columns(self): + df = self.spark.range(2).select([]) + + self.assertTrue(df.toArrow().equals(pa.table([]))) + + def test_toPandas_empty_rows(self): for arrow_enabled in [True, False]: with self.subTest(arrow_enabled=arrow_enabled): - self.check_createDataFrame_nested_timestamp(arrow_enabled) + self.check_toPandas_empty_rows(arrow_enabled) - def check_createDataFrame_nested_timestamp(self, arrow_enabled): - schema = ( - StructType() - .add("ts", TimestampType()) - .add("ts_ntz", TimestampNTZType()) - .add( - "struct", StructType().add("ts", TimestampType()).add("ts_ntz", TimestampNTZType()) + def check_toPandas_empty_rows(self, arrow_enabled): + df = self.spark.range(2).limit(0) + + with self.sql_conf({"spark.sql.execution.arrow.pyspark.enabled": arrow_enabled}): + assert_frame_equal(df.toPandas(), pd.DataFrame({"id": pd.Series([], dtype="int64")})) + + def test_toArrow_empty_rows(self): + df = self.spark.range(2).limit(0) + + self.assertTrue( + df.toArrow().equals( + pa.Table.from_arrays([[]], schema=pa.schema([pa.field("id", pa.int64(), False)])) ) - .add("array", ArrayType(TimestampType())) - .add("array_ntz", ArrayType(TimestampNTZType())) - .add("map", MapType(StringType(), TimestampType())) - .add("map_ntz", MapType(StringType(), TimestampNTZType())) ) - data = [ - Row( - datetime.datetime(2023, 1, 1, 0, 0, 0), - datetime.datetime(2023, 1, 1, 0, 0, 0), - Row( - datetime.datetime(2023, 1, 1, 0, 0, 0), - datetime.datetime(2023, 1, 1, 0, 0, 0), - ), - [datetime.datetime(2023, 1, 1, 0, 0, 0)], - [datetime.datetime(2023, 1, 1, 0, 0, 0)], - dict(ts=datetime.datetime(2023, 1, 1, 0, 0, 0)), - dict(ts_ntz=datetime.datetime(2023, 1, 1, 0, 0, 0)), - ) - ] + + def test_createDataFrame_pandas_nested_timestamp(self): + for arrow_enabled in [True, False]: + with self.subTest(arrow_enabled=arrow_enabled): + self.check_createDataFrame_pandas_nested_timestamp(arrow_enabled) + + def check_createDataFrame_pandas_nested_timestamp(self, arrow_enabled): + schema = self.schema_nested_timestamp + data = self.data_nested_timestamp pdf = pd.DataFrame.from_records(data, columns=schema.names) with self.sql_conf( @@ -1018,18 +1435,23 @@ def check_createDataFrame_nested_timestamp(self, arrow_enabled): ): df = self.spark.createDataFrame(pdf, schema) - expected = Row( - ts=datetime.datetime(2022, 12, 31, 21, 0, 0), - ts_ntz=datetime.datetime(2023, 1, 1, 0, 0, 0), - struct=Row( - ts=datetime.datetime(2022, 12, 31, 21, 0, 0), - ts_ntz=datetime.datetime(2023, 1, 1, 0, 0, 0), - ), - array=[datetime.datetime(2022, 12, 31, 21, 0, 0)], - array_ntz=[datetime.datetime(2023, 1, 1, 0, 0, 0)], - map=dict(ts=datetime.datetime(2022, 12, 31, 21, 0, 0)), - map_ntz=dict(ts_ntz=datetime.datetime(2023, 1, 1, 0, 0, 0)), - ) + expected = self.data_nested_timestamp_expected_ny + + self.assertEqual(df.first(), expected) + + def test_createDataFrame_arrow_nested_timestamp(self): + from pyspark.sql.pandas.types import to_arrow_schema + + schema = self.schema_nested_timestamp + data = self.data_nested_timestamp + pdf = pd.DataFrame.from_records(data, columns=schema.names) + arrow_schema = to_arrow_schema(schema, timestamp_utc=False) + t = pa.Table.from_pandas(pdf, arrow_schema) + + with self.sql_conf({"spark.sql.session.timeZone": "America/New_York"}): + df = self.spark.createDataFrame(t, schema) + + expected = self.data_nested_timestamp_expected_ny self.assertEqual(df.first(), expected) @@ -1066,32 +1488,8 @@ def test_toPandas_nested_timestamp(self): self.check_toPandas_nested_timestamp(arrow_enabled) def check_toPandas_nested_timestamp(self, arrow_enabled): - schema = ( - StructType() - .add("ts", TimestampType()) - .add("ts_ntz", TimestampNTZType()) - .add( - "struct", StructType().add("ts", TimestampType()).add("ts_ntz", TimestampNTZType()) - ) - .add("array", ArrayType(TimestampType())) - .add("array_ntz", ArrayType(TimestampNTZType())) - .add("map", MapType(StringType(), TimestampType())) - .add("map_ntz", MapType(StringType(), TimestampNTZType())) - ) - data = [ - Row( - datetime.datetime(2023, 1, 1, 0, 0, 0), - datetime.datetime(2023, 1, 1, 0, 0, 0), - Row( - datetime.datetime(2023, 1, 1, 0, 0, 0), - datetime.datetime(2023, 1, 1, 0, 0, 0), - ), - [datetime.datetime(2023, 1, 1, 0, 0, 0)], - [datetime.datetime(2023, 1, 1, 0, 0, 0)], - dict(ts=datetime.datetime(2023, 1, 1, 0, 0, 0)), - dict(ts_ntz=datetime.datetime(2023, 1, 1, 0, 0, 0)), - ) - ] + schema = self.schema_nested_timestamp + data = self.data_nested_timestamp df = self.spark.createDataFrame(data, schema) with self.sql_conf( @@ -1122,6 +1520,46 @@ def check_toPandas_nested_timestamp(self, arrow_enabled): assert_frame_equal(pdf, expected) + def test_toArrow_nested_timestamp(self): + schema = self.schema_nested_timestamp + data = self.data_nested_timestamp + df = self.spark.createDataFrame(data, schema) + + t = df.toArrow() + + from pyspark.sql.pandas.types import to_arrow_schema + + arrow_schema = to_arrow_schema(schema) + expected = pa.Table.from_pydict( + { + "ts": [datetime.datetime(2023, 1, 1, 8, 0, 0)], + "ts_ntz": [datetime.datetime(2023, 1, 1, 0, 0, 0)], + "struct": [ + Row( + datetime.datetime(2023, 1, 1, 8, 0, 0), + datetime.datetime(2023, 1, 1, 0, 0, 0), + ) + ], + "array": [[datetime.datetime(2023, 1, 1, 8, 0, 0)]], + "array_ntz": [[datetime.datetime(2023, 1, 1, 0, 0, 0)]], + "map": [dict(ts=datetime.datetime(2023, 1, 1, 8, 0, 0))], + "map_ntz": [dict(ts_ntz=datetime.datetime(2023, 1, 1, 0, 0, 0))], + }, + schema=arrow_schema, + ) + + self.assertTrue(t.equals(expected)) + + @unittest.skip("SPARK-48302: Nulls are replaced with empty lists") + def test_arrow_map_timestamp_nulls_round_trip(self): + origin = pa.table( + [[dict(ts=datetime.datetime(2023, 1, 1, 8, 0, 0)), None]], + schema=pa.schema([("map", pa.map_(pa.string(), pa.timestamp("us", tz="UTC")))]), + ) + df = self.spark.createDataFrame(origin) + t = df.toArrow() + self.assertTrue(origin.equals(t)) + def test_createDataFrame_udt(self): for arrow_enabled in [True, False]: with self.subTest(arrow_enabled=arrow_enabled): diff --git a/python/pyspark/sql/tests/typing/test_session.yml b/python/pyspark/sql/tests/typing/test_session.yml index 8f48edb7e579e..d6eee82a7678e 100644 --- a/python/pyspark/sql/tests/typing/test_session.yml +++ b/python/pyspark/sql/tests/typing/test_session.yml @@ -51,25 +51,6 @@ spark.createDataFrame(["foo", "bar"], "string") -- case: createDataFrameScalarsInvalid - main: | - from pyspark.sql import SparkSession - from pyspark.sql.types import StructType, StructField, StringType, IntegerType - - spark = SparkSession.builder.getOrCreate() - - schema = StructType([ - StructField("name", StringType(), True), - StructField("age", IntegerType(), True) - ]) - - # Invalid - scalars require schema - spark.createDataFrame(["foo", "bar"]) # E: Value of type variable "RowLike" of "createDataFrame" of "SparkSession" cannot be "str" [type-var] - - # Invalid - data has to match schema (either product -> struct or scalar -> atomic) - spark.createDataFrame([1, 2, 3], schema) # E: Value of type variable "RowLike" of "createDataFrame" of "SparkSession" cannot be "int" [type-var] - - - case: createDataFrameStructsInvalid main: | from pyspark.sql import SparkSession @@ -102,7 +83,9 @@ main:18: note: def [AtomicValue in (datetime, date, Decimal, bool, str, int, float)] createDataFrame(self, data: RDD[AtomicValue], schema: Union[AtomicType, str], verifySchema: bool = ...) -> DataFrame main:18: note: def [AtomicValue in (datetime, date, Decimal, bool, str, int, float)] createDataFrame(self, data: Iterable[AtomicValue], schema: Union[AtomicType, str], verifySchema: bool = ...) -> DataFrame main:18: note: def createDataFrame(self, data: DataFrame, samplingRatio: Optional[float] = ...) -> DataFrame + main:18: note: def createDataFrame(self, data: Any, samplingRatio: Optional[float] = ...) -> DataFrame main:18: note: def createDataFrame(self, data: DataFrame, schema: Union[StructType, str], verifySchema: bool = ...) -> DataFrame + main:18: note: def createDataFrame(self, data: Any, schema: Union[StructType, str], verifySchema: bool = ...) -> DataFrame - case: createDataFrameFromEmptyRdd main: |