Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add error reporting for BatchConverter match failure #24022

Merged
merged 6 commits into from
Nov 15, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 25 additions & 8 deletions sdks/python/apache_beam/typehints/arrow_type_compatibility.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,13 +303,19 @@ def __init__(self, element_type: RowTypeConstraint):
self._arrow_schema = arrow_schema

@staticmethod
@BatchConverter.register
def from_typehints(element_type,
batch_type) -> Optional['PyarrowBatchConverter']:
if isinstance(element_type, RowTypeConstraint) and batch_type == pa.Table:
return PyarrowBatchConverter(element_type)
assert batch_type == pa.Table

return None
if not isinstance(element_type, RowTypeConstraint):
element_type = RowTypeConstraint.from_user_type(element_type)
if element_type is None:
raise TypeError(
"Element type must be compatible with Beam Schemas ("
"https://beam.apache.org/documentation/programming-guide/#schemas) "
"for batch type pa.Table.")

return PyarrowBatchConverter(element_type)

def produce_batch(self, elements):
arrays = [
Expand Down Expand Up @@ -358,13 +364,11 @@ def __init__(self, element_type: type):
self._arrow_type = _arrow_type_from_beam_fieldtype(beam_fieldtype)

@staticmethod
@BatchConverter.register
def from_typehints(element_type,
batch_type) -> Optional['PyarrowArrayBatchConverter']:
if batch_type == pa.Array:
return PyarrowArrayBatchConverter(element_type)
assert batch_type == pa.Array

return None
return PyarrowArrayBatchConverter(element_type)

def produce_batch(self, elements):
return pa.array(list(elements), type=self._arrow_type)
Expand All @@ -382,3 +386,16 @@ def get_length(self, batch: pa.Array):

def estimate_byte_size(self, batch: pa.Array):
return batch.nbytes


@BatchConverter.register(name="pyarrow")
def create_pyarrow_batch_converter(
element_type: type, batch_type: type) -> BatchConverter:
if batch_type == pa.Table:
return PyarrowBatchConverter.from_typehints(
element_type=element_type, batch_type=batch_type)
elif batch_type == pa.Array:
return PyarrowArrayBatchConverter.from_typehints(
element_type=element_type, batch_type=batch_type)

raise TypeError("batch type must be pa.Table or pa.Array")
24 changes: 24 additions & 0 deletions sdks/python/apache_beam/typehints/arrow_type_compatibility_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import logging
import unittest
from typing import Any
from typing import Optional

import pyarrow as pa
Expand Down Expand Up @@ -192,6 +193,29 @@ def test_hash(self):
self.assertEqual(hash(self.create_batch_converter()), hash(self.converter))


class ArrowBatchConverterErrorsTest(unittest.TestCase):
@parameterized.expand([
(
pa.RecordBatch,
row_type.RowTypeConstraint.from_fields([
("bar", Optional[float]), # noqa: F821
("baz", Optional[str]), # noqa: F821
]),
r'batch type must be pa\.Table or pa\.Array',
),
(
pa.Table,
Any,
r'Element type must be compatible with Beam Schemas',
),
])
def test_construction_errors(
self, batch_typehint, element_typehint, error_regex):
with self.assertRaisesRegex(TypeError, error_regex):
BatchConverter.from_typehints(
element_type=element_typehint, batch_type=batch_typehint)


if __name__ == '__main__':
logging.getLogger().setLevel(logging.INFO)
unittest.main()
82 changes: 49 additions & 33 deletions sdks/python/apache_beam/typehints/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from typing import Generic
from typing import Iterator
from typing import List
from typing import Mapping
from typing import Optional
from typing import Sequence
from typing import TypeVar
Expand All @@ -44,7 +45,8 @@
B = TypeVar('B')
E = TypeVar('E')

BATCH_CONVERTER_REGISTRY: List[Callable[[type, type], 'BatchConverter']] = []
BatchConverterConstructor = Callable[[type, type], 'BatchConverter']
BATCH_CONVERTER_REGISTRY: Mapping[str, BatchConverterConstructor] = {}

__all__ = ['BatchConverter']

Expand Down Expand Up @@ -72,26 +74,34 @@ def estimate_byte_size(self, batch):
raise NotImplementedError

@staticmethod
def register(
batch_converter_constructor: Callable[[type, type], 'BatchConverter']):
BATCH_CONVERTER_REGISTRY.append(batch_converter_constructor)
return batch_converter_constructor
def register(*, name: str):
def do_registration(
batch_converter_constructor: Callable[[type, type], 'BatchConverter']):
if name in BATCH_CONVERTER_REGISTRY:
raise AssertionError(
f"Attempted to register two batch converters with name {name}")

BATCH_CONVERTER_REGISTRY[name] = batch_converter_constructor
return batch_converter_constructor

return do_registration

@staticmethod
def from_typehints(*, element_type, batch_type) -> 'BatchConverter':
element_type = typehints.normalize(element_type)
batch_type = typehints.normalize(batch_type)
for constructor in BATCH_CONVERTER_REGISTRY:
result = constructor(element_type, batch_type)
if result is not None:
return result

# TODO(https://github.com/apache/beam/issues/21654): Aggregate error
# information from the failed BatchConverter matches instead of this
# generic error.
errors = {}
for name, constructor in BATCH_CONVERTER_REGISTRY.items():
try:
return constructor(element_type, batch_type)
except TypeError as e:
errors[name] = e.args[0]

error_summaries = '\n\n'.join(
f"{name}:\n\t{msg}" for name, msg in errors.items())
raise TypeError(
f"Unable to find BatchConverter for element_type {element_type!r} and "
f"batch_type {batch_type!r}")
f"Unable to find BatchConverter for element_type={element_type!r} and "
f"batch_type={batch_type!r}. Error summaries:\n\n{error_summaries}")

