Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-26887][SQL][PYTHON] Create datetime.date directly instead of creating datetime64[ns] as intermediate data. #23795

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions python/pyspark/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,10 +311,9 @@ def __init__(self, timezone, safecheck):

def arrow_to_pandas(self, arrow_column):
from pyspark.sql.types import from_arrow_type, \
_check_series_convert_date, _check_series_localize_timestamps
_arrow_column_to_pandas, _check_series_localize_timestamps

s = arrow_column.to_pandas()
s = _check_series_convert_date(s, from_arrow_type(arrow_column.type))
s = _arrow_column_to_pandas(arrow_column, from_arrow_type(arrow_column.type))
s = _check_series_localize_timestamps(s, self._timezone)
return s

Expand Down
5 changes: 2 additions & 3 deletions python/pyspark/sql/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -2107,14 +2107,13 @@ def toPandas(self):
# of PyArrow is found, if 'spark.sql.execution.arrow.enabled' is enabled.
if use_arrow:
try:
from pyspark.sql.types import _check_dataframe_convert_date, \
from pyspark.sql.types import _arrow_table_to_pandas, \
_check_dataframe_localize_timestamps
import pyarrow
batches = self._collectAsArrow()
if len(batches) > 0:
table = pyarrow.Table.from_batches(batches)
pdf = table.to_pandas()
pdf = _check_dataframe_convert_date(pdf, self.schema)
pdf = _arrow_table_to_pandas(table, self.schema)
return _check_dataframe_localize_timestamps(pdf, timezone)
else:
return pd.DataFrame.from_records([], columns=self.columns)
Expand Down
5 changes: 4 additions & 1 deletion python/pyspark/sql/tests/test_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,14 +68,17 @@ def setUpClass(cls):
(u"b", 2, 20, 0.4, 4.0, Decimal("4.0"),
date(2012, 2, 2), datetime(2012, 2, 2, 2, 2, 2)),
(u"c", 3, 30, 0.8, 6.0, Decimal("6.0"),
date(2100, 3, 3), datetime(2100, 3, 3, 3, 3, 3))]
date(2100, 3, 3), datetime(2100, 3, 3, 3, 3, 3)),
(u"d", 4, 40, 1.0, 8.0, Decimal("8.0"),
date(2262, 4, 12), datetime(2262, 3, 3, 3, 3, 3))]

# TODO: remove version check once minimum pyarrow version is 0.10.0
if LooseVersion("0.10.0") <= LooseVersion(pa.__version__):
cls.schema.add(StructField("9_binary_t", BinaryType(), True))
cls.data[0] = cls.data[0] + (bytearray(b"a"),)
cls.data[1] = cls.data[1] + (bytearray(b"bb"),)
cls.data[2] = cls.data[2] + (bytearray(b"ccc"),)
cls.data[3] = cls.data[3] + (bytearray(b"dddd"),)

@classmethod
def tearDownClass(cls):
Expand Down
3 changes: 2 additions & 1 deletion python/pyspark/sql/tests/test_pandas_udf_scalar.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,7 +349,8 @@ def test_vectorized_udf_dates(self):
data = [(0, date(1969, 1, 1),),
(1, date(2012, 2, 2),),
(2, None,),
(3, date(2100, 4, 4),)]
(3, date(2100, 4, 4),),
(4, date(2262, 4, 12),)]
df = self.spark.createDataFrame(data, schema=schema)

date_copy = pandas_udf(lambda t: t, returnType=DateType())
Expand Down
54 changes: 34 additions & 20 deletions python/pyspark/sql/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -1681,38 +1681,52 @@ def from_arrow_schema(arrow_schema):
for field in arrow_schema])


def _check_series_convert_date(series, data_type):
"""
Cast the series to datetime.date if it's a date type, otherwise returns the original series.
def _arrow_column_to_pandas(column, data_type):
""" Convert Arrow Column to pandas Series.

:param series: pandas.Series
:param data_type: a Spark data type for the series
:param series: pyarrow.lib.Column
:param data_type: a Spark data type for the column
"""
import pyarrow
import pandas as pd
import pyarrow as pa
from distutils.version import LooseVersion
# As of Arrow 0.12.0, date_as_objects is True by default, see ARROW-3910
if LooseVersion(pyarrow.__version__) < LooseVersion("0.12.0") and type(data_type) == DateType:
return series.dt.date
# If the given column is a date type column, creates a series of datetime.date directly instead
# of creating datetime64[ns] as intermediate data to avoid overflow caused by datetime64[ns]
# type handling.
if LooseVersion(pa.__version__) < LooseVersion("0.11.0"):
if type(data_type) == DateType:
return pd.Series(column.to_pylist(), name=column.name)
else:
return column.to_pandas()
else:
return series
# Since Arrow 0.11.0, support date_as_object to return datetime.date instead of
# np.datetime64.
return column.to_pandas(date_as_object=True)


def _check_dataframe_convert_date(pdf, schema):
""" Correct date type value to use datetime.date.
def _arrow_table_to_pandas(table, schema):
""" Convert Arrow Table to pandas DataFrame.

Pandas DataFrame created from PyArrow uses datetime64[ns] for date type values, but we should
use datetime.date to match the behavior with when Arrow optimization is disabled.

:param pdf: pandas.DataFrame
:param schema: a Spark schema of the pandas.DataFrame
:param table: pyarrow.lib.Table
:param schema: a Spark schema of the pyarrow.lib.Table
"""
import pyarrow
import pandas as pd
import pyarrow as pa
from distutils.version import LooseVersion
# As of Arrow 0.12.0, date_as_objects is True by default, see ARROW-3910
if LooseVersion(pyarrow.__version__) < LooseVersion("0.12.0"):
for field in schema:
pdf[field.name] = _check_series_convert_date(pdf[field.name], field.dataType)
return pdf
# If the given table contains a date type column, use `_arrow_column_to_pandas` for pyarrow<0.11
# or use `date_as_object` option for pyarrow>=0.11 to avoid creating datetime64[ns] as
# intermediate data.
if LooseVersion(pa.__version__) < LooseVersion("0.11.0"):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good @ueshin.

@ueshin, @BryanCutler , BTW, which version of PyArrow do you think we should bump up to in Spark 3.0.0? I was thinking about matching it to 0.12.0, or 0.11.0. I think it's overhead that we should test all the pyarrow versions.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be nice to bump to 0.12.0 because I think that would allow us to clean up the code the most, but since it's a raised error if the user doesn't have that version, it might too restrictive. Let's definitely make a JIRA to discuss more.

if any(type(field.dataType) == DateType for field in schema):
return pd.concat([_arrow_column_to_pandas(column, field.dataType)
for column, field in zip(table.itercolumns(), schema)], axis=1)
else:
return table.to_pandas()
else:
return table.to_pandas(date_as_object=True)


def _get_local_timezone():
Expand Down