diff --git a/sdks/python/apache_beam/typehints/arrow_type_compatibility.py b/sdks/python/apache_beam/typehints/arrow_type_compatibility.py index cad6ac8751cac..c8e425f0e96a7 100644 --- a/sdks/python/apache_beam/typehints/arrow_type_compatibility.py +++ b/sdks/python/apache_beam/typehints/arrow_type_compatibility.py @@ -303,13 +303,19 @@ def __init__(self, element_type: RowTypeConstraint): self._arrow_schema = arrow_schema @staticmethod - @BatchConverter.register def from_typehints(element_type, batch_type) -> Optional['PyarrowBatchConverter']: - if isinstance(element_type, RowTypeConstraint) and batch_type == pa.Table: - return PyarrowBatchConverter(element_type) + assert batch_type == pa.Table - return None + if not isinstance(element_type, RowTypeConstraint): + element_type = RowTypeConstraint.from_user_type(element_type) + if element_type is None: + raise TypeError( + "Element type must be compatible with Beam Schemas (" + "https://beam.apache.org/documentation/programming-guide/#schemas) " + "for batch type pa.Table.") + + return PyarrowBatchConverter(element_type) def produce_batch(self, elements): arrays = [ @@ -358,13 +364,11 @@ def __init__(self, element_type: type): self._arrow_type = _arrow_type_from_beam_fieldtype(beam_fieldtype) @staticmethod - @BatchConverter.register def from_typehints(element_type, batch_type) -> Optional['PyarrowArrayBatchConverter']: - if batch_type == pa.Array: - return PyarrowArrayBatchConverter(element_type) + assert batch_type == pa.Array - return None + return PyarrowArrayBatchConverter(element_type) def produce_batch(self, elements): return pa.array(list(elements), type=self._arrow_type) @@ -382,3 +386,16 @@ def get_length(self, batch: pa.Array): def estimate_byte_size(self, batch: pa.Array): return batch.nbytes + + +@BatchConverter.register(name="pyarrow") +def create_pyarrow_batch_converter( + element_type: type, batch_type: type) -> BatchConverter: + if batch_type == pa.Table: + return PyarrowBatchConverter.from_typehints( + element_type=element_type, batch_type=batch_type) + elif batch_type == pa.Array: + return PyarrowArrayBatchConverter.from_typehints( + element_type=element_type, batch_type=batch_type) + + raise TypeError("batch type must be pa.Table or pa.Array") diff --git a/sdks/python/apache_beam/typehints/arrow_type_compatibility_test.py b/sdks/python/apache_beam/typehints/arrow_type_compatibility_test.py index 6a8649cff1eaa..e708b151d9056 100644 --- a/sdks/python/apache_beam/typehints/arrow_type_compatibility_test.py +++ b/sdks/python/apache_beam/typehints/arrow_type_compatibility_test.py @@ -19,6 +19,7 @@ import logging import unittest +from typing import Any from typing import Optional import pyarrow as pa @@ -192,6 +193,29 @@ def test_hash(self): self.assertEqual(hash(self.create_batch_converter()), hash(self.converter)) +class ArrowBatchConverterErrorsTest(unittest.TestCase): + @parameterized.expand([ + ( + pa.RecordBatch, + row_type.RowTypeConstraint.from_fields([ + ("bar", Optional[float]), # noqa: F821 + ("baz", Optional[str]), # noqa: F821 + ]), + r'batch type must be pa\.Table or pa\.Array', + ), + ( + pa.Table, + Any, + r'Element type must be compatible with Beam Schemas', + ), + ]) + def test_construction_errors( + self, batch_typehint, element_typehint, error_regex): + with self.assertRaisesRegex(TypeError, error_regex): + BatchConverter.from_typehints( + element_type=element_typehint, batch_type=batch_typehint) + + if __name__ == '__main__': logging.getLogger().setLevel(logging.INFO) unittest.main() diff --git a/sdks/python/apache_beam/typehints/batch.py b/sdks/python/apache_beam/typehints/batch.py index de6c7fb715727..73cfc2bfd08bb 100644 --- a/sdks/python/apache_beam/typehints/batch.py +++ b/sdks/python/apache_beam/typehints/batch.py @@ -30,6 +30,7 @@ from typing import Generic from typing import Iterator from typing import List +from typing import Mapping from typing import Optional from typing import Sequence from typing import TypeVar @@ -44,7 +45,8 @@ B = TypeVar('B') E = TypeVar('E') -BATCH_CONVERTER_REGISTRY: List[Callable[[type, type], 'BatchConverter']] = [] +BatchConverterConstructor = Callable[[type, type], 'BatchConverter'] +BATCH_CONVERTER_REGISTRY: Mapping[str, BatchConverterConstructor] = {} __all__ = ['BatchConverter'] @@ -72,26 +74,34 @@ def estimate_byte_size(self, batch): raise NotImplementedError @staticmethod - def register( - batch_converter_constructor: Callable[[type, type], 'BatchConverter']): - BATCH_CONVERTER_REGISTRY.append(batch_converter_constructor) - return batch_converter_constructor + def register(*, name: str): + def do_registration( + batch_converter_constructor: Callable[[type, type], 'BatchConverter']): + if name in BATCH_CONVERTER_REGISTRY: + raise AssertionError( + f"Attempted to register two batch converters with name {name}") + + BATCH_CONVERTER_REGISTRY[name] = batch_converter_constructor + return batch_converter_constructor + + return do_registration @staticmethod def from_typehints(*, element_type, batch_type) -> 'BatchConverter': element_type = typehints.normalize(element_type) batch_type = typehints.normalize(batch_type) - for constructor in BATCH_CONVERTER_REGISTRY: - result = constructor(element_type, batch_type) - if result is not None: - return result - - # TODO(https://github.com/apache/beam/issues/21654): Aggregate error - # information from the failed BatchConverter matches instead of this - # generic error. + errors = {} + for name, constructor in BATCH_CONVERTER_REGISTRY.items(): + try: + return constructor(element_type, batch_type) + except TypeError as e: + errors[name] = e.args[0] + + error_summaries = '\n\n'.join( + f"{name}:\n\t{msg}" for name, msg in errors.items()) raise TypeError( - f"Unable to find BatchConverter for element_type {element_type!r} and " - f"batch_type {batch_type!r}") + f"Unable to find BatchConverter for element_type={element_type!r} and " + f"batch_type={batch_type!r}. Error summaries:\n\n{error_summaries}") @property def batch_type(self): @@ -124,13 +134,13 @@ def __init__(self, batch_type, element_type): self.element_coder = coders.registry.get_coder(element_type) @staticmethod - @BatchConverter.register + @BatchConverter.register(name="list") def from_typehints(element_type, batch_type): - if (isinstance(batch_type, typehints.ListConstraint) and - batch_type.inner_type == element_type): - return ListBatchConverter(batch_type, element_type) - else: - return None + if (not isinstance(batch_type, typehints.ListConstraint) or + batch_type.inner_type != element_type): + raise TypeError("batch type must be List[T] for element type T") + + return ListBatchConverter(batch_type, element_type) def produce_batch(self, elements): return list(elements) @@ -173,29 +183,35 @@ def __init__( self.partition_dimension = partition_dimension @staticmethod - @BatchConverter.register + @BatchConverter.register(name="numpy") def from_typehints(element_type, batch_type) -> Optional['NumpyBatchConverter']: if not isinstance(element_type, NumpyTypeHint.NumpyTypeConstraint): try: element_type = NumpyArray[element_type, ()] - except TypeError: - # TODO: Is there a better way to detect if element_type is a dtype? - return None + except TypeError as e: + raise TypeError("Element type is not a dtype") from e if not isinstance(batch_type, NumpyTypeHint.NumpyTypeConstraint): if not batch_type == np.ndarray: - # TODO: Include explanation for mismatch? - return None + raise TypeError( + "batch type must be np.ndarray or " + "beam.typehints.batch.NumpyArray[..]") batch_type = NumpyArray[element_type.dtype, (N, )] if not batch_type.dtype == element_type.dtype: - return None - batch_shape = list(batch_type.shape) - partition_dimension = batch_shape.index(N) - batch_shape.pop(partition_dimension) - if not tuple(batch_shape) == element_type.shape: - return None + raise TypeError( + "batch type and element type must have equivalent dtypes " + f"(batch={batch_type.dtype}, element={element_type.dtype})") + + computed_element_shape = list(batch_type.shape) + partition_dimension = computed_element_shape.index(N) + computed_element_shape.pop(partition_dimension) + if not tuple(computed_element_shape) == element_type.shape: + raise TypeError( + "Failed to align batch type's batch dimension with element type. " + f"(batch type dimensions: {batch_type.shape}, element type " + f"dimenstions: {element_type.shape}") return NumpyBatchConverter( batch_type, diff --git a/sdks/python/apache_beam/typehints/batch_test.py b/sdks/python/apache_beam/typehints/batch_test.py index a6ea003dd496e..3fbad76fce06b 100644 --- a/sdks/python/apache_beam/typehints/batch_test.py +++ b/sdks/python/apache_beam/typehints/batch_test.py @@ -149,6 +149,38 @@ def test_hash(self): self.assertEqual(hash(self.create_batch_converter()), hash(self.converter)) +class BatchConverterErrorsTest(unittest.TestCase): + @parameterized.expand([ + ( + typing.List[int], + str, + r'batch type must be List\[T\] for element type T', + ), + ( + np.ndarray, + typing.Any, + r'Element type is not a dtype', + ), + ( + np.array, + np.int64, + ( + r'batch type must be np\.ndarray or ' + r'beam\.typehints\.batch\.NumpyArray\[\.\.\]'), + ), + ( + NumpyArray[np.int64, (3, N, 2)], + NumpyArray[np.int64, (3, 7)], + r'Failed to align batch type\'s batch dimension', + ), + ]) + def test_construction_errors( + self, batch_typehint, element_typehint, error_regex): + with self.assertRaisesRegex(TypeError, error_regex): + BatchConverter.from_typehints( + element_type=element_typehint, batch_type=batch_typehint) + + @contextlib.contextmanager def temp_seed(seed): state = random.getstate() diff --git a/sdks/python/apache_beam/typehints/pandas_type_compatibility.py b/sdks/python/apache_beam/typehints/pandas_type_compatibility.py index a143f9c4ef379..ca9523f283490 100644 --- a/sdks/python/apache_beam/typehints/pandas_type_compatibility.py +++ b/sdks/python/apache_beam/typehints/pandas_type_compatibility.py @@ -136,7 +136,7 @@ def dtype_to_fieldtype(dtype): return Any -@BatchConverter.register +@BatchConverter.register(name="pandas") def create_pandas_batch_converter( element_type: type, batch_type: type) -> BatchConverter: if batch_type == pd.DataFrame: @@ -146,7 +146,7 @@ def create_pandas_batch_converter( return SeriesBatchConverter.from_typehints( element_type=element_type, batch_type=batch_type) - return None + raise TypeError("batch type must be pd.Series or pd.DataFrame") class DataFrameBatchConverter(BatchConverter): @@ -160,13 +160,15 @@ def __init__( @staticmethod def from_typehints(element_type, batch_type) -> Optional['DataFrameBatchConverter']: - if not batch_type == pd.DataFrame: - return None + assert batch_type == pd.DataFrame if not isinstance(element_type, RowTypeConstraint): element_type = RowTypeConstraint.from_user_type(element_type) if element_type is None: - return None + raise TypeError( + "Element type must be compatible with Beam Schemas (" + "https://beam.apache.org/documentation/programming-guide/#schemas) " + "for batch type pd.DataFrame") index_columns = [ field_name @@ -275,8 +277,7 @@ def unbatch(series): @staticmethod def from_typehints(element_type, batch_type) -> Optional['SeriesBatchConverter']: - if not batch_type == pd.Series: - return None + assert batch_type == pd.Series dtype = dtype_from_typehint(element_type) diff --git a/sdks/python/apache_beam/typehints/pandas_type_compatibility_test.py b/sdks/python/apache_beam/typehints/pandas_type_compatibility_test.py index 0ee9b1178a9ba..5a8dc72dd4b99 100644 --- a/sdks/python/apache_beam/typehints/pandas_type_compatibility_test.py +++ b/sdks/python/apache_beam/typehints/pandas_type_compatibility_test.py @@ -18,6 +18,7 @@ """Unit tests for pandas batched type converters.""" import unittest +from typing import Any from typing import Optional import numpy as np @@ -115,7 +116,7 @@ dtype=pd.StringDtype()), }, ]) -class DataFrameBatchConverterTest(unittest.TestCase): +class PandasBatchConverterTest(unittest.TestCase): def create_batch_converter(self): return BatchConverter.from_typehints( element_type=self.element_typehint, batch_type=self.batch_typehint) @@ -208,5 +209,28 @@ def test_hash(self): self.assertEqual(hash(self.create_batch_converter()), hash(self.converter)) +class PandasBatchConverterErrorsTest(unittest.TestCase): + @parameterized.expand([ + ( + Any, + row_type.RowTypeConstraint.from_fields([ + ("bar", Optional[float]), # noqa: F821 + ("baz", Optional[str]), # noqa: F821 + ]), + r'batch type must be pd\.Series or pd\.DataFrame', + ), + ( + pd.DataFrame, + Any, + r'Element type must be compatible with Beam Schemas', + ), + ]) + def test_construction_errors( + self, batch_typehint, element_typehint, error_regex): + with self.assertRaisesRegex(TypeError, error_regex): + BatchConverter.from_typehints( + element_type=element_typehint, batch_type=batch_typehint) + + if __name__ == '__main__': unittest.main() diff --git a/sdks/python/apache_beam/typehints/pytorch_type_compatibility.py b/sdks/python/apache_beam/typehints/pytorch_type_compatibility.py index fbecb6d5105bd..f008174bcc03c 100644 --- a/sdks/python/apache_beam/typehints/pytorch_type_compatibility.py +++ b/sdks/python/apache_beam/typehints/pytorch_type_compatibility.py @@ -37,29 +37,31 @@ def __init__( self.partition_dimension = partition_dimension @staticmethod - @BatchConverter.register + @BatchConverter.register(name="pytorch") def from_typehints(element_type, batch_type) -> Optional['PytorchBatchConverter']: if not isinstance(element_type, PytorchTypeHint.PytorchTypeConstraint): - try: - element_type = PytorchTensor[element_type, ()] - except TypeError: - # TODO: Is there a better way to detect if element_type is a dtype? - return None + element_type = PytorchTensor[element_type, ()] if not isinstance(batch_type, PytorchTypeHint.PytorchTypeConstraint): if not batch_type == torch.Tensor: - # TODO: Include explanation for mismatch? - return None + raise TypeError( + "batch type must be torch.Tensor or " + "beam.typehints.pytorch_type_compatibility.PytorchTensor[..]") batch_type = PytorchTensor[element_type.dtype, (N, )] if not batch_type.dtype == element_type.dtype: - return None - batch_shape = list(batch_type.shape) - partition_dimension = batch_shape.index(N) - batch_shape.pop(partition_dimension) - if not tuple(batch_shape) == element_type.shape: - return None + raise TypeError( + "batch type and element type must have equivalent dtypes " + f"(batch={batch_type.dtype}, element={element_type.dtype})") + computed_element_shape = list(batch_type.shape) + partition_dimension = computed_element_shape.index(N) + computed_element_shape.pop(partition_dimension) + if not tuple(computed_element_shape) == element_type.shape: + raise TypeError( + "Could not align batch type's batch dimension with element type. " + f"(batch type dimensions: {batch_type.shape}, element type " + f"dimenstions: {element_type.shape}") return PytorchBatchConverter( batch_type, diff --git a/sdks/python/apache_beam/typehints/pytorch_type_compatibility_test.py b/sdks/python/apache_beam/typehints/pytorch_type_compatibility_test.py index e851d4679ccb9..d1f5c0d271ee7 100644 --- a/sdks/python/apache_beam/typehints/pytorch_type_compatibility_test.py +++ b/sdks/python/apache_beam/typehints/pytorch_type_compatibility_test.py @@ -18,6 +18,7 @@ """Unit tests for pytorch_type_compabitility.""" import unittest +from typing import Any import pytest from parameterized import parameterized @@ -134,5 +135,32 @@ def test_hash(self): self.assertEqual(hash(self.create_batch_converter()), hash(self.converter)) +class PytorchBatchConverterErrorsTest(unittest.TestCase): + @parameterized.expand([ + ( + Any, + PytorchTensor[torch.int64, ()], + ( + r'batch type must be torch\.Tensor or ' + r'beam\.typehints\.pytorch_type_compatibility.PytorchTensor'), + ), + ( + PytorchTensor[torch.int64, (3, N, 2)], + PytorchTensor[torch.int64, (3, 7)], + r'Could not align batch type\'s batch dimension', + ), + ( + PytorchTensor[torch.int64, (N, 10)], + PytorchTensor[torch.float32, (10, )], + r'batch type and element type must have equivalent dtypes', + ), + ]) + def test_construction_errors( + self, batch_typehint, element_typehint, error_regex): + with self.assertRaisesRegex(TypeError, error_regex): + BatchConverter.from_typehints( + element_type=element_typehint, batch_type=batch_typehint) + + if __name__ == '__main__': unittest.main()