@property
def batch_type(self):
Expand Down Expand Up @@ -124,13 +134,13 @@ def __init__(self, batch_type, element_type):
self.element_coder = coders.registry.get_coder(element_type)

@staticmethod
@BatchConverter.register
@BatchConverter.register(name="list")
def from_typehints(element_type, batch_type):
if (isinstance(batch_type, typehints.ListConstraint) and
batch_type.inner_type == element_type):
return ListBatchConverter(batch_type, element_type)
else:
return None
if (not isinstance(batch_type, typehints.ListConstraint) or
batch_type.inner_type != element_type):
raise TypeError("batch type must be List[T] for element type T")

return ListBatchConverter(batch_type, element_type)

def produce_batch(self, elements):
return list(elements)
Expand Down Expand Up @@ -173,29 +183,35 @@ def __init__(
self.partition_dimension = partition_dimension

@staticmethod
@BatchConverter.register
@BatchConverter.register(name="numpy")
def from_typehints(element_type,
batch_type) -> Optional['NumpyBatchConverter']:
if not isinstance(element_type, NumpyTypeHint.NumpyTypeConstraint):
try:
element_type = NumpyArray[element_type, ()]
except TypeError:
# TODO: Is there a better way to detect if element_type is a dtype?
return None
except TypeError as e:
raise TypeError("Element type is not a dtype") from e

if not isinstance(batch_type, NumpyTypeHint.NumpyTypeConstraint):
if not batch_type == np.ndarray:
# TODO: Include explanation for mismatch?
return None
raise TypeError(
"batch type must be np.ndarray or "
"beam.typehints.batch.NumpyArray[..]")
batch_type = NumpyArray[element_type.dtype, (N, )]

if not batch_type.dtype == element_type.dtype:
return None
batch_shape = list(batch_type.shape)
partition_dimension = batch_shape.index(N)
batch_shape.pop(partition_dimension)
if not tuple(batch_shape) == element_type.shape:
return None
raise TypeError(
"batch type and element type must have equivalent dtypes "
f"(batch={batch_type.dtype}, element={element_type.dtype})")

computed_element_shape = list(batch_type.shape)
partition_dimension = computed_element_shape.index(N)
computed_element_shape.pop(partition_dimension)
if not tuple(computed_element_shape) == element_type.shape:
raise TypeError(
"Failed to align batch type's batch dimension with element type. "
f"(batch type dimensions: {batch_type.shape}, element type "
f"dimenstions: {element_type.shape}")

return NumpyBatchConverter(
batch_type,
Expand Down
32 changes: 32 additions & 0 deletions sdks/python/apache_beam/typehints/batch_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,38 @@ def test_hash(self):
self.assertEqual(hash(self.create_batch_converter()), hash(self.converter))


class BatchConverterErrorsTest(unittest.TestCase):
@parameterized.expand([
(
typing.List[int],
str,
r'batch type must be List\[T\] for element type T',
),
(
np.ndarray,
typing.Any,
r'Element type is not a dtype',
),
(
np.array,
np.int64,
(
r'batch type must be np\.ndarray or '
r'beam\.typehints\.batch\.NumpyArray\[\.\.\]'),
),
(
NumpyArray[np.int64, (3, N, 2)],
NumpyArray[np.int64, (3, 7)],
r'Failed to align batch type\'s batch dimension',
),
])
def test_construction_errors(
self, batch_typehint, element_typehint, error_regex):
with self.assertRaisesRegex(TypeError, error_regex):
BatchConverter.from_typehints(
element_type=element_typehint, batch_type=batch_typehint)


@contextlib.contextmanager
def temp_seed(seed):
state = random.getstate()
Expand Down
15 changes: 8 additions & 7 deletions sdks/python/apache_beam/typehints/pandas_type_compatibility.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def dtype_to_fieldtype(dtype):
return Any


@BatchConverter.register
@BatchConverter.register(name="pandas")
def create_pandas_batch_converter(
element_type: type, batch_type: type) -> BatchConverter:
if batch_type == pd.DataFrame:
Expand All @@ -146,7 +146,7 @@ def create_pandas_batch_converter(
return SeriesBatchConverter.from_typehints(
element_type=element_type, batch_type=batch_type)

return None
raise TypeError("batch type must be pd.Series or pd.DataFrame")


class DataFrameBatchConverter(BatchConverter):
Expand All @@ -160,13 +160,15 @@ def __init__(
@staticmethod
def from_typehints(element_type,
batch_type) -> Optional['DataFrameBatchConverter']:
if not batch_type == pd.DataFrame:
return None
assert batch_type == pd.DataFrame

if not isinstance(element_type, RowTypeConstraint):
element_type = RowTypeConstraint.from_user_type(element_type)
if element_type is None:
return None
raise TypeError(
"Element type must be compatible with Beam Schemas ("
"https://beam.apache.org/documentation/programming-guide/#schemas) "
"for batch type pd.DataFrame")

index_columns = [
field_name
Expand Down Expand Up @@ -275,8 +277,7 @@ def unbatch(series):
@staticmethod
def from_typehints(element_type,
batch_type) -> Optional['SeriesBatchConverter']:
if not batch_type == pd.Series:
return None
assert batch_type == pd.Series

dtype = dtype_from_typehint(element_type)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
"""Unit tests for pandas batched type converters."""

import unittest
from typing import Any
from typing import Optional

import numpy as np
Expand Down Expand Up @@ -115,7 +116,7 @@
dtype=pd.StringDtype()),
},
])
class DataFrameBatchConverterTest(unittest.TestCase):
class PandasBatchConverterTest(unittest.TestCase):
def create_batch_converter(self):
return BatchConverter.from_typehints(
element_type=self.element_typehint, batch_type=self.batch_typehint)
Expand Down Expand Up @@ -208,5 +209,28 @@ def test_hash(self):
self.assertEqual(hash(self.create_batch_converter()), hash(self.converter))


class PandasBatchConverterErrorsTest(unittest.TestCase):
@parameterized.expand([
(
Any,
row_type.RowTypeConstraint.from_fields([
("bar", Optional[float]), # noqa: F821
("baz", Optional[str]), # noqa: F821
]),
r'batch type must be pd\.Series or pd\.DataFrame',
),
(
pd.DataFrame,
Any,
r'Element type must be compatible with Beam Schemas',
),
])
def test_construction_errors(
self, batch_typehint, element_typehint, error_regex):
with self.assertRaisesRegex(TypeError, error_regex):
BatchConverter.from_typehints(
element_type=element_typehint, batch_type=batch_typehint)


if __name__ == '__main__':
unittest.main()
Loading