Skip to content

Commit

Permalink
Add support for StructType return in Scalar Pandas UDF (#582)
Browse files Browse the repository at this point in the history
## Upstream SPARK-XXXXX ticket and PR link (if not applicable, explain)

- [[SPARK-26887][SQL][PYTHON][NS] Create datetime.date directly insteadof creating datetime64 as intermediate data.](02f03b7) (merged Feb 18)

- [[SPARK-23836][PYTHON] Add support for StructType return in Scalar Pandas UDF](e8193ed) (merged Mar 7) <= Main Commit

- [[SPARK-27163][PYTHON] Cleanup and consolidate Pandas UDF functionality](8d69b8c) (merged Mar 21)

- [[SPARK-27240][PYTHON] Use pandas DataFrame for struct type argument in Scalar Pandas UDF.](adb3a01) (merged Mar 25)

## What changes were proposed in this pull request?

This is a series of 4 cherry-picks all related to pandas_udfs. The cherry-picks denote an isolated code change and were applied without any conflicts / manual edits. 
They allow to accept and return StructTypes in pandas_udfs. This is heavily required by the concept in FoundryML that all Stages can act on pandas or pyspark dataframes interchangeably.
  • Loading branch information
simon-slowik authored and bulldozer-bot[bot] committed Jul 1, 2019
1 parent eca4aa8 commit 7ca6659
Show file tree
Hide file tree
Showing 12 changed files with 387 additions and 142 deletions.
223 changes: 146 additions & 77 deletions python/pyspark/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@
else:
import pickle
protocol = 3
basestring = unicode = str
xrange = range

from pyspark import cloudpickle
Expand Down Expand Up @@ -245,112 +246,180 @@ def __repr__(self):
return "ArrowStreamSerializer"


def _create_batch(series, timezone, safecheck):
class ArrowStreamPandasSerializer(ArrowStreamSerializer):
"""
Create an Arrow record batch from the given pandas.Series or list of Series, with optional type.
Serializes Pandas.Series as Arrow data with Arrow streaming format.
:param series: A single pandas.Series, list of Series, or list of (series, arrow_type)
:param timezone: A timezone to respect when handling timestamp values
:return: Arrow RecordBatch
"""
import decimal
from distutils.version import LooseVersion
import pyarrow as pa
from pyspark.sql.types import _check_series_convert_timestamps_internal
# Make input conform to [(series1, type1), (series2, type2), ...]
if not isinstance(series, (list, tuple)) or \
(len(series) == 2 and isinstance(series[1], pa.DataType)):
series = [series]
series = ((s, None) if not isinstance(s, (list, tuple)) else s for s in series)

def create_array(s, t):
mask = s.isnull()
# Ensure timestamp series are in expected form for Spark internal representation
# TODO: maybe don't need None check anymore as of Arrow 0.9.1
if t is not None and pa.types.is_timestamp(t):
s = _check_series_convert_timestamps_internal(s.fillna(0), timezone)
# TODO: need cast after Arrow conversion, ns values cause error with pandas 0.19.2
return pa.Array.from_pandas(s, mask=mask).cast(t, safe=False)
elif t is not None and pa.types.is_string(t) and sys.version < '3':
# TODO: need decode before converting to Arrow in Python 2
# TODO: don't need as of Arrow 0.9.1
return pa.Array.from_pandas(s.apply(
lambda v: v.decode("utf-8") if isinstance(v, str) else v), mask=mask, type=t)
elif t is not None and pa.types.is_decimal(t) and \
LooseVersion("0.9.0") <= LooseVersion(pa.__version__) < LooseVersion("0.10.0"):
# TODO: see ARROW-2432. Remove when the minimum PyArrow version becomes 0.10.0.
return pa.Array.from_pandas(s.apply(
lambda v: decimal.Decimal('NaN') if v is None else v), mask=mask, type=t)
elif LooseVersion(pa.__version__) < LooseVersion("0.11.0"):
# TODO: see ARROW-1949. Remove when the minimum PyArrow version becomes 0.11.0.
return pa.Array.from_pandas(s, mask=mask, type=t)

try:
array = pa.Array.from_pandas(s, mask=mask, type=t, safe=safecheck)
except pa.ArrowException as e:
error_msg = "Exception thrown when converting pandas.Series (%s) to Arrow " + \
"Array (%s). It can be caused by overflows or other unsafe " + \
"conversions warned by Arrow. Arrow safe type check can be " + \
"disabled by using SQL config " + \
"`spark.sql.execution.pandas.arrowSafeTypeConversion`."
raise RuntimeError(error_msg % (s.dtype, t), e)
return array

arrs = [create_array(s, t) for s, t in series]
return pa.RecordBatch.from_arrays(arrs, ["_%d" % i for i in xrange(len(arrs))])


class ArrowStreamPandasSerializer(Serializer):
"""
Serializes Pandas.Series as Arrow data with Arrow streaming format.
:param safecheck: If True, conversion from Arrow to Pandas checks for overflow/truncation
:param assign_cols_by_name: If True, then Pandas DataFrames will get columns by name
"""

def __init__(self, timezone, safecheck):
def __init__(self, timezone, safecheck, assign_cols_by_name):
super(ArrowStreamPandasSerializer, self).__init__()
self._timezone = timezone
self._safecheck = safecheck
self._assign_cols_by_name = assign_cols_by_name

def arrow_to_pandas(self, arrow_column):
from pyspark.sql.types import from_arrow_type, \
_check_series_convert_date, _check_series_localize_timestamps
def arrow_to_pandas(self, arrow_column, data_type):
from pyspark.sql.types import _arrow_column_to_pandas, _check_series_localize_timestamps

s = arrow_column.to_pandas()
s = _check_series_convert_date(s, from_arrow_type(arrow_column.type))
s = _arrow_column_to_pandas(arrow_column, data_type)
s = _check_series_localize_timestamps(s, self._timezone)
return s

def _create_batch(self, series):
"""
Create an Arrow record batch from the given pandas.Series or list of Series,
with optional type.
:param series: A single pandas.Series, list of Series, or list of (series, arrow_type)
:return: Arrow RecordBatch
"""
import decimal
from distutils.version import LooseVersion
import pandas as pd
import pyarrow as pa
from pyspark.sql.types import _check_series_convert_timestamps_internal
# Make input conform to [(series1, type1), (series2, type2), ...]
if not isinstance(series, (list, tuple)) or \
(len(series) == 2 and isinstance(series[1], pa.DataType)):
series = [series]
series = ((s, None) if not isinstance(s, (list, tuple)) else s for s in series)

def create_array(s, t):
mask = s.isnull()
# Ensure timestamp series are in expected form for Spark internal representation
# TODO: maybe don't need None check anymore as of Arrow 0.9.1
if t is not None and pa.types.is_timestamp(t):
s = _check_series_convert_timestamps_internal(s.fillna(0), self._timezone)
# TODO: need cast after Arrow conversion, ns values cause error with pandas 0.19.2
return pa.Array.from_pandas(s, mask=mask).cast(t, safe=False)
elif t is not None and pa.types.is_string(t) and sys.version < '3':
# TODO: need decode before converting to Arrow in Python 2
# TODO: don't need as of Arrow 0.9.1
return pa.Array.from_pandas(s.apply(
lambda v: v.decode("utf-8") if isinstance(v, str) else v), mask=mask, type=t)
elif t is not None and pa.types.is_decimal(t) and \
LooseVersion("0.9.0") <= LooseVersion(pa.__version__) < LooseVersion("0.10.0"):
# TODO: see ARROW-2432. Remove when the minimum PyArrow version becomes 0.10.0.
return pa.Array.from_pandas(s.apply(
lambda v: decimal.Decimal('NaN') if v is None else v), mask=mask, type=t)
elif LooseVersion(pa.__version__) < LooseVersion("0.11.0"):
# TODO: see ARROW-1949. Remove when the minimum PyArrow version becomes 0.11.0.
return pa.Array.from_pandas(s, mask=mask, type=t)

try:
array = pa.Array.from_pandas(s, mask=mask, type=t, safe=self._safecheck)
except pa.ArrowException as e:
error_msg = "Exception thrown when converting pandas.Series (%s) to Arrow " + \
"Array (%s). It can be caused by overflows or other unsafe " + \
"conversions warned by Arrow. Arrow safe type check can be " + \
"disabled by using SQL config " + \
"`spark.sql.execution.pandas.arrowSafeTypeConversion`."
raise RuntimeError(error_msg % (s.dtype, t), e)
return array

arrs = []
for s, t in series:
if t is not None and pa.types.is_struct(t):
if not isinstance(s, pd.DataFrame):
raise ValueError("A field of type StructType expects a pandas.DataFrame, "
"but got: %s" % str(type(s)))

# Input partition and result pandas.DataFrame empty, make empty Arrays with struct
if len(s) == 0 and len(s.columns) == 0:
arrs_names = [(pa.array([], type=field.type), field.name) for field in t]
# Assign result columns by schema name if user labeled with strings
elif self._assign_cols_by_name and any(isinstance(name, basestring)
for name in s.columns):
arrs_names = [(create_array(s[field.name], field.type), field.name)
for field in t]
# Assign result columns by position
else:
arrs_names = [(create_array(s[s.columns[i]], field.type), field.name)
for i, field in enumerate(t)]

struct_arrs, struct_names = zip(*arrs_names)

# TODO: from_arrays args switched for v0.9.0, remove when bump min pyarrow version
if LooseVersion(pa.__version__) < LooseVersion("0.9.0"):
arrs.append(pa.StructArray.from_arrays(struct_names, struct_arrs))
else:
arrs.append(pa.StructArray.from_arrays(struct_arrs, struct_names))
else:
arrs.append(create_array(s, t))

return pa.RecordBatch.from_arrays(arrs, ["_%d" % i for i in xrange(len(arrs))])

def dump_stream(self, iterator, stream):
"""
Make ArrowRecordBatches from Pandas Series and serialize. Input is a single series or
a list of series accompanied by an optional pyarrow type to coerce the data to.
"""
import pyarrow as pa
writer = None
try:
for series in iterator:
batch = _create_batch(series, self._timezone, self._safecheck)
if writer is None:
write_int(SpecialLengths.START_ARROW_STREAM, stream)
writer = pa.RecordBatchStreamWriter(stream, batch.schema)
writer.write_batch(batch)
finally:
if writer is not None:
writer.close()
batches = (self._create_batch(series) for series in iterator)
super(ArrowStreamPandasSerializer, self).dump_stream(batches, stream)

def load_stream(self, stream):
"""
Deserialize ArrowRecordBatches to an Arrow table and return as a list of pandas.Series.
"""
batches = super(ArrowStreamPandasSerializer, self).load_stream(stream)
import pyarrow as pa
reader = pa.ipc.open_stream(stream)

for batch in reader:
yield [self.arrow_to_pandas(c) for c in pa.Table.from_batches([batch]).itercolumns()]
from pyspark.sql.types import from_arrow_type
for batch in batches:
yield [self.arrow_to_pandas(c, from_arrow_type(c.type))
for c in pa.Table.from_batches([batch]).itercolumns()]

def __repr__(self):
return "ArrowStreamPandasSerializer"


class ArrowStreamPandasUDFSerializer(ArrowStreamPandasSerializer):
"""
Serializer used by Python worker to evaluate Pandas UDFs
"""

def __init__(self, timezone, safecheck, assign_cols_by_name, df_for_struct=False):
super(ArrowStreamPandasUDFSerializer, self) \
.__init__(timezone, safecheck, assign_cols_by_name)
self._df_for_struct = df_for_struct

def arrow_to_pandas(self, arrow_column, data_type):
from pyspark.sql.types import StructType, \
_arrow_column_to_pandas, _check_dataframe_localize_timestamps

if self._df_for_struct and type(data_type) == StructType:
import pandas as pd
series = [_arrow_column_to_pandas(column, field.dataType).rename(field.name)
for column, field in zip(arrow_column.flatten(), data_type)]
s = _check_dataframe_localize_timestamps(pd.concat(series, axis=1), self._timezone)
else:
s = super(ArrowStreamPandasUDFSerializer, self).arrow_to_pandas(arrow_column, data_type)
return s

def dump_stream(self, iterator, stream):
"""
Override because Pandas UDFs require a START_ARROW_STREAM before the Arrow stream is sent.
This should be sent after creating the first record batch so in case of an error, it can
be sent back to the JVM before the Arrow stream starts.
"""

def init_stream_yield_batches():
should_write_start_length = True
for series in iterator:
batch = self._create_batch(series)
if should_write_start_length:
write_int(SpecialLengths.START_ARROW_STREAM, stream)
should_write_start_length = False
yield batch

return ArrowStreamSerializer.dump_stream(self, init_stream_yield_batches(), stream)

def __repr__(self):
return "ArrowStreamPandasUDFSerializer"


class BatchedSerializer(Serializer):

"""
Expand Down
5 changes: 2 additions & 3 deletions python/pyspark/sql/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -2107,14 +2107,13 @@ def toPandas(self):
# of PyArrow is found, if 'spark.sql.execution.arrow.enabled' is enabled.
if use_arrow:
try:
from pyspark.sql.types import _check_dataframe_convert_date, \
from pyspark.sql.types import _arrow_table_to_pandas, \
_check_dataframe_localize_timestamps
import pyarrow
batches = self._collectAsArrow()
if len(batches) > 0:
table = pyarrow.Table.from_batches(batches)
pdf = table.to_pandas()
pdf = _check_dataframe_convert_date(pdf, self.schema)
pdf = _arrow_table_to_pandas(table, self.schema)
return _check_dataframe_localize_timestamps(pdf, timezone)
else:
return pd.DataFrame.from_records([], columns=self.columns)
Expand Down
12 changes: 11 additions & 1 deletion python/pyspark/sql/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2872,8 +2872,9 @@ def pandas_udf(f=None, returnType=None, functionType=None):
A scalar UDF defines a transformation: One or more `pandas.Series` -> A `pandas.Series`.
The length of the returned `pandas.Series` must be of the same as the input `pandas.Series`.
If the return type is :class:`StructType`, the returned value should be a `pandas.DataFrame`.
:class:`MapType`, :class:`StructType` are currently not supported as output types.
:class:`MapType`, nested :class:`StructType` are currently not supported as output types.
Scalar UDFs are used with :meth:`pyspark.sql.DataFrame.withColumn` and
:meth:`pyspark.sql.DataFrame.select`.
Expand All @@ -2898,6 +2899,15 @@ def pandas_udf(f=None, returnType=None, functionType=None):
+----------+--------------+------------+
| 8| JOHN DOE| 22|
+----------+--------------+------------+
>>> @pandas_udf("first string, last string") # doctest: +SKIP
... def split_expand(n):
... return n.str.split(expand=True)
>>> df.select(split_expand("name")).show() # doctest: +SKIP
+------------------+
|split_expand(name)|
+------------------+
| [John, Doe]|
+------------------+
.. note:: The length of `pandas.Series` within a scalar UDF is not that of the whole input
column, but is the length of an internal batch used for each call to the function.
Expand Down
41 changes: 24 additions & 17 deletions python/pyspark/sql/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -530,15 +530,29 @@ def _create_from_pandas_with_arrow(self, pdf, schema, timezone):
to Arrow data, then sending to the JVM to parallelize. If a schema is passed in, the
data types will be used to coerce the data in Pandas to Arrow conversion.
"""
from pyspark.serializers import ArrowStreamSerializer, _create_batch
from pyspark.sql.types import from_arrow_schema, to_arrow_type, TimestampType
from distutils.version import LooseVersion
from pyspark.serializers import ArrowStreamPandasSerializer
from pyspark.sql.types import from_arrow_type, to_arrow_type, TimestampType
from pyspark.sql.utils import require_minimum_pandas_version, \
require_minimum_pyarrow_version

require_minimum_pandas_version()
require_minimum_pyarrow_version()

from pandas.api.types import is_datetime64_dtype, is_datetime64tz_dtype
import pyarrow as pa

# Create the Spark schema from list of names passed in with Arrow types
if isinstance(schema, (list, tuple)):
if LooseVersion(pa.__version__) < LooseVersion("0.12.0"):
temp_batch = pa.RecordBatch.from_pandas(pdf[0:100], preserve_index=False)
arrow_schema = temp_batch.schema
else:
arrow_schema = pa.Schema.from_pandas(pdf, preserve_index=False)
struct = StructType()
for name, field in zip(schema, arrow_schema):
struct.add(name, from_arrow_type(field.type), nullable=field.nullable)
schema = struct

# Determine arrow types to coerce data when creating batches
if isinstance(schema, StructType):
Expand All @@ -555,31 +569,24 @@ def _create_from_pandas_with_arrow(self, pdf, schema, timezone):
step = -(-len(pdf) // self.sparkContext.defaultParallelism) # round int up
pdf_slices = (pdf[start:start + step] for start in xrange(0, len(pdf), step))

# Create Arrow record batches
safecheck = self._wrapped._conf.arrowSafeTypeConversion()
batches = [_create_batch([(c, t) for (_, c), t in zip(pdf_slice.iteritems(), arrow_types)],
timezone, safecheck)
for pdf_slice in pdf_slices]

# Create the Spark schema from the first Arrow batch (always at least 1 batch after slicing)
if isinstance(schema, (list, tuple)):
struct = from_arrow_schema(batches[0].schema)
for i, name in enumerate(schema):
struct.fields[i].name = name
struct.names[i] = name
schema = struct
# Create list of Arrow (columns, type) for serializer dump_stream
arrow_data = [[(c, t) for (_, c), t in zip(pdf_slice.iteritems(), arrow_types)]
for pdf_slice in pdf_slices]

jsqlContext = self._wrapped._jsqlContext

safecheck = self._wrapped._conf.arrowSafeTypeConversion()
col_by_name = True # col by name only applies to StructType columns, can't happen here
ser = ArrowStreamPandasSerializer(timezone, safecheck, col_by_name)

def reader_func(temp_filename):
return self._jvm.PythonSQLUtils.readArrowStreamFromFile(jsqlContext, temp_filename)

def create_RDD_server():
return self._jvm.ArrowRDDServer(jsqlContext)

# Create Spark DataFrame from Arrow stream file, using one batch per partition
jrdd = self._sc._serialize_to_jvm(batches, ArrowStreamSerializer(), reader_func,
create_RDD_server)
jrdd = self._sc._serialize_to_jvm(arrow_data, ser, reader_func, create_RDD_server)
jdf = self._jvm.PythonSQLUtils.toDataFrame(jrdd, schema.json(), jsqlContext)
df = DataFrame(jdf, self._wrapped)
df._schema = schema
Expand Down
Loading

0 comments on commit 7ca6659

Please sign in to comment.