diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index 1d170530d2850..dfe1aa14798f4 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -66,6 +66,7 @@ else: import pickle protocol = 3 + basestring = unicode = str xrange = range from pyspark import cloudpickle @@ -245,112 +246,180 @@ def __repr__(self): return "ArrowStreamSerializer" -def _create_batch(series, timezone, safecheck): +class ArrowStreamPandasSerializer(ArrowStreamSerializer): """ - Create an Arrow record batch from the given pandas.Series or list of Series, with optional type. + Serializes Pandas.Series as Arrow data with Arrow streaming format. - :param series: A single pandas.Series, list of Series, or list of (series, arrow_type) :param timezone: A timezone to respect when handling timestamp values - :return: Arrow RecordBatch - """ - import decimal - from distutils.version import LooseVersion - import pyarrow as pa - from pyspark.sql.types import _check_series_convert_timestamps_internal - # Make input conform to [(series1, type1), (series2, type2), ...] - if not isinstance(series, (list, tuple)) or \ - (len(series) == 2 and isinstance(series[1], pa.DataType)): - series = [series] - series = ((s, None) if not isinstance(s, (list, tuple)) else s for s in series) - - def create_array(s, t): - mask = s.isnull() - # Ensure timestamp series are in expected form for Spark internal representation - # TODO: maybe don't need None check anymore as of Arrow 0.9.1 - if t is not None and pa.types.is_timestamp(t): - s = _check_series_convert_timestamps_internal(s.fillna(0), timezone) - # TODO: need cast after Arrow conversion, ns values cause error with pandas 0.19.2 - return pa.Array.from_pandas(s, mask=mask).cast(t, safe=False) - elif t is not None and pa.types.is_string(t) and sys.version < '3': - # TODO: need decode before converting to Arrow in Python 2 - # TODO: don't need as of Arrow 0.9.1 - return pa.Array.from_pandas(s.apply( - lambda v: v.decode("utf-8") if isinstance(v, str) else v), mask=mask, type=t) - elif t is not None and pa.types.is_decimal(t) and \ - LooseVersion("0.9.0") <= LooseVersion(pa.__version__) < LooseVersion("0.10.0"): - # TODO: see ARROW-2432. Remove when the minimum PyArrow version becomes 0.10.0. - return pa.Array.from_pandas(s.apply( - lambda v: decimal.Decimal('NaN') if v is None else v), mask=mask, type=t) - elif LooseVersion(pa.__version__) < LooseVersion("0.11.0"): - # TODO: see ARROW-1949. Remove when the minimum PyArrow version becomes 0.11.0. - return pa.Array.from_pandas(s, mask=mask, type=t) - - try: - array = pa.Array.from_pandas(s, mask=mask, type=t, safe=safecheck) - except pa.ArrowException as e: - error_msg = "Exception thrown when converting pandas.Series (%s) to Arrow " + \ - "Array (%s). It can be caused by overflows or other unsafe " + \ - "conversions warned by Arrow. Arrow safe type check can be " + \ - "disabled by using SQL config " + \ - "`spark.sql.execution.pandas.arrowSafeTypeConversion`." - raise RuntimeError(error_msg % (s.dtype, t), e) - return array - - arrs = [create_array(s, t) for s, t in series] - return pa.RecordBatch.from_arrays(arrs, ["_%d" % i for i in xrange(len(arrs))]) - - -class ArrowStreamPandasSerializer(Serializer): - """ - Serializes Pandas.Series as Arrow data with Arrow streaming format. + :param safecheck: If True, conversion from Arrow to Pandas checks for overflow/truncation + :param assign_cols_by_name: If True, then Pandas DataFrames will get columns by name """ - def __init__(self, timezone, safecheck): + def __init__(self, timezone, safecheck, assign_cols_by_name): super(ArrowStreamPandasSerializer, self).__init__() self._timezone = timezone self._safecheck = safecheck + self._assign_cols_by_name = assign_cols_by_name - def arrow_to_pandas(self, arrow_column): - from pyspark.sql.types import from_arrow_type, \ - _check_series_convert_date, _check_series_localize_timestamps + def arrow_to_pandas(self, arrow_column, data_type): + from pyspark.sql.types import _arrow_column_to_pandas, _check_series_localize_timestamps - s = arrow_column.to_pandas() - s = _check_series_convert_date(s, from_arrow_type(arrow_column.type)) + s = _arrow_column_to_pandas(arrow_column, data_type) s = _check_series_localize_timestamps(s, self._timezone) return s + def _create_batch(self, series): + """ + Create an Arrow record batch from the given pandas.Series or list of Series, + with optional type. + + :param series: A single pandas.Series, list of Series, or list of (series, arrow_type) + :return: Arrow RecordBatch + """ + import decimal + from distutils.version import LooseVersion + import pandas as pd + import pyarrow as pa + from pyspark.sql.types import _check_series_convert_timestamps_internal + # Make input conform to [(series1, type1), (series2, type2), ...] + if not isinstance(series, (list, tuple)) or \ + (len(series) == 2 and isinstance(series[1], pa.DataType)): + series = [series] + series = ((s, None) if not isinstance(s, (list, tuple)) else s for s in series) + + def create_array(s, t): + mask = s.isnull() + # Ensure timestamp series are in expected form for Spark internal representation + # TODO: maybe don't need None check anymore as of Arrow 0.9.1 + if t is not None and pa.types.is_timestamp(t): + s = _check_series_convert_timestamps_internal(s.fillna(0), self._timezone) + # TODO: need cast after Arrow conversion, ns values cause error with pandas 0.19.2 + return pa.Array.from_pandas(s, mask=mask).cast(t, safe=False) + elif t is not None and pa.types.is_string(t) and sys.version < '3': + # TODO: need decode before converting to Arrow in Python 2 + # TODO: don't need as of Arrow 0.9.1 + return pa.Array.from_pandas(s.apply( + lambda v: v.decode("utf-8") if isinstance(v, str) else v), mask=mask, type=t) + elif t is not None and pa.types.is_decimal(t) and \ + LooseVersion("0.9.0") <= LooseVersion(pa.__version__) < LooseVersion("0.10.0"): + # TODO: see ARROW-2432. Remove when the minimum PyArrow version becomes 0.10.0. + return pa.Array.from_pandas(s.apply( + lambda v: decimal.Decimal('NaN') if v is None else v), mask=mask, type=t) + elif LooseVersion(pa.__version__) < LooseVersion("0.11.0"): + # TODO: see ARROW-1949. Remove when the minimum PyArrow version becomes 0.11.0. + return pa.Array.from_pandas(s, mask=mask, type=t) + + try: + array = pa.Array.from_pandas(s, mask=mask, type=t, safe=self._safecheck) + except pa.ArrowException as e: + error_msg = "Exception thrown when converting pandas.Series (%s) to Arrow " + \ + "Array (%s). It can be caused by overflows or other unsafe " + \ + "conversions warned by Arrow. Arrow safe type check can be " + \ + "disabled by using SQL config " + \ + "`spark.sql.execution.pandas.arrowSafeTypeConversion`." + raise RuntimeError(error_msg % (s.dtype, t), e) + return array + + arrs = [] + for s, t in series: + if t is not None and pa.types.is_struct(t): + if not isinstance(s, pd.DataFrame): + raise ValueError("A field of type StructType expects a pandas.DataFrame, " + "but got: %s" % str(type(s))) + + # Input partition and result pandas.DataFrame empty, make empty Arrays with struct + if len(s) == 0 and len(s.columns) == 0: + arrs_names = [(pa.array([], type=field.type), field.name) for field in t] + # Assign result columns by schema name if user labeled with strings + elif self._assign_cols_by_name and any(isinstance(name, basestring) + for name in s.columns): + arrs_names = [(create_array(s[field.name], field.type), field.name) + for field in t] + # Assign result columns by position + else: + arrs_names = [(create_array(s[s.columns[i]], field.type), field.name) + for i, field in enumerate(t)] + + struct_arrs, struct_names = zip(*arrs_names) + + # TODO: from_arrays args switched for v0.9.0, remove when bump min pyarrow version + if LooseVersion(pa.__version__) < LooseVersion("0.9.0"): + arrs.append(pa.StructArray.from_arrays(struct_names, struct_arrs)) + else: + arrs.append(pa.StructArray.from_arrays(struct_arrs, struct_names)) + else: + arrs.append(create_array(s, t)) + + return pa.RecordBatch.from_arrays(arrs, ["_%d" % i for i in xrange(len(arrs))]) + def dump_stream(self, iterator, stream): """ Make ArrowRecordBatches from Pandas Series and serialize. Input is a single series or a list of series accompanied by an optional pyarrow type to coerce the data to. """ - import pyarrow as pa - writer = None - try: - for series in iterator: - batch = _create_batch(series, self._timezone, self._safecheck) - if writer is None: - write_int(SpecialLengths.START_ARROW_STREAM, stream) - writer = pa.RecordBatchStreamWriter(stream, batch.schema) - writer.write_batch(batch) - finally: - if writer is not None: - writer.close() + batches = (self._create_batch(series) for series in iterator) + super(ArrowStreamPandasSerializer, self).dump_stream(batches, stream) def load_stream(self, stream): """ Deserialize ArrowRecordBatches to an Arrow table and return as a list of pandas.Series. """ + batches = super(ArrowStreamPandasSerializer, self).load_stream(stream) import pyarrow as pa - reader = pa.ipc.open_stream(stream) - - for batch in reader: - yield [self.arrow_to_pandas(c) for c in pa.Table.from_batches([batch]).itercolumns()] + from pyspark.sql.types import from_arrow_type + for batch in batches: + yield [self.arrow_to_pandas(c, from_arrow_type(c.type)) + for c in pa.Table.from_batches([batch]).itercolumns()] def __repr__(self): return "ArrowStreamPandasSerializer" +class ArrowStreamPandasUDFSerializer(ArrowStreamPandasSerializer): + """ + Serializer used by Python worker to evaluate Pandas UDFs + """ + + def __init__(self, timezone, safecheck, assign_cols_by_name, df_for_struct=False): + super(ArrowStreamPandasUDFSerializer, self) \ + .__init__(timezone, safecheck, assign_cols_by_name) + self._df_for_struct = df_for_struct + + def arrow_to_pandas(self, arrow_column, data_type): + from pyspark.sql.types import StructType, \ + _arrow_column_to_pandas, _check_dataframe_localize_timestamps + + if self._df_for_struct and type(data_type) == StructType: + import pandas as pd + series = [_arrow_column_to_pandas(column, field.dataType).rename(field.name) + for column, field in zip(arrow_column.flatten(), data_type)] + s = _check_dataframe_localize_timestamps(pd.concat(series, axis=1), self._timezone) + else: + s = super(ArrowStreamPandasUDFSerializer, self).arrow_to_pandas(arrow_column, data_type) + return s + + def dump_stream(self, iterator, stream): + """ + Override because Pandas UDFs require a START_ARROW_STREAM before the Arrow stream is sent. + This should be sent after creating the first record batch so in case of an error, it can + be sent back to the JVM before the Arrow stream starts. + """ + + def init_stream_yield_batches(): + should_write_start_length = True + for series in iterator: + batch = self._create_batch(series) + if should_write_start_length: + write_int(SpecialLengths.START_ARROW_STREAM, stream) + should_write_start_length = False + yield batch + + return ArrowStreamSerializer.dump_stream(self, init_stream_yield_batches(), stream) + + def __repr__(self): + return "ArrowStreamPandasUDFSerializer" + + class BatchedSerializer(Serializer): """ diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index a1056d0b787e3..472d2969b3e19 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -2107,14 +2107,13 @@ def toPandas(self): # of PyArrow is found, if 'spark.sql.execution.arrow.enabled' is enabled. if use_arrow: try: - from pyspark.sql.types import _check_dataframe_convert_date, \ + from pyspark.sql.types import _arrow_table_to_pandas, \ _check_dataframe_localize_timestamps import pyarrow batches = self._collectAsArrow() if len(batches) > 0: table = pyarrow.Table.from_batches(batches) - pdf = table.to_pandas() - pdf = _check_dataframe_convert_date(pdf, self.schema) + pdf = _arrow_table_to_pandas(table, self.schema) return _check_dataframe_localize_timestamps(pdf, timezone) else: return pd.DataFrame.from_records([], columns=self.columns) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index cac566c74cd9b..584de7be33cac 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -2872,8 +2872,9 @@ def pandas_udf(f=None, returnType=None, functionType=None): A scalar UDF defines a transformation: One or more `pandas.Series` -> A `pandas.Series`. The length of the returned `pandas.Series` must be of the same as the input `pandas.Series`. + If the return type is :class:`StructType`, the returned value should be a `pandas.DataFrame`. - :class:`MapType`, :class:`StructType` are currently not supported as output types. + :class:`MapType`, nested :class:`StructType` are currently not supported as output types. Scalar UDFs are used with :meth:`pyspark.sql.DataFrame.withColumn` and :meth:`pyspark.sql.DataFrame.select`. @@ -2898,6 +2899,15 @@ def pandas_udf(f=None, returnType=None, functionType=None): +----------+--------------+------------+ | 8| JOHN DOE| 22| +----------+--------------+------------+ + >>> @pandas_udf("first string, last string") # doctest: +SKIP + ... def split_expand(n): + ... return n.str.split(expand=True) + >>> df.select(split_expand("name")).show() # doctest: +SKIP + +------------------+ + |split_expand(name)| + +------------------+ + | [John, Doe]| + +------------------+ .. note:: The length of `pandas.Series` within a scalar UDF is not that of the whole input column, but is the length of an internal batch used for each call to the function. diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py index bdf1701a58959..b11e0f3ff69de 100644 --- a/python/pyspark/sql/session.py +++ b/python/pyspark/sql/session.py @@ -530,8 +530,9 @@ def _create_from_pandas_with_arrow(self, pdf, schema, timezone): to Arrow data, then sending to the JVM to parallelize. If a schema is passed in, the data types will be used to coerce the data in Pandas to Arrow conversion. """ - from pyspark.serializers import ArrowStreamSerializer, _create_batch - from pyspark.sql.types import from_arrow_schema, to_arrow_type, TimestampType + from distutils.version import LooseVersion + from pyspark.serializers import ArrowStreamPandasSerializer + from pyspark.sql.types import from_arrow_type, to_arrow_type, TimestampType from pyspark.sql.utils import require_minimum_pandas_version, \ require_minimum_pyarrow_version @@ -539,6 +540,19 @@ def _create_from_pandas_with_arrow(self, pdf, schema, timezone): require_minimum_pyarrow_version() from pandas.api.types import is_datetime64_dtype, is_datetime64tz_dtype + import pyarrow as pa + + # Create the Spark schema from list of names passed in with Arrow types + if isinstance(schema, (list, tuple)): + if LooseVersion(pa.__version__) < LooseVersion("0.12.0"): + temp_batch = pa.RecordBatch.from_pandas(pdf[0:100], preserve_index=False) + arrow_schema = temp_batch.schema + else: + arrow_schema = pa.Schema.from_pandas(pdf, preserve_index=False) + struct = StructType() + for name, field in zip(schema, arrow_schema): + struct.add(name, from_arrow_type(field.type), nullable=field.nullable) + schema = struct # Determine arrow types to coerce data when creating batches if isinstance(schema, StructType): @@ -555,22 +569,16 @@ def _create_from_pandas_with_arrow(self, pdf, schema, timezone): step = -(-len(pdf) // self.sparkContext.defaultParallelism) # round int up pdf_slices = (pdf[start:start + step] for start in xrange(0, len(pdf), step)) - # Create Arrow record batches - safecheck = self._wrapped._conf.arrowSafeTypeConversion() - batches = [_create_batch([(c, t) for (_, c), t in zip(pdf_slice.iteritems(), arrow_types)], - timezone, safecheck) - for pdf_slice in pdf_slices] - - # Create the Spark schema from the first Arrow batch (always at least 1 batch after slicing) - if isinstance(schema, (list, tuple)): - struct = from_arrow_schema(batches[0].schema) - for i, name in enumerate(schema): - struct.fields[i].name = name - struct.names[i] = name - schema = struct + # Create list of Arrow (columns, type) for serializer dump_stream + arrow_data = [[(c, t) for (_, c), t in zip(pdf_slice.iteritems(), arrow_types)] + for pdf_slice in pdf_slices] jsqlContext = self._wrapped._jsqlContext + safecheck = self._wrapped._conf.arrowSafeTypeConversion() + col_by_name = True # col by name only applies to StructType columns, can't happen here + ser = ArrowStreamPandasSerializer(timezone, safecheck, col_by_name) + def reader_func(temp_filename): return self._jvm.PythonSQLUtils.readArrowStreamFromFile(jsqlContext, temp_filename) @@ -578,8 +586,7 @@ def create_RDD_server(): return self._jvm.ArrowRDDServer(jsqlContext) # Create Spark DataFrame from Arrow stream file, using one batch per partition - jrdd = self._sc._serialize_to_jvm(batches, ArrowStreamSerializer(), reader_func, - create_RDD_server) + jrdd = self._sc._serialize_to_jvm(arrow_data, ser, reader_func, create_RDD_server) jdf = self._jvm.PythonSQLUtils.toDataFrame(jrdd, schema.json(), jsqlContext) df = DataFrame(jdf, self._wrapped) df._schema = schema diff --git a/python/pyspark/sql/tests/test_arrow.py b/python/pyspark/sql/tests/test_arrow.py index 3ce6764278ce3..d82da5cec9836 100644 --- a/python/pyspark/sql/tests/test_arrow.py +++ b/python/pyspark/sql/tests/test_arrow.py @@ -68,7 +68,9 @@ def setUpClass(cls): (u"b", 2, 20, 0.4, 4.0, Decimal("4.0"), date(2012, 2, 2), datetime(2012, 2, 2, 2, 2, 2)), (u"c", 3, 30, 0.8, 6.0, Decimal("6.0"), - date(2100, 3, 3), datetime(2100, 3, 3, 3, 3, 3))] + date(2100, 3, 3), datetime(2100, 3, 3, 3, 3, 3)), + (u"d", 4, 40, 1.0, 8.0, Decimal("8.0"), + date(2262, 4, 12), datetime(2262, 3, 3, 3, 3, 3))] # TODO: remove version check once minimum pyarrow version is 0.10.0 if LooseVersion("0.10.0") <= LooseVersion(pa.__version__): @@ -76,6 +78,7 @@ def setUpClass(cls): cls.data[0] = cls.data[0] + (bytearray(b"a"),) cls.data[1] = cls.data[1] + (bytearray(b"bb"),) cls.data[2] = cls.data[2] + (bytearray(b"ccc"),) + cls.data[3] = cls.data[3] + (bytearray(b"dddd"),) @classmethod def tearDownClass(cls): diff --git a/python/pyspark/sql/tests/test_pandas_udf_grouped_map.py b/python/pyspark/sql/tests/test_pandas_udf_grouped_map.py index a0a25359d1e01..f7684d3fbcff0 100644 --- a/python/pyspark/sql/tests/test_pandas_udf_grouped_map.py +++ b/python/pyspark/sql/tests/test_pandas_udf_grouped_map.py @@ -273,6 +273,7 @@ def test_unsupported_types(self): StructField('map', MapType(StringType(), IntegerType())), StructField('arr_ts', ArrayType(TimestampType())), StructField('null', NullType()), + StructField('struct', StructType([StructField('l', LongType())])), ] # TODO: Remove this if-statement once minimum pyarrow version is 0.10.0 diff --git a/python/pyspark/sql/tests/test_pandas_udf_scalar.py b/python/pyspark/sql/tests/test_pandas_udf_scalar.py index f29ff11ab998c..5efcfd343013a 100644 --- a/python/pyspark/sql/tests/test_pandas_udf_scalar.py +++ b/python/pyspark/sql/tests/test_pandas_udf_scalar.py @@ -23,13 +23,16 @@ import time import unittest +if sys.version >= '3': + unicode = str + from datetime import date, datetime from decimal import Decimal from distutils.version import LooseVersion from pyspark.rdd import PythonEvalType from pyspark.sql import Column -from pyspark.sql.functions import array, col, expr, lit, sum, udf, pandas_udf +from pyspark.sql.functions import array, col, expr, lit, sum, struct, udf, pandas_udf from pyspark.sql.types import Row from pyspark.sql.types import * from pyspark.sql.utils import AnalysisException @@ -265,6 +268,77 @@ def test_vectorized_udf_null_array(self): result = df.select(array_f(col('array'))) self.assertEquals(df.collect(), result.collect()) + def test_vectorized_udf_struct_type(self): + import pandas as pd + import pyarrow as pa + + df = self.spark.range(10) + return_type = StructType([ + StructField('id', LongType()), + StructField('str', StringType())]) + + def func(id): + return pd.DataFrame({'id': id, 'str': id.apply(unicode)}) + + f = pandas_udf(func, returnType=return_type) + + expected = df.select(struct(col('id'), col('id').cast('string').alias('str')) + .alias('struct')).collect() + + actual = df.select(f(col('id')).alias('struct')).collect() + self.assertEqual(expected, actual) + + g = pandas_udf(func, 'id: long, str: string') + actual = df.select(g(col('id')).alias('struct')).collect() + self.assertEqual(expected, actual) + + struct_f = pandas_udf(lambda x: x, return_type) + actual = df.select(struct_f(struct(col('id'), col('id').cast('string').alias('str')))) + if LooseVersion(pa.__version__) < LooseVersion("0.10.0"): + with QuietTest(self.sc): + from py4j.protocol import Py4JJavaError + with self.assertRaisesRegexp( + Py4JJavaError, + 'Unsupported type in conversion from Arrow'): + self.assertEqual(expected, actual.collect()) + else: + self.assertEqual(expected, actual.collect()) + + def test_vectorized_udf_struct_complex(self): + import pandas as pd + + df = self.spark.range(10) + return_type = StructType([ + StructField('ts', TimestampType()), + StructField('arr', ArrayType(LongType()))]) + + @pandas_udf(returnType=return_type) + def f(id): + return pd.DataFrame({'ts': id.apply(lambda i: pd.Timestamp(i)), + 'arr': id.apply(lambda i: [i, i + 1])}) + + actual = df.withColumn('f', f(col('id'))).collect() + for i, row in enumerate(actual): + id, f = row + self.assertEqual(i, id) + self.assertEqual(pd.Timestamp(i).to_pydatetime(), f[0]) + self.assertListEqual([i, i + 1], f[1]) + + def test_vectorized_udf_nested_struct(self): + nested_type = StructType([ + StructField('id', IntegerType()), + StructField('nested', StructType([ + StructField('foo', StringType()), + StructField('bar', FloatType()) + ])) + ]) + + with QuietTest(self.sc): + with self.assertRaisesRegexp( + Exception, + 'Invalid returnType with scalar Pandas UDFs'): + pandas_udf(lambda x: x, returnType=nested_type) + def test_vectorized_udf_complex(self): df = self.spark.range(10).select( col('id').cast('int').alias('a'), @@ -302,6 +376,26 @@ def test_vectorized_udf_chained(self): res = df.select(g(f(col('id')))) self.assertEquals(df.collect(), res.collect()) + def test_vectorized_udf_chained_struct_type(self): + import pandas as pd + + df = self.spark.range(10) + return_type = StructType([ + StructField('id', LongType()), + StructField('str', StringType())]) + + @pandas_udf(return_type) + def f(id): + return pd.DataFrame({'id': id, 'str': id.apply(unicode)}) + + g = pandas_udf(lambda x: x, return_type) + + expected = df.select(struct(col('id'), col('id').cast('string').alias('str')) + .alias('struct')).collect() + + actual = df.select(g(f(col('id'))).alias('struct')).collect() + self.assertEqual(expected, actual) + def test_vectorized_udf_wrong_return_type(self): with QuietTest(self.sc): with self.assertRaisesRegexp( @@ -331,6 +425,20 @@ def test_vectorized_udf_empty_partition(self): res = df.select(f(col('id'))) self.assertEquals(df.collect(), res.collect()) + def test_vectorized_udf_struct_with_empty_partition(self): + df = self.spark.createDataFrame(self.sc.parallelize([Row(id=1)], 2))\ + .withColumn('name', lit('John Doe')) + + @pandas_udf("first string, last string") + def split_expand(n): + return n.str.split(expand=True) + + result = df.select(split_expand('name')).collect() + self.assertEqual(1, len(result)) + row = result[0] + self.assertEqual('John', row[0]['first']) + self.assertEqual('Doe', row[0]['last']) + def test_vectorized_udf_varargs(self): df = self.spark.createDataFrame(self.sc.parallelize([Row(id=1)], 2)) f = pandas_udf(lambda *v: v[0], LongType()) @@ -343,13 +451,18 @@ def test_vectorized_udf_unsupported_types(self): NotImplementedError, 'Invalid returnType.*scalar Pandas UDF.*MapType'): pandas_udf(lambda x: x, MapType(StringType(), IntegerType())) + with self.assertRaisesRegexp( + NotImplementedError, + 'Invalid returnType.*scalar Pandas UDF.*ArrayType.StructType'): + pandas_udf(lambda x: x, ArrayType(StructType([StructField('a', IntegerType())]))) def test_vectorized_udf_dates(self): schema = StructType().add("idx", LongType()).add("date", DateType()) data = [(0, date(1969, 1, 1),), (1, date(2012, 2, 2),), (2, None,), - (3, date(2100, 4, 4),)] + (3, date(2100, 4, 4),), + (4, date(2262, 4, 12),)] df = self.spark.createDataFrame(data, schema=schema) date_copy = pandas_udf(lambda t: t, returnType=DateType()) diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index f540954bcdb54..3246e9ee31f52 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -1616,9 +1616,15 @@ def to_arrow_type(dt): # Timestamps should be in UTC, JVM Arrow timestamps require a timezone to be read arrow_type = pa.timestamp('us', tz='UTC') elif type(dt) == ArrayType: - if type(dt.elementType) == TimestampType: + if type(dt.elementType) in [StructType, TimestampType]: raise TypeError("Unsupported type in conversion to Arrow: " + str(dt)) arrow_type = pa.list_(to_arrow_type(dt.elementType)) + elif type(dt) == StructType: + if any(type(field.dataType) == StructType for field in dt): + raise TypeError("Nested StructType not supported in conversion to Arrow") + fields = [pa.field(field.name, to_arrow_type(field.dataType), nullable=field.nullable) + for field in dt] + arrow_type = pa.struct(fields) else: raise TypeError("Unsupported type in conversion to Arrow: " + str(dt)) return arrow_type @@ -1671,6 +1677,16 @@ def from_arrow_type(at): if types.is_timestamp(at.value_type): raise TypeError("Unsupported type in conversion from Arrow: " + str(at)) spark_type = ArrayType(from_arrow_type(at.value_type)) + elif types.is_struct(at): + # TODO: remove version check once minimum pyarrow version is 0.10.0 + if LooseVersion(pa.__version__) < LooseVersion("0.10.0"): + raise TypeError("Unsupported type in conversion from Arrow: " + str(at) + + "\nPlease install pyarrow >= 0.10.0 for StructType support.") + if any(types.is_struct(field.type) for field in at): + raise TypeError("Nested StructType not supported in conversion from Arrow: " + str(at)) + return StructType( + [StructField(field.name, from_arrow_type(field.type), nullable=field.nullable) + for field in at]) else: raise TypeError("Unsupported type in conversion from Arrow: " + str(at)) return spark_type @@ -1684,38 +1700,52 @@ def from_arrow_schema(arrow_schema): for field in arrow_schema]) -def _check_series_convert_date(series, data_type): - """ - Cast the series to datetime.date if it's a date type, otherwise returns the original series. +def _arrow_column_to_pandas(column, data_type): + """ Convert Arrow Column to pandas Series. - :param series: pandas.Series - :param data_type: a Spark data type for the series + :param series: pyarrow.lib.Column + :param data_type: a Spark data type for the column """ - import pyarrow + import pandas as pd + import pyarrow as pa from distutils.version import LooseVersion - # As of Arrow 0.12.0, date_as_objects is True by default, see ARROW-3910 - if LooseVersion(pyarrow.__version__) < LooseVersion("0.12.0") and type(data_type) == DateType: - return series.dt.date + # If the given column is a date type column, creates a series of datetime.date directly instead + # of creating datetime64[ns] as intermediate data to avoid overflow caused by datetime64[ns] + # type handling. + if LooseVersion(pa.__version__) < LooseVersion("0.11.0"): + if type(data_type) == DateType: + return pd.Series(column.to_pylist(), name=column.name) + else: + return column.to_pandas() else: - return series + # Since Arrow 0.11.0, support date_as_object to return datetime.date instead of + # np.datetime64. + return column.to_pandas(date_as_object=True) -def _check_dataframe_convert_date(pdf, schema): - """ Correct date type value to use datetime.date. +def _arrow_table_to_pandas(table, schema): + """ Convert Arrow Table to pandas DataFrame. Pandas DataFrame created from PyArrow uses datetime64[ns] for date type values, but we should use datetime.date to match the behavior with when Arrow optimization is disabled. - :param pdf: pandas.DataFrame - :param schema: a Spark schema of the pandas.DataFrame + :param table: pyarrow.lib.Table + :param schema: a Spark schema of the pyarrow.lib.Table """ - import pyarrow + import pandas as pd + import pyarrow as pa from distutils.version import LooseVersion - # As of Arrow 0.12.0, date_as_objects is True by default, see ARROW-3910 - if LooseVersion(pyarrow.__version__) < LooseVersion("0.12.0"): - for field in schema: - pdf[field.name] = _check_series_convert_date(pdf[field.name], field.dataType) - return pdf + # If the given table contains a date type column, use `_arrow_column_to_pandas` for pyarrow<0.11 + # or use `date_as_object` option for pyarrow>=0.11 to avoid creating datetime64[ns] as + # intermediate data. + if LooseVersion(pa.__version__) < LooseVersion("0.11.0"): + if any(type(field.dataType) == DateType for field in schema): + return pd.concat([_arrow_column_to_pandas(column, field.dataType) + for column, field in zip(table.itercolumns(), schema)], axis=1) + else: + return table.to_pandas() + else: + return table.to_pandas(date_as_object=True) def _get_local_timezone(): diff --git a/python/pyspark/sql/udf.py b/python/pyspark/sql/udf.py index bd137a1a02681..20db0522ccf5a 100644 --- a/python/pyspark/sql/udf.py +++ b/python/pyspark/sql/udf.py @@ -124,7 +124,7 @@ def returnType(self): elif self.evalType == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF: if isinstance(self._returnType_placeholder, StructType): try: - to_arrow_schema(self._returnType_placeholder) + to_arrow_type(self._returnType_placeholder) except TypeError: raise NotImplementedError( "Invalid returnType with grouped map Pandas UDFs: " @@ -134,6 +134,9 @@ def returnType(self): "UDFs: returnType must be a StructType.") elif self.evalType == PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF: try: + # StructType is not yet allowed as a return type, explicitly check here to fail fast + if isinstance(self._returnType_placeholder, StructType): + raise TypeError to_arrow_type(self._returnType_placeholder) except TypeError: raise NotImplementedError( diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 01934a0e72758..478fdc081d352 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -38,8 +38,8 @@ from pyspark.rdd import PythonEvalType from pyspark.serializers import write_with_length, write_int, read_long, read_bool, \ write_long, read_int, SpecialLengths, UTF8Deserializer, PickleSerializer, \ - BatchedSerializer, ArrowStreamPandasSerializer -from pyspark.sql.types import to_arrow_type + BatchedSerializer, ArrowStreamPandasUDFSerializer +from pyspark.sql.types import to_arrow_type, StructType from pyspark.util import _get_argspec, fail_on_stopiteration from pyspark import shuffle @@ -90,8 +90,9 @@ def wrap_scalar_pandas_udf(f, return_type): def verify_result_length(*a): result = f(*a) if not hasattr(result, "__len__"): + pd_type = "Pandas.DataFrame" if type(return_type) == StructType else "Pandas.Series" raise TypeError("Return type of the user-defined function should be " - "Pandas.Series, but is {}".format(type(result))) + "{}, but is {}".format(pd_type, type(result))) if len(result) != len(a[0]): raise RuntimeError("Result vector from pandas_udf was not the required length: " "expected %d, got %d" % (len(a[0]), len(result))) @@ -100,10 +101,7 @@ def verify_result_length(*a): return lambda *a: (verify_result_length(*a), arrow_return_type) -def wrap_grouped_map_pandas_udf(f, return_type, argspec, runner_conf): - assign_cols_by_name = runner_conf.get( - "spark.sql.legacy.execution.pandas.groupedMap.assignColumnsByName", "true") - assign_cols_by_name = assign_cols_by_name.lower() == "true" +def wrap_grouped_map_pandas_udf(f, return_type, argspec): def wrapped(key_series, value_series): import pandas as pd @@ -122,15 +120,9 @@ def wrapped(key_series, value_series): "Number of columns of the returned pandas.DataFrame " "doesn't match specified schema. " "Expected: {} Actual: {}".format(len(return_type), len(result.columns))) + return result - # Assign result columns by schema name if user labeled with strings, else use position - if assign_cols_by_name and any(isinstance(name, basestring) for name in result.columns): - return [(result[field.name], to_arrow_type(field.dataType)) for field in return_type] - else: - return [(result[result.columns[i]], to_arrow_type(field.dataType)) - for i, field in enumerate(return_type)] - - return wrapped + return lambda k, v: [(wrapped(k, v), to_arrow_type(return_type))] def wrap_grouped_agg_pandas_udf(f, return_type): @@ -224,7 +216,7 @@ def read_single_udf(pickleSer, infile, eval_type, runner_conf, udf_index): return arg_offsets, wrap_scalar_pandas_udf(func, return_type) elif eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF: argspec = _get_argspec(row_func) # signature was lost when wrapping it - return arg_offsets, wrap_grouped_map_pandas_udf(func, return_type, argspec, runner_conf) + return arg_offsets, wrap_grouped_map_pandas_udf(func, return_type, argspec) elif eval_type == PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF: return arg_offsets, wrap_grouped_agg_pandas_udf(func, return_type) elif eval_type == PythonEvalType.SQL_WINDOW_AGG_PANDAS_UDF: @@ -254,7 +246,16 @@ def read_udfs(pickleSer, infile, eval_type): timezone = runner_conf.get("spark.sql.session.timeZone", None) safecheck = runner_conf.get("spark.sql.execution.pandas.arrowSafeTypeConversion", "false").lower() == 'true' - ser = ArrowStreamPandasSerializer(timezone, safecheck) + # Used by SQL_GROUPED_MAP_PANDAS_UDF and SQL_SCALAR_PANDAS_UDF when returning StructType + assign_cols_by_name = runner_conf.get( + "spark.sql.legacy.execution.pandas.groupedMap.assignColumnsByName", "true")\ + .lower() == "true" + + # Scalar Pandas UDF handles struct type arguments as pandas DataFrames instead of + # pandas Series. See SPARK-27240. + df_for_struct = eval_type == PythonEvalType.SQL_SCALAR_PANDAS_UDF + ser = ArrowStreamPandasUDFSerializer(timezone, safecheck, assign_cols_by_name, + df_for_struct) else: ser = BatchedSerializer(PickleSerializer(), 100) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala index 2bf6a58b55658..884dc8c6215ff 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala @@ -31,7 +31,6 @@ import org.apache.arrow.vector.ipc.message.{ArrowRecordBatch, MessageSerializer} import org.apache.spark.TaskContext import org.apache.spark.api.java.JavaRDD import org.apache.spark.network.util.JavaUtils -import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, SQLContext} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.types._ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala index e9cff1a5a2007..ce755ffb7c9fd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala @@ -29,6 +29,7 @@ import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistrib import org.apache.spark.sql.execution.{GroupedIterator, SparkPlan, UnaryExecNode} import org.apache.spark.sql.execution.arrow.ArrowUtils import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch, ColumnVector} /** * Physical node for [[org.apache.spark.sql.catalyst.plans.logical.FlatMapGroupsInPandas]] @@ -145,7 +146,16 @@ case class FlatMapGroupsInPandasExec( sessionLocalTimeZone, pythonRunnerConf).compute(grouped, context.partitionId(), context) - columnarBatchIter.flatMap(_.rowIterator.asScala).map(UnsafeProjection.create(output, output)) + val unsafeProj = UnsafeProjection.create(output, output) + + columnarBatchIter.flatMap { batch => + // Grouped Map UDF returns a StructType column in ColumnarBatch, select the children here + val structVector = batch.column(0).asInstanceOf[ArrowColumnVector] + val outputVectors = output.indices.map(structVector.getChild) + val flattenedBatch = new ColumnarBatch(outputVectors.toArray) + flattenedBatch.setNumRows(batch.numRows()) + flattenedBatch.rowIterator.asScala + }.map(unsafeProj) } } }