diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index 3db259551fa8b..a2c59fedfc8cd 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -311,10 +311,9 @@ def __init__(self, timezone, safecheck): def arrow_to_pandas(self, arrow_column): from pyspark.sql.types import from_arrow_type, \ - _check_series_convert_date, _check_series_localize_timestamps + _arrow_column_to_pandas, _check_series_localize_timestamps - s = arrow_column.to_pandas() - s = _check_series_convert_date(s, from_arrow_type(arrow_column.type)) + s = _arrow_column_to_pandas(arrow_column, from_arrow_type(arrow_column.type)) s = _check_series_localize_timestamps(s, self._timezone) return s diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index a1056d0b787e3..472d2969b3e19 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -2107,14 +2107,13 @@ def toPandas(self): # of PyArrow is found, if 'spark.sql.execution.arrow.enabled' is enabled. if use_arrow: try: - from pyspark.sql.types import _check_dataframe_convert_date, \ + from pyspark.sql.types import _arrow_table_to_pandas, \ _check_dataframe_localize_timestamps import pyarrow batches = self._collectAsArrow() if len(batches) > 0: table = pyarrow.Table.from_batches(batches) - pdf = table.to_pandas() - pdf = _check_dataframe_convert_date(pdf, self.schema) + pdf = _arrow_table_to_pandas(table, self.schema) return _check_dataframe_localize_timestamps(pdf, timezone) else: return pd.DataFrame.from_records([], columns=self.columns) diff --git a/python/pyspark/sql/tests/test_arrow.py b/python/pyspark/sql/tests/test_arrow.py index 8a62500b17f27..38a6402c01322 100644 --- a/python/pyspark/sql/tests/test_arrow.py +++ b/python/pyspark/sql/tests/test_arrow.py @@ -68,7 +68,9 @@ def setUpClass(cls): (u"b", 2, 20, 0.4, 4.0, Decimal("4.0"), date(2012, 2, 2), datetime(2012, 2, 2, 2, 2, 2)), (u"c", 3, 30, 0.8, 6.0, Decimal("6.0"), - date(2100, 3, 3), datetime(2100, 3, 3, 3, 3, 3))] + date(2100, 3, 3), datetime(2100, 3, 3, 3, 3, 3)), + (u"d", 4, 40, 1.0, 8.0, Decimal("8.0"), + date(2262, 4, 12), datetime(2262, 3, 3, 3, 3, 3))] # TODO: remove version check once minimum pyarrow version is 0.10.0 if LooseVersion("0.10.0") <= LooseVersion(pa.__version__): @@ -76,6 +78,7 @@ def setUpClass(cls): cls.data[0] = cls.data[0] + (bytearray(b"a"),) cls.data[1] = cls.data[1] + (bytearray(b"bb"),) cls.data[2] = cls.data[2] + (bytearray(b"ccc"),) + cls.data[3] = cls.data[3] + (bytearray(b"dddd"),) @classmethod def tearDownClass(cls): diff --git a/python/pyspark/sql/tests/test_pandas_udf_scalar.py b/python/pyspark/sql/tests/test_pandas_udf_scalar.py index 6a6865a9fb16d..28ef98d7b3f1e 100644 --- a/python/pyspark/sql/tests/test_pandas_udf_scalar.py +++ b/python/pyspark/sql/tests/test_pandas_udf_scalar.py @@ -349,7 +349,8 @@ def test_vectorized_udf_dates(self): data = [(0, date(1969, 1, 1),), (1, date(2012, 2, 2),), (2, None,), - (3, date(2100, 4, 4),)] + (3, date(2100, 4, 4),), + (4, date(2262, 4, 12),)] df = self.spark.createDataFrame(data, schema=schema) date_copy = pandas_udf(lambda t: t, returnType=DateType()) diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index 4b8f2efff4acc..348cb5b118594 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -1681,38 +1681,52 @@ def from_arrow_schema(arrow_schema): for field in arrow_schema]) -def _check_series_convert_date(series, data_type): - """ - Cast the series to datetime.date if it's a date type, otherwise returns the original series. +def _arrow_column_to_pandas(column, data_type): + """ Convert Arrow Column to pandas Series. - :param series: pandas.Series - :param data_type: a Spark data type for the series + :param series: pyarrow.lib.Column + :param data_type: a Spark data type for the column """ - import pyarrow + import pandas as pd + import pyarrow as pa from distutils.version import LooseVersion - # As of Arrow 0.12.0, date_as_objects is True by default, see ARROW-3910 - if LooseVersion(pyarrow.__version__) < LooseVersion("0.12.0") and type(data_type) == DateType: - return series.dt.date + # If the given column is a date type column, creates a series of datetime.date directly instead + # of creating datetime64[ns] as intermediate data to avoid overflow caused by datetime64[ns] + # type handling. + if LooseVersion(pa.__version__) < LooseVersion("0.11.0"): + if type(data_type) == DateType: + return pd.Series(column.to_pylist(), name=column.name) + else: + return column.to_pandas() else: - return series + # Since Arrow 0.11.0, support date_as_object to return datetime.date instead of + # np.datetime64. + return column.to_pandas(date_as_object=True) -def _check_dataframe_convert_date(pdf, schema): - """ Correct date type value to use datetime.date. +def _arrow_table_to_pandas(table, schema): + """ Convert Arrow Table to pandas DataFrame. Pandas DataFrame created from PyArrow uses datetime64[ns] for date type values, but we should use datetime.date to match the behavior with when Arrow optimization is disabled. - :param pdf: pandas.DataFrame - :param schema: a Spark schema of the pandas.DataFrame + :param table: pyarrow.lib.Table + :param schema: a Spark schema of the pyarrow.lib.Table """ - import pyarrow + import pandas as pd + import pyarrow as pa from distutils.version import LooseVersion - # As of Arrow 0.12.0, date_as_objects is True by default, see ARROW-3910 - if LooseVersion(pyarrow.__version__) < LooseVersion("0.12.0"): - for field in schema: - pdf[field.name] = _check_series_convert_date(pdf[field.name], field.dataType) - return pdf + # If the given table contains a date type column, use `_arrow_column_to_pandas` for pyarrow<0.11 + # or use `date_as_object` option for pyarrow>=0.11 to avoid creating datetime64[ns] as + # intermediate data. + if LooseVersion(pa.__version__) < LooseVersion("0.11.0"): + if any(type(field.dataType) == DateType for field in schema): + return pd.concat([_arrow_column_to_pandas(column, field.dataType) + for column, field in zip(table.itercolumns(), schema)], axis=1) + else: + return table.to_pandas() + else: + return table.to_pandas(date_as_object=True) def _get_local_timezone():