From 934e11597886ea70eae0bfc4beeafbc1bcb1bd67 Mon Sep 17 00:00:00 2001 From: Joshua Hoskins Date: Mon, 14 Sep 2020 14:18:22 -0700 Subject: [PATCH] common: gremlin support tweaks add optional key parameter to column, table. rename Statistics -> Stat. add fixtures, streaming, utils to common. Update mypy to 761 to fix false errors. Fix flake8 complaints about extra line in __init__.py files Signed-off-by: Joshua Hoskins --- amundsen_common/__init__.py | 1 - amundsen_common/fixtures.py | 219 ++++++++++++++++ amundsen_common/log/__init__.py | 1 - amundsen_common/models/__init__.py | 1 - amundsen_common/models/dashboard.py | 1 - amundsen_common/models/enums.py | 22 ++ amundsen_common/models/table.py | 11 +- amundsen_common/models/user.py | 10 +- amundsen_common/utils/__init__.py | 0 amundsen_common/utils/streams.py | 390 ++++++++++++++++++++++++++++ amundsen_common/utils/utils.py | 12 + requirements.txt | 2 +- setup.cfg | 1 + tests/__init__.py | 1 - tests/unit/__init__.py | 1 - tests/unit/log/__init__.py | 1 - tests/unit/utils/__init__.py | 2 + tests/unit/utils/test_streams.py | 201 ++++++++++++++ 18 files changed, 860 insertions(+), 17 deletions(-) create mode 100644 amundsen_common/fixtures.py create mode 100644 amundsen_common/models/enums.py create mode 100644 amundsen_common/utils/__init__.py create mode 100644 amundsen_common/utils/streams.py create mode 100644 amundsen_common/utils/utils.py create mode 100644 tests/unit/utils/__init__.py create mode 100644 tests/unit/utils/test_streams.py diff --git a/amundsen_common/__init__.py b/amundsen_common/__init__.py index d66c0ef..f3145d7 100644 --- a/amundsen_common/__init__.py +++ b/amundsen_common/__init__.py @@ -1,3 +1,2 @@ # Copyright Contributors to the Amundsen project. # SPDX-License-Identifier: Apache-2.0 - diff --git a/amundsen_common/fixtures.py b/amundsen_common/fixtures.py new file mode 100644 index 0000000..fcf6fb9 --- /dev/null +++ b/amundsen_common/fixtures.py @@ -0,0 +1,219 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import string +from typing import Any, List, Optional + +from amundsen_common.models.table import (Column, ProgrammaticDescription, + Stat, Table, Tag, Application) +from amundsen_common.models.user import User + + +class Fixtures: + counter = 1000 + + @staticmethod + def next_int() -> int: + i = Fixtures.counter + Fixtures.counter += 1 + return i + + @staticmethod + def next_string(*, prefix: str = '', length: int = 10) -> str: + astr: str = prefix + \ + ''.join(Fixtures.next_item(items=list(string.ascii_lowercase)) for _ in range(length)) + \ + ('%06d' % Fixtures.next_int()) + return astr + + @staticmethod + def next_range() -> range: + return range(0, Fixtures.next_int() % 5) + + @staticmethod + def next_item(*, items: List[Any]) -> Any: + return items[Fixtures.next_int() % len(items)] + + @staticmethod + def next_database() -> str: + return Fixtures.next_item(items=list(["database1", "database2"])) + + @staticmethod + def next_application(*, application_id: Optional[str] = None) -> Application: + if not application_id: + application_id = Fixtures.next_string(prefix='ap', length=8) + application = Application(application_url=f'https://{application_id}.example.com', + description=f'{application_id} description', + name=application_id.capitalize(), + id=application_id) + return application + + @staticmethod + def next_tag(*, tag_name: Optional[str] = None) -> Tag: + if not tag_name: + tag_name = Fixtures.next_string(prefix='ta', length=8) + return Tag(tag_name=tag_name, tag_type='default') + + @staticmethod + def next_tags() -> List[Tag]: + return sorted([Fixtures.next_tag() for _ in Fixtures.next_range()]) + + @staticmethod + def next_description_source() -> str: + return Fixtures.next_string(prefix='de', length=8) + + @staticmethod + def next_description(*, text: Optional[str] = None, source: Optional[str] = None) -> ProgrammaticDescription: + if not text: + text = Fixtures.next_string(length=20) + if not source: + source = Fixtures.next_description_source() + return ProgrammaticDescription(text=text, source=source) + + @staticmethod + def next_col_type() -> str: + return Fixtures.next_item(items=['varchar', 'int', 'blob', 'timestamp', 'datetime']) + + @staticmethod + def next_column(*, + table_key: str, + sort_order: int, + name: Optional[str] = None) -> Column: + if not name: + name = Fixtures.next_string(prefix='co', length=8) + + return Column(name=name, + description=f'{name} description', + col_type=Fixtures.next_col_type(), + key=f'{table_key}/{name}', + sort_order=sort_order, + stats=[Stat(stat_type='num_rows', + stat_val=f'{Fixtures.next_int() * 100}', + start_epoch=None, + end_epoch=None)]) + + @staticmethod + def next_columns(*, + table_key: str, + randomize_pii: bool = False, + randomize_data_subject: bool = False) -> List[Column]: + return [Fixtures.next_column(table_key=table_key, + sort_order=i + ) for i in Fixtures.next_range()] + + @staticmethod + def next_descriptions() -> List[ProgrammaticDescription]: + return sorted([Fixtures.next_description() for _ in Fixtures.next_range()]) + + @staticmethod + def next_table(table: Optional[str] = None, + cluster: Optional[str] = None, + schema: Optional[str] = None, + database: Optional[str] = None, + tags: Optional[List[Tag]] = None, + application: Optional[Application] = None) -> Table: + """ + Returns a table for testing in the test_database + """ + if not database: + database = Fixtures.next_database() + + if not table: + table = Fixtures.next_string(prefix='tb', length=8) + + if not cluster: + cluster = Fixtures.next_string(prefix='cl', length=8) + + if not schema: + schema = Fixtures.next_string(prefix='sc', length=8) + + if not tags: + tags = Fixtures.next_tags() + + table_key: str = f'{database}://{cluster}.{schema}/{table}' + # TODO: add owners, watermarks, last_udpated_timestamp, source + return Table(database=database, + cluster=cluster, + schema=schema, + name=table, + key=table_key, + tags=tags, + table_writer=application, + table_readers=[], + description=f'{table} description', + programmatic_descriptions=Fixtures.next_descriptions(), + columns=Fixtures.next_columns(table_key=table_key), + is_view=False + ) + + @staticmethod + def next_user(*, user_id: Optional[str] = None, is_active: bool = True) -> User: + last_name = ''.join(Fixtures.next_item(items=list(string.ascii_lowercase)) for _ in range(6)).capitalize() + first_name = Fixtures.next_item(items=['alice', 'bob', 'carol', 'dan']).capitalize() + if not user_id: + user_id = Fixtures.next_string(prefix='us', length=8) + return User(user_id=user_id, + email=f'{user_id}@example.com', + is_active=is_active, + first_name=first_name, + last_name=last_name, + full_name=f'{first_name} {last_name}') + + +def next_int() -> int: + return Fixtures.next_int() + + +def next_string(**kwargs: Any) -> str: + return Fixtures.next_string(**kwargs) + + +def next_range() -> range: + return Fixtures.next_range() + + +def next_item(**kwargs: Any) -> Any: + return Fixtures.next_item(**kwargs) + + +def next_database() -> str: + return Fixtures.next_database() + + +def next_tag(**kwargs: Any) -> Tag: + return Fixtures.next_tag(**kwargs) + + +def next_tags() -> List[Tag]: + return Fixtures.next_tags() + + +def next_description_source() -> str: + return Fixtures.next_description_source() + + +def next_description(**kwargs: Any) -> ProgrammaticDescription: + return Fixtures.next_description(**kwargs) + + +def next_col_type() -> str: + return Fixtures.next_col_type() + + +def next_column(**kwargs: Any) -> Column: + return Fixtures.next_column(**kwargs) + + +def next_columns(**kwargs: Any) -> List[Column]: + return Fixtures.next_columns(**kwargs) + + +def next_descriptions() -> List[ProgrammaticDescription]: + return Fixtures.next_descriptions() + + +def next_table(**kwargs: Any) -> Table: + return Fixtures.next_table(**kwargs) + + +def next_user(**kwargs: Any) -> User: + return Fixtures.next_user(**kwargs) diff --git a/amundsen_common/log/__init__.py b/amundsen_common/log/__init__.py index d66c0ef..f3145d7 100644 --- a/amundsen_common/log/__init__.py +++ b/amundsen_common/log/__init__.py @@ -1,3 +1,2 @@ # Copyright Contributors to the Amundsen project. # SPDX-License-Identifier: Apache-2.0 - diff --git a/amundsen_common/models/__init__.py b/amundsen_common/models/__init__.py index d66c0ef..f3145d7 100644 --- a/amundsen_common/models/__init__.py +++ b/amundsen_common/models/__init__.py @@ -1,3 +1,2 @@ # Copyright Contributors to the Amundsen project. # SPDX-License-Identifier: Apache-2.0 - diff --git a/amundsen_common/models/dashboard.py b/amundsen_common/models/dashboard.py index c0d7681..a3c94c7 100644 --- a/amundsen_common/models/dashboard.py +++ b/amundsen_common/models/dashboard.py @@ -24,4 +24,3 @@ class DashboardSummarySchema(AttrsSchema): class Meta: target = DashboardSummary register_as_scheme = True - diff --git a/amundsen_common/models/enums.py b/amundsen_common/models/enums.py new file mode 100644 index 0000000..421c045 --- /dev/null +++ b/amundsen_common/models/enums.py @@ -0,0 +1,22 @@ +from enum import Enum, auto +from typing import Any + + +# Enums that derive from AutoEnum will have automatically generate +# values that are the lower case of the names +class AutoEnum(Enum): + @staticmethod + def _generate_next_value_(name: str, start: Any, count: int, last_values: Any) -> str: + return name + + +class RoleType(AutoEnum): + READ_ONLY = auto() + READ_WRITE = auto() + ADMIN = auto() + OWNER = auto() + + +class NodeLabel(Enum): + COLUMN = 1 + TABLE = 2 diff --git a/amundsen_common/models/table.py b/amundsen_common/models/table.py index 1493413..0637bc4 100644 --- a/amundsen_common/models/table.py +++ b/amundsen_common/models/table.py @@ -61,26 +61,27 @@ class Meta: @attr.s(auto_attribs=True, kw_only=True) -class Statistics: +class Stat: stat_type: str stat_val: Optional[str] = None start_epoch: Optional[int] = None end_epoch: Optional[int] = None -class StatisticsSchema(AttrsSchema): +class StatSchema(AttrsSchema): class Meta: - target = Statistics + target = Stat register_as_scheme = True @attr.s(auto_attribs=True, kw_only=True) class Column: name: str + key: Optional[str] = None description: Optional[str] = None col_type: str sort_order: int - stats: List[Statistics] = [] + stats: List[Stat] = [] class ColumnSchema(AttrsSchema): @@ -115,6 +116,7 @@ class Meta: target = Source register_as_scheme = True + @attr.s(auto_attribs=True, kw_only=True) class ResourceReport: name: str @@ -151,6 +153,7 @@ class Table: cluster: str schema: str name: str + key: Optional[str] = None tags: List[Tag] = [] badges: List[Badge] = [] table_readers: List[Reader] = [] diff --git a/amundsen_common/models/user.py b/amundsen_common/models/user.py index 55ff2a7..09b924d 100644 --- a/amundsen_common/models/user.py +++ b/amundsen_common/models/user.py @@ -1,7 +1,7 @@ # Copyright Contributors to the Amundsen project. # SPDX-License-Identifier: Apache-2.0 -from typing import Optional, Dict +from typing import Any, Optional, Dict import attr from marshmallow import ValidationError, validates_schema, pre_load @@ -38,7 +38,7 @@ class User: manager_id: Optional[str] = None role_name: Optional[str] = None profile_url: Optional[str] = None - other_key_values: Optional[Dict[str, str]] = attr.ib(factory=dict) + other_key_values: Optional[Dict[str, str]] = attr.ib(factory=dict) # type: ignore # TODO: Add frequent_used, bookmarked, & owned resources @@ -57,14 +57,14 @@ def _str_no_value(self, s: Optional[str]) -> bool: return False @pre_load - def preprocess_data(self, data: Dict) -> Dict: + def preprocess_data(self, data: Dict[str, Any]) -> Dict[str, Any]: if self._str_no_value(data.get('user_id')): data['user_id'] = data.get('email') if self._str_no_value(data.get('profile_url')): data['profile_url'] = '' if data.get('GET_PROFILE_URL'): - data['profile_url'] = data.get('GET_PROFILE_URL')(data['user_id']) + data['profile_url'] = data.get('GET_PROFILE_URL')(data['user_id']) # type: ignore first_name = data.get('first_name') last_name = data.get('last_name') @@ -81,7 +81,7 @@ def preprocess_data(self, data: Dict) -> Dict: return data @validates_schema - def validate_user(self, data: Dict) -> None: + def validate_user(self, data: Dict[str, Any]) -> None: if self._str_no_value(data.get('display_name')): raise ValidationError('"display_name", "full_name", or "email" must be provided') diff --git a/amundsen_common/utils/__init__.py b/amundsen_common/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/amundsen_common/utils/streams.py b/amundsen_common/utils/streams.py new file mode 100644 index 0000000..1b42fd1 --- /dev/null +++ b/amundsen_common/utils/streams.py @@ -0,0 +1,390 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import logging +import threading +from typing import ( + Any, AsyncIterator, Callable, Collection, Iterable, Iterator, List, + Optional, Tuple, TypeVar, Union +) + +from typing_extensions import Final, final + +LOGGER = logging.getLogger(__name__) + + +V = TypeVar('V') +R = TypeVar('R') + + +def one(ignored: Any) -> int: + return 1 + + +class PeekingIterator(Iterator[V]): + """ + Like Iterator, but with peek(), peek_default(), and take_peeked() + """ + def __init__(self, iterable: Iterable[V]): + self.it: Final[Iterator[V]] = iterable if isinstance(iterable, Iterator) else iter(iterable) + self.has_peeked_value = False + self.peeked_value: Optional[V] = None + # RLock could make sense, but it would be just weird for the same thread to try to peek from same blocking + # iterator + self.lock: Final[threading.Lock] = threading.Lock() + + @final + # @overrides Iterator but @overrides doesn't like + def __next__(self) -> V: + """ + :return: the previously peeked value or the next + :raises StopIteration if there is no more values + """ + with self.lock: + value: V + if self.has_peeked_value: + value = self.peeked_value # type: ignore + self.peeked_value = None + self.has_peeked_value = False + else: + value = next(self.it) + assert not self.has_peeked_value + return value + + @final + def peek(self) -> V: + """ + :return: the previously peeked value or the next + :raises StopIteration if there is no more values + """ + with self.lock: + if not self.has_peeked_value: + self.peeked_value = next(self.it) + self.has_peeked_value = True + assert self.has_peeked_value + return self.peeked_value # type: ignore + + @final + def peek_default(self, default: Optional[V]) -> Optional[V]: + """ + :return: the previously peeked value or the next, or default if no more values + """ + try: + return self.peek() + except StopIteration: + return default + + @final + def take_peeked(self, value: V) -> None: + with self.lock: + assert self.has_peeked_value, f'expected to find a peaked value' + assert self.peeked_value is value, f'expected the peaked value to be the same' + self.peeked_value = None + self.has_peeked_value = False + + @final + def has_more(self) -> bool: + try: + self.peek() + return True + except StopIteration: + return False + + +class PeekingAsyncIterator(AsyncIterator[V]): + """ + Like AsyncIterator, but with peek(), peek_default(), and take_peeked() + """ + def __init__(self, iterable: AsyncIterator[V]): + self.it: Final[AsyncIterator[V]] = iterable + self.has_peeked_value = False + self.peeked_value: Optional[V] = None + # RLock could make sense, but it would be just weird for the same thread to try to peek from same blocking + # iterator + self.lock: Final[threading.Lock] = threading.Lock() + + @final + # @overrides AsyncIterator but @overrides doesn't like + async def __anext__(self) -> V: + """ + :return: the previously peeked value or the next + :raises StopAsyncIteration if there is no more values + """ + with self.lock: + value: V + if self.has_peeked_value: + value = self.peeked_value # type: ignore + self.peeked_value = None + self.has_peeked_value = False + else: + value = await self.__anext__() + assert not self.has_peeked_value + return value + + @final + async def peek(self) -> V: + """ + :return: the previously peeked value or the next + :raises StopAsyncIteration if there is no more values + """ + with self.lock: + if not self.has_peeked_value: + self.peeked_value = await self.it.__anext__() + self.has_peeked_value = True + assert self.has_peeked_value + return self.peeked_value # type: ignore + + @final + async def peek_default(self, default: Optional[V]) -> Optional[V]: + """ + :return: the previously peeked value or the next, or default if no more values + """ + try: + return await self.peek() + except StopAsyncIteration: + return default + + @final + def take_peeked(self, value: V) -> None: + with self.lock: + assert self.has_peeked_value, f'expected to find a peaked value' + assert self.peeked_value is value, f'expected the peaked value to be the same' + self.peeked_value = None + self.has_peeked_value = False + + @final + async def has_more(self) -> bool: + try: + await self.peek() + return True + except StopAsyncIteration: + return False + + +def one_chunk(*, it: PeekingIterator[V], n: int, metric: Callable[[V], int]) -> Tuple[Iterable[V], bool]: + """ + :param it: stream of values as a PeekingIterator (or regular iterable if you are only going to take the first chunk + and don't care about the peeked value being consumed) + :param n: consume stream until n is reached. if n is 0, process whole stream as one chunk. + :param metric: the callable that returns positive metric for a value + :returns the chunk + """ + items: List[V] = [] + items_metric: int = 0 + try: + while True: + item = it.peek() + item_metric = metric(item) + # negative would be insane, let's say positive + assert item_metric > 0, \ + f'expected metric to be positive! item_metric={item_metric}, metric={metric}, item={item}' + if not items and item_metric > n: + # should we assert instead? it's probably a surprise to the caller too, and might fail for whatever + # limit they were trying to avoid, but let's give them a shot at least. + LOGGER.error(f"expected a single item's metric to be less than the chunk limit! {item_metric} > {n}, " + f"but returning to make progress") + items.append(item) + it.take_peeked(item) + items_metric += item_metric + break + elif items_metric + item_metric <= n: + items.append(item) + it.take_peeked(item) + items_metric += item_metric + if items_metric >= n: + # we're full + break + # else keep accumulating + else: + assert items_metric + item_metric > n + # we're full + break + # don't catch exception, let that be a concern for callers + except StopIteration: + pass + + has_more = it.has_more() + return tuple(items), has_more + + +def chunk(it: Union[Iterable[V], PeekingIterator[V]], n: int, metric: Callable[[V], int] = one + ) -> Iterable[Iterable[V]]: + """ + :param it: stream of values as a PeekingIterator (or regular iterable if you are only going to take the first chunk + and don't care about the peeked value being consumed) + :param n: consume stream until n is reached. if n is 0, process whole stream as one chunk. + :param metric: the callable that returns positive metric for a value + :returns the Iterable (generator) of chunks + """ + if not isinstance(it, PeekingIterator): + it = PeekingIterator(it) + assert isinstance(it, PeekingIterator) + has_more: bool = True + while has_more: + items, has_more = one_chunk(it=it, n=n, metric=metric) + if items or has_more: + yield items + + +async def async_one_chunk( + it: PeekingAsyncIterator[V], n: int, metric: Callable[[V], int] = one) -> Tuple[Iterable[V], bool]: + """ + :param it: stream of values as a PeekingAsyncIterator + :param n: consume stream until n is reached. if n is 0, process whole stream as one chunk. + :param metric: the callable that returns positive metric for a value + :returns the chunk and if there are more items + """ + items: List[V] = [] + items_metric: int = 0 + if not isinstance(it, PeekingAsyncIterator): + it = PeekingAsyncIterator(it) + assert isinstance(it, PeekingAsyncIterator) + try: + while True: + item = await it.peek() + item_metric = metric(item) + # negative would be insane, let's say positive + assert item_metric > 0, \ + f'expected metric to be positive! item_metric={item_metric}, metric={metric}, item={item}' + if not items and item_metric > n: + # should we assert instead? it's probably a surprise to the caller too, and might fail for whatever + # limit they were trying to avoid, but let's give them a shot at least. + LOGGER.error(f"expected a single item's metric to be less than the chunk limit! {item_metric} > {n}, " + f"but returning to make progress") + items.append(item) + it.take_peeked(item) + items_metric += item_metric + break + elif items_metric + item_metric <= n: + items.append(item) + it.take_peeked(item) + items_metric += item_metric + if items_metric >= n: + # we're full + break + # else keep accumulating + else: + assert items_metric + item_metric > n + # we're full + break + # don't catch exception, let that be a concern for callers + except StopAsyncIteration: + pass + + has_more = await it.has_more() + return tuple(items), has_more + + +async def async_chunk(*, it: Union[AsyncIterator[V], PeekingAsyncIterator[V]], n: int, metric: Callable[[V], int] + ) -> AsyncIterator[Iterable[V]]: + """ + :param it: stream of values as a PeekingAsyncIterator + :param n: consume stream until n is reached. if n is 0, process whole stream as one chunk. + :param metric: the callable that returns positive metric for a value + :returns the chunk and if there are more items + """ + if not isinstance(it, PeekingAsyncIterator): + it = PeekingAsyncIterator(it) + assert isinstance(it, PeekingAsyncIterator) + has_more: bool = True + while has_more: + items, has_more = await async_one_chunk(it=it, n=n, metric=metric) + if items or has_more: + yield items + + +def reduce_in_chunks(*, stream: Iterable[V], n: int, initial: R, + consumer: Callable[[Iterable[V], R], R], metric: Callable[[V], int] = one) -> R: + """ + :param stream: stream of values + :param n: consume stream until n is reached. if n is 0, process whole stream as one chunk. + :param metric: the callable that returns positive metric for a value + :param initial: the initial state + :param consumer: the callable to handle the chunk + :returns the final state + """ + if n > 0: + it = PeekingIterator(stream) + state = initial + for items in chunk(it=it, n=n, metric=metric): + state = consumer(items, state) + return state + else: + return consumer(stream, initial) + + +async def async_reduce_in_chunks(*, stream: AsyncIterator[V], n: int, metric: Callable[[V], int], initial: R, + consumer: Callable[[Iterable[V], R], R]) -> R: + """ + :param stream: + :param n: if n is 0, process whole stream as one chunk + :param metric: the callable that returns positive metric for a value + :param initial: the initial state + :param consumer: the callable to handle the chunk + :returns the final state + """ + if n > 0: + it = PeekingAsyncIterator(stream) + state = initial + async for items in async_chunk(it=it, n=n, metric=metric): + state = consumer(items, state) + return state + else: + return consumer(tuple([_ async for _ in stream]), initial) + + +def consume_in_chunks(*, stream: Iterable[V], n: int, consumer: Callable[[Iterable[V]], None], + metric: Callable[[V], int] = one) -> int: + """ + :param stream: + :param n: consume stream until n is reached if n is 0, process whole stream as one chunk + :param metric: the callable that returns positive metric for a value + :param consumer: the callable to handle the chunk + :return: + """ + _actual_state: int = 0 + + def _consumer(things: Iterable[V], ignored: None) -> None: + nonlocal _actual_state + things = _assure_collection(things) + assert isinstance(things, Collection) # appease the types + _actual_state += len(things) + consumer(things) + reduce_in_chunks(stream=stream, n=n, initial=None, consumer=_consumer, metric=metric) + return _actual_state + + +async def async_consume_in_chunks(*, stream: AsyncIterator[V], n: int, consumer: Callable[[Iterable[V]], None], + metric: Callable[[V], int] = one) -> int: + _actual_state: int = 0 + + def _consumer(things: Iterable[V], ignored: None) -> None: + nonlocal _actual_state + things = _assure_collection(things) + assert isinstance(things, Collection) # appease the types + _actual_state += len(things) + consumer(things) + await async_reduce_in_chunks(stream=stream, n=n, initial=None, consumer=_consumer, metric=metric) + return _actual_state + + +def consume_in_chunks_with_state(*, stream: Iterable[V], n: int, consumer: Callable[[Iterable[V]], None], + state: Callable[[V], R], metric: Callable[[V], int] = one) -> Iterable[R]: + _actual_state: List[R] = list() + + def _consumer(things: Iterable[V], ignored: None) -> None: + nonlocal _actual_state + things = _assure_collection(things) + assert isinstance(things, Collection) # appease the types + _actual_state.extend(map(state, things)) + consumer(things) + + reduce_in_chunks(stream=stream, n=n, initial=None, consumer=_consumer, metric=metric) + return tuple(_actual_state) + + +def _assure_collection(iterable: Iterable[V]) -> Collection[V]: + if isinstance(iterable, Collection): + return iterable + else: + return tuple(iterable) diff --git a/amundsen_common/utils/utils.py b/amundsen_common/utils/utils.py new file mode 100644 index 0000000..b4780e9 --- /dev/null +++ b/amundsen_common/utils/utils.py @@ -0,0 +1,12 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +from typing import Optional, TypeVar + +X = TypeVar('X') + + +def check_not_none(x: Optional[X], *, message: str = 'is None') -> X: + if x is None: + raise RuntimeError(message) + return x diff --git a/requirements.txt b/requirements.txt index e480d4f..5312baf 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,7 +3,7 @@ flake8==3.7.8 Flask==1.1.1 marshmallow==2.15.3 marshmallow-annotations==2.4.0 -mypy==0.720 +mypy==0.761 pytest>=4.6 pytest-cov pytest-mock diff --git a/setup.cfg b/setup.cfg index f4db814..1f621cd 100644 --- a/setup.cfg +++ b/setup.cfg @@ -35,6 +35,7 @@ check_untyped_defs = true disallow_any_generics = true disallow_incomplete_defs = true disallow_untyped_defs = true +ignore_missing_imports = true no_implicit_optional = true [mypy-tests.*] diff --git a/tests/__init__.py b/tests/__init__.py index d66c0ef..f3145d7 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -1,3 +1,2 @@ # Copyright Contributors to the Amundsen project. # SPDX-License-Identifier: Apache-2.0 - diff --git a/tests/unit/__init__.py b/tests/unit/__init__.py index d66c0ef..f3145d7 100644 --- a/tests/unit/__init__.py +++ b/tests/unit/__init__.py @@ -1,3 +1,2 @@ # Copyright Contributors to the Amundsen project. # SPDX-License-Identifier: Apache-2.0 - diff --git a/tests/unit/log/__init__.py b/tests/unit/log/__init__.py index d66c0ef..f3145d7 100644 --- a/tests/unit/log/__init__.py +++ b/tests/unit/log/__init__.py @@ -1,3 +1,2 @@ # Copyright Contributors to the Amundsen project. # SPDX-License-Identifier: Apache-2.0 - diff --git a/tests/unit/utils/__init__.py b/tests/unit/utils/__init__.py new file mode 100644 index 0000000..f3145d7 --- /dev/null +++ b/tests/unit/utils/__init__.py @@ -0,0 +1,2 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 diff --git a/tests/unit/utils/test_streams.py b/tests/unit/utils/test_streams.py new file mode 100644 index 0000000..fe73fa6 --- /dev/null +++ b/tests/unit/utils/test_streams.py @@ -0,0 +1,201 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import asyncio +import logging +import unittest +from typing import AsyncIterator, Iterable +from unittest.mock import Mock, call + +from amundsen_common.utils.streams import ( + PeekingIterator, _assure_collection, async_consume_in_chunks, + consume_in_chunks, consume_in_chunks_with_state, one_chunk, + reduce_in_chunks +) + + +class TestConsumer(unittest.TestCase): + def test_consume_in_chunks(self) -> None: + values = Mock() + values.side_effect = list(range(5)) + consumer = Mock() + parent = Mock() + parent.values = values + parent.consumer = consumer + + def stream() -> Iterable[int]: + for _ in range(5): + yield values() + + count = consume_in_chunks(stream=stream(), n=2, consumer=consumer) + self.assertEqual(count, 5) + # this might look at little weird, but PeekingIterator is why + self.assertSequenceEqual([call.values(), call.values(), call.values(), call.consumer((0, 1)), + call.values(), call.values(), call.consumer((2, 3)), call.consumer((4,))], + parent.mock_calls) + + def test_consume_in_chunks_with_exception(self) -> None: + consumer = Mock() + + def stream() -> Iterable[int]: + yield from range(10) + raise KeyError('hi') + + with self.assertRaisesRegex(KeyError, 'hi'): + consume_in_chunks(stream=stream(), n=4, consumer=consumer) + self.assertSequenceEqual([call.consumer((0, 1, 2, 3)), call.consumer((4, 5, 6, 7))], consumer.mock_calls) + + def test_consume_in_chunks_with_state(self) -> None: + values = Mock() + values.side_effect = list(range(5)) + consumer = Mock() + consumer.side_effect = list(range(1, 4)) + state = Mock() + state.side_effect = lambda x: x * 10 + parent = Mock() + parent.values = values + parent.consumer = consumer + parent.state = state + + def stream() -> Iterable[int]: + for _ in range(5): + yield values() + + result = consume_in_chunks_with_state(stream=stream(), n=2, consumer=consumer, state=state) + self.assertSequenceEqual(tuple(result), (0, 10, 20, 30, 40)) + # this might look at little weird, but PeekingIterator is why + self.assertSequenceEqual([call.values(), call.values(), call.values(), call.state(0), call.state(1), + call.consumer((0, 1)), call.values(), call.values(), call.state(2), call.state(3), + call.consumer((2, 3)), call.state(4), call.consumer((4,))], + parent.mock_calls) + + def test_consume_in_chunks_no_batch(self) -> None: + consumer = Mock() + count = consume_in_chunks(stream=range(100000000), n=-1, consumer=consumer) + self.assertEqual(100000000, count) + consumer.assert_called_once() + + def test_reduce_in_chunks(self) -> None: + values = Mock() + values.side_effect = list(range(5)) + consumer = Mock() + consumer.side_effect = list(range(1, 4)) + parent = Mock() + parent.values = values + parent.consumer = consumer + + def stream() -> Iterable[int]: + for _ in range(5): + yield values() + + result = reduce_in_chunks(stream=stream(), n=2, initial=0, consumer=consumer) + self.assertEqual(result, 3) + # this might look at little weird, but PeekingIterator is why + self.assertSequenceEqual([call.values(), call.values(), call.values(), call.consumer((0, 1), 0), + call.values(), call.values(), call.consumer((2, 3), 1), call.consumer((4,), 2)], + parent.mock_calls) + + def test_async_consume_in_chunks(self) -> None: + values = Mock() + values.side_effect = list(range(5)) + consumer = Mock() + parent = Mock() + parent.values = values + parent.consumer = consumer + + async def stream() -> AsyncIterator[int]: + for i in range(5): + yield values() + + count = asyncio.run(async_consume_in_chunks(stream=stream(), n=2, consumer=consumer)) + self.assertEqual(5, count, 'count') + # this might look at little weird, but PeekingIterator is why + self.assertSequenceEqual([call.values(), call.values(), call.values(), call.consumer((0, 1)), + call.values(), call.values(), call.consumer((2, 3)), call.consumer((4,))], + parent.mock_calls) + + def test_one_chunk_logging(self) -> None: + it = PeekingIterator(range(1, 4)) + actual, has_more = one_chunk(it=it, n=2, metric=lambda x: x) + self.assertSequenceEqual([1], tuple(actual)) + self.assertTrue(has_more) + + actual, has_more = one_chunk(it=it, n=2, metric=lambda x: x) + self.assertSequenceEqual([2], tuple(actual)) + self.assertTrue(has_more) + + with self.assertLogs(logger='amundsen_common.utils.streams', level=logging.ERROR): + actual, has_more = one_chunk(it=it, n=2, metric=lambda x: x) + self.assertSequenceEqual([3], tuple(actual)) + self.assertFalse(has_more) + + def test_assure_collection(self) -> None: + actual = _assure_collection(iter(range(2))) + self.assertIsInstance(actual, tuple) + self.assertEqual((0, 1), actual) + actual = _assure_collection(list(range(2))) + self.assertIsInstance(actual, list) + self.assertEqual([0, 1], actual) + actual = _assure_collection(set(range(2))) + self.assertIsInstance(actual, set) + self.assertEqual({0, 1}, actual) + actual = _assure_collection(frozenset(range(2))) + self.assertIsInstance(actual, frozenset) + self.assertEqual(frozenset({0, 1}), actual) + + +class TestPeekingIterator(unittest.TestCase): + # TODO: it'd be good to test the locking + def test_no_peek(self) -> None: + it = PeekingIterator(range(3)) + self.assertEqual(0, next(it)) + self.assertEqual(1, next(it)) + self.assertEqual(2, next(it)) + with self.assertRaises(StopIteration): + next(it) + + def test_peek_is_next(self) -> None: + it = PeekingIterator(range(2)) + self.assertEqual(0, it.peek()) + self.assertTrue(it.has_more()) + self.assertEqual(0, next(it)) + self.assertTrue(it.has_more()) + self.assertEqual(1, next(it)) + self.assertFalse(it.has_more()) + with self.assertRaises(StopIteration): + next(it) + + def test_peek_repeats(self) -> None: + it = PeekingIterator(range(2)) + for _ in range(100): + self.assertEqual(0, it.peek()) + self.assertEqual(0, next(it)) + self.assertEqual(1, next(it)) + + def test_peek_after_exhaustion(self) -> None: + it = PeekingIterator(range(2)) + self.assertEqual(0, next(it)) + self.assertEqual(1, next(it)) + with self.assertRaises(StopIteration): + next(it) + with self.assertRaises(StopIteration): + it.peek() + self.assertEqual(999, it.peek_default(999)) + + def test_take_peeked(self) -> None: + it = PeekingIterator(range(2)) + self.assertEqual(0, it.peek()) + it.take_peeked(0) + self.assertEqual(1, next(it)) + with self.assertRaises(StopIteration): + next(it) + + def test_take_peeked_wrong_value(self) -> None: + it = PeekingIterator(range(2)) + self.assertEqual(0, it.peek()) + with self.assertRaisesRegex(AssertionError, 'expected the peaked value to be the same'): + it.take_peeked(1) + it.take_peeked(0) + self.assertEqual(1, next(it)) + +# TODO: test PeekingAsyncIterator directly