From 09238a57465e59fa31d0cb4e8d8a71b3b2f37b91 Mon Sep 17 00:00:00 2001 From: Brian Hulette Date: Mon, 8 Aug 2022 16:32:51 -0700 Subject: [PATCH] Use pandas_type_compatibility BatchConverters for dataframe.schemas utilities --- sdks/python/apache_beam/dataframe/convert.py | 7 +- sdks/python/apache_beam/dataframe/schemas.py | 125 +++++------------- .../apache_beam/dataframe/schemas_test.py | 29 +++- .../typehints/pandas_type_compatibility.py | 18 ++- 4 files changed, 76 insertions(+), 103 deletions(-) diff --git a/sdks/python/apache_beam/dataframe/convert.py b/sdks/python/apache_beam/dataframe/convert.py index 1a5da92e3b08b..cc92ae32c2986 100644 --- a/sdks/python/apache_beam/dataframe/convert.py +++ b/sdks/python/apache_beam/dataframe/convert.py @@ -29,8 +29,9 @@ from apache_beam import pvalue from apache_beam.dataframe import expressions from apache_beam.dataframe import frame_base -from apache_beam.dataframe import schemas from apache_beam.dataframe import transforms +from apache_beam.dataframe.schemas import element_typehint_from_dataframe_proxy +from apache_beam.dataframe.schemas import generate_proxy from apache_beam.typehints.pandas_type_compatibility import dtype_to_fieldtype if TYPE_CHECKING: @@ -70,7 +71,7 @@ def to_dataframe( # Attempt to come up with a reasonable, stable label by retrieving # the name of these variables in the calling context. label = 'BatchElements(%s)' % _var_name(pcoll, 2) - proxy = schemas.generate_proxy(pcoll.element_type) + proxy = generate_proxy(pcoll.element_type) shim_dofn: beam.DoFn if isinstance(proxy, pd.DataFrame): @@ -145,7 +146,7 @@ def process(self, element: pd.DataFrame) -> Iterable[pd.DataFrame]: yield element def infer_output_type(self, input_element_type): - return schemas.element_typehint_from_dataframe_proxy( + return element_typehint_from_dataframe_proxy( self._proxy, self._include_indexes) diff --git a/sdks/python/apache_beam/dataframe/schemas.py b/sdks/python/apache_beam/dataframe/schemas.py index 97c4dae62e460..0c0b9d8bcef17 100644 --- a/sdks/python/apache_beam/dataframe/schemas.py +++ b/sdks/python/apache_beam/dataframe/schemas.py @@ -36,16 +36,13 @@ import apache_beam as beam from apache_beam import typehints -from apache_beam.portability.api import schema_pb2 from apache_beam.transforms.util import BatchElements -from apache_beam.typehints.native_type_compatibility import _match_is_optional +from apache_beam.typehints.pandas_type_compatibility import create_pandas_batch_converter from apache_beam.typehints.pandas_type_compatibility import dtype_from_typehint from apache_beam.typehints.pandas_type_compatibility import dtype_to_fieldtype from apache_beam.typehints.row_type import RowTypeConstraint from apache_beam.typehints.schemas import named_fields_from_element_type -from apache_beam.typehints.schemas import named_tuple_from_schema -from apache_beam.typehints.schemas import named_tuple_to_schema -from apache_beam.utils import proto_utils +from apache_beam.typehints.typehints import normalize __all__ = ( 'BatchRowsAsDataFrame', @@ -69,18 +66,21 @@ def __init__(self, *args, proxy=None, **kwargs): self._proxy = proxy def expand(self, pcoll): - proxy = generate_proxy( - pcoll.element_type) if self._proxy is None else self._proxy - if isinstance(proxy, pd.DataFrame): - columns = proxy.columns - construct = lambda batch: pd.DataFrame.from_records( - batch, columns=columns) - elif isinstance(proxy, pd.Series): - dtype = proxy.dtype - construct = lambda batch: pd.Series(batch, dtype=dtype) + if self._proxy is not None: + # Generate typehint + proxy = self._proxy + element_typehint = element_typehint_from_proxy(proxy) else: - raise NotImplementedError("Unknown proxy type: %s" % proxy) - return pcoll | self._batch_elements_transform | beam.Map(construct) + # Generate proxy + proxy = generate_proxy(pcoll.element_type) + element_typehint = pcoll.element_type + + converter = create_pandas_batch_converter( + element_type=element_typehint, batch_type=type(proxy)) + + return ( + pcoll | self._batch_elements_transform + | beam.Map(converter.produce_batch)) def generate_proxy(element_type): @@ -117,6 +117,20 @@ def element_type_from_dataframe(proxy, include_indexes=False): return element_typehint_from_dataframe_proxy(proxy, include_indexes).user_type +def element_typehint_from_proxy( + proxy: pd.core.generic.NDFrame, include_indexes: bool = False): + if isinstance(proxy, pd.DataFrame): + return element_typehint_from_dataframe_proxy( + proxy, include_indexes=include_indexes) + elif isinstance(proxy, pd.Series): + if include_indexes: + import warnings + warnings.warn("TODO indexes cannot be included for series output") + return dtype_to_fieldtype(proxy.dtype) + else: + raise TypeError("TODO") + + def element_typehint_from_dataframe_proxy( proxy: pd.DataFrame, include_indexes: bool = False) -> RowTypeConstraint: @@ -168,82 +182,15 @@ def element_typehint_from_dataframe_proxy( return RowTypeConstraint.from_fields(fields, field_options=field_options) -class _BaseDataframeUnbatchDoFn(beam.DoFn): - def __init__(self, namedtuple_ctor): - self._namedtuple_ctor = namedtuple_ctor - - def _get_series(self, df): - raise NotImplementedError() - - def process(self, df): - # TODO: Only do null checks for nullable types - def make_null_checking_generator(series): - nulls = pd.isnull(series) - return (None if isnull else value for isnull, value in zip(nulls, series)) - - all_series = self._get_series(df) - iterators = [ - make_null_checking_generator(series) for series, - typehint in zip(all_series, self._namedtuple_ctor.__annotations__) - ] - - # TODO: Avoid materializing the rows. Produce an object that references the - # underlying dataframe - for values in zip(*iterators): - yield self._namedtuple_ctor(*values) - - def infer_output_type(self, input_type): - return self._namedtuple_ctor - - @classmethod - def _from_serialized_schema(cls, schema_str): - return cls( - named_tuple_from_schema( - proto_utils.parse_Bytes(schema_str, schema_pb2.Schema))) - - def __reduce__(self): - # when pickling, use bytes representation of the schema. - return ( - self._from_serialized_schema, - (named_tuple_to_schema(self._namedtuple_ctor).SerializeToString(), )) - - -class _UnbatchNoIndex(_BaseDataframeUnbatchDoFn): - def _get_series(self, df): - return [df[column] for column in df.columns] - - -class _UnbatchWithIndex(_BaseDataframeUnbatchDoFn): - def _get_series(self, df): - return [df.index.get_level_values(i) for i in range(len(df.index.names)) - ] + [df[column] for column in df.columns] - - def _unbatch_transform(proxy, include_indexes): - if isinstance(proxy, pd.DataFrame): - ctor = element_type_from_dataframe(proxy, include_indexes=include_indexes) - - return beam.ParDo( - _UnbatchWithIndex(ctor) if include_indexes else _UnbatchNoIndex(ctor)) - elif isinstance(proxy, pd.Series): - # Raise a TypeError if proxy has an unknown type - output_type = dtype_to_fieldtype(proxy.dtype) - # TODO: Should the index ever be included for a Series? - if _match_is_optional(output_type): - - def unbatch(series): - for isnull, value in zip(pd.isnull(series), series): - yield None if isnull else value - else: + element_typehint = normalize( + element_typehint_from_proxy(proxy, include_indexes=include_indexes)) - def unbatch(series): - yield from series + converter = create_pandas_batch_converter( + element_type=element_typehint, batch_type=type(proxy)) - return beam.FlatMap(unbatch).with_output_types(output_type) - # TODO: What about scalar inputs? - else: - raise TypeError( - "Proxy '%s' has unsupported type '%s'" % (proxy, type(proxy))) + return beam.FlatMap( + converter.explode_batch).with_output_types(element_typehint) @typehints.with_input_types(Union[pd.DataFrame, pd.Series]) diff --git a/sdks/python/apache_beam/dataframe/schemas_test.py b/sdks/python/apache_beam/dataframe/schemas_test.py index f019f82ddb3cb..47639a26bf3e5 100644 --- a/sdks/python/apache_beam/dataframe/schemas_test.py +++ b/sdks/python/apache_beam/dataframe/schemas_test.py @@ -34,6 +34,7 @@ from apache_beam.testing.test_pipeline import TestPipeline from apache_beam.testing.util import assert_that from apache_beam.testing.util import equal_to +from apache_beam.typehints import row_type from apache_beam.typehints import typehints from apache_beam.typehints.native_type_compatibility import match_is_named_tuple @@ -52,10 +53,7 @@ def check_df_pcoll_equal(actual): drop=True) sorted_expected = expected.sort_values( by=list(expected.columns)).reset_index(drop=True) - if not sorted_actual.equals(sorted_expected): - raise AssertionError( - 'Dataframes not equal: \n\nActual:\n%s\n\nExpected:\n%s' % - (sorted_actual, sorted_expected)) + pd.testing.assert_frame_equal(sorted_actual, sorted_expected) return check_df_pcoll_equal @@ -145,6 +143,8 @@ def test_simple_df(self): }, columns=['name', 'id', 'height']) + expected.name = expected.name.astype(pd.StringDtype()) + with TestPipeline() as p: res = ( p @@ -160,6 +160,7 @@ def test_simple_df_with_beam_row(self): 'height': list(float(i) for i in range(5)) }, columns=['name', 'id', 'height']) + expected.name = expected.name.astype(pd.StringDtype()) with TestPipeline() as p: res = ( @@ -235,8 +236,14 @@ def test_batch_with_df_transform(self): assert_that(res, equal_to([('Falcon', 375.), ('Parrot', 25.)])) def assert_typehints_equal(self, left, right): - left = typehints.normalize(left) - right = typehints.normalize(right) + def maybe_drop_rowtypeconstraint(typehint): + if isinstance(typehint, row_type.RowTypeConstraint): + return typehint.user_type + else: + return typehint + + left = maybe_drop_rowtypeconstraint(typehints.normalize(left)) + right = maybe_drop_rowtypeconstraint(typehints.normalize(right)) if match_is_named_tuple(left): self.assertTrue(match_is_named_tuple(right)) @@ -273,6 +280,16 @@ def test_unbatch_with_index(self, df_or_series, rows, _): assert_that(res, equal_to(rows)) + @parameterized.expand(SERIES_TESTS, name_func=test_name_func) + def test_unbatch_series_with_index_warns( + self, series, unused_rows, unused_type): + proxy = series[:0] + + with TestPipeline() as p: + input_pc = p | beam.Create([series[::2], series[1::2]]) + with self.assertWarns(UserWarning): + _ = input_pc | schemas.UnbatchPandas(proxy, include_indexes=True) + def test_unbatch_include_index_unnamed_index_raises(self): df = pd.DataFrame({'foo': [1, 2, 3, 4]}) proxy = df[:0] diff --git a/sdks/python/apache_beam/typehints/pandas_type_compatibility.py b/sdks/python/apache_beam/typehints/pandas_type_compatibility.py index c730400b21b8d..6d5c658dcf168 100644 --- a/sdks/python/apache_beam/typehints/pandas_type_compatibility.py +++ b/sdks/python/apache_beam/typehints/pandas_type_compatibility.py @@ -143,6 +143,19 @@ def dtype_to_fieldtype(dtype): __all__ = [] +@BatchConverter.register +def create_pandas_batch_converter( + element_type: type, batch_type: type) -> BatchConverter: + if batch_type == pd.DataFrame: + return DataFrameBatchConverter.from_typehints( + element_type=element_type, batch_type=batch_type) + elif batch_type == pd.Series: + return SeriesBatchConverter.from_typehints( + element_type=element_type, batch_type=batch_type) + + return None + + class DataFrameBatchConverter(BatchConverter): def __init__( self, @@ -152,7 +165,6 @@ def __init__( self._columns = [name for name, _ in element_type._fields] @staticmethod - @BatchConverter.register def from_typehints(element_type, batch_type) -> Optional['DataFrameBatchConverter']: if not batch_type == pd.DataFrame: @@ -268,16 +280,12 @@ def unbatch(series): self.explode_batch = unbatch @staticmethod - @BatchConverter.register def from_typehints(element_type, batch_type) -> Optional['SeriesBatchConverter']: if not batch_type == pd.Series: return None dtype = dtype_from_typehint(element_type) - if dtype == np.object: - # Don't create Any <-> Series[np.object] mapping - return None return SeriesBatchConverter(element_type, dtype)