diff --git a/CHANGES.md b/CHANGES.md index 02f9913..fc46da2 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -1,14 +1,27 @@ +## 1.13.0 + +`LazyMulticast` - a ContextManager-based interface for +process-local multicasting of a producer. + +`funnel_latest_from_log_group` - CloudWatch Logs funnel that composes +easily with the above. + +`funnel_sharded_stream` - Genericized sharded stream funnel that +guarantees receiving subsequent writes. Previously existing DynamoDB +streams processor `process_latest_from_stream` uses this now. + ### 1.12.2 Cover more exception types in retries for DynamoDB transaction utilities. ### 1.12.1 -Update `all_items_for_next_attempt` to be able to handle multiple tables with different key schemas +Update `all_items_for_next_attempt` to be able to handle multiple +tables with different key schemas ## 1.12.0 -An enhanced API supporting the use of +`TypedTable` - An enhanced API supporting the use of `dynamodb.write_versioned.versioned_transact_write_items`, which is a great way to write business logic against DynamoDB. @@ -82,7 +95,7 @@ avoid race conditions. - `map_tree` now supports postorder transformations via keyword argument. -### 1.4.0 +## 1.4.0 - Improved DynamoDB Item-related Exceptions for `GetItem`, `put_but_raise_if_exists`, and `versioned_diffed_update_item`. diff --git a/dev-utils/cloudwatch-logs-to-local-file.py b/dev-utils/cloudwatch-logs-to-local-file.py new file mode 100755 index 0000000..7c289ce --- /dev/null +++ b/dev-utils/cloudwatch-logs-to-local-file.py @@ -0,0 +1,47 @@ +#!/usr/bin/env python +"""This is mostly just a silly proof-of-concept""" +import json +import threading +from functools import partial + +import boto3 + +from xoto3.cloudwatch.logs import funnel_latest_from_log_group +from xoto3.multicast import LazyMulticast + +CLOUDWATCH_LOGS = LazyMulticast(partial(funnel_latest_from_log_group, boto3.client("logs"))) + + +def write_log_events_to_file(log_group_name: str, filename: str): + with open(filename, "w") as outf: + with CLOUDWATCH_LOGS(log_group_name) as log_events: + for event in log_events: + outf.write(json.dumps(event) + "\n") + + +def main(): + while True: + log_group_name = input("Log Group Name: ") + if not log_group_name: + continue + try: + while True: + output_filename = input("output filename: ") + if not output_filename: + continue + t = threading.Thread( + target=write_log_events_to_file, + args=(log_group_name, output_filename), + daemon=True, + ) + t.start() + break + except KeyboardInterrupt: + print("\n") + + +if __name__ == "__main__": + try: + main() + except KeyboardInterrupt: + pass diff --git a/dev-utils/watch-cloudwatch-log-stream.py b/dev-utils/watch-cloudwatch-log-stream.py old mode 100644 new mode 100755 index 7f87b07..5d70da6 --- a/dev-utils/watch-cloudwatch-log-stream.py +++ b/dev-utils/watch-cloudwatch-log-stream.py @@ -1,33 +1,27 @@ -from datetime import datetime, timedelta -from functools import partial +#!/usr/bin/env python +import argparse +from datetime import datetime, timezone import boto3 -from xoto3.paginate import yield_pages_from_operation +from xoto3.cloudwatch.logs import yield_filtered_log_events -cw_client = boto3.client("logs") -start_time = (datetime.utcnow() - timedelta(hours=20)).timestamp() * 1000 -end_time = datetime.utcnow().timestamp() * 1000 +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("log_group_name") + parser.add_argument("--filter-pattern", "-f", default="") + args = parser.parse_args() -query = dict( - logGroupName="xoi-ecs-logs-devl", - logStreamNamePrefix="dataplateocr/dataplateocrContainer", - startTime=int(start_time), - endTime=int(end_time), -) + cw_client = boto3.client("logs") -nt = ("nextToken",) -CLOUDWATCH_FILTER_LOG_EVENTS = ( - nt, - nt, - ("limit",), - ("events",), -) + start_time = datetime.now(timezone.utc) -yield_cloudwatch_pages = partial(yield_pages_from_operation, *CLOUDWATCH_FILTER_LOG_EVENTS,) + for log_event in yield_filtered_log_events( + cw_client, args.log_group_name, start_time, args.filter_pattern + ): + print(log_event["message"]) -for page in yield_cloudwatch_pages(cw_client.filter_log_events, query,): - for event in page["events"]: - print(event["message"]) +if __name__ == "__main__": + main() diff --git a/dev-utils/watch_ddb_stream.py b/dev-utils/watch_ddb_stream.py index 4ba19e3..5e4b44a 100755 --- a/dev-utils/watch_ddb_stream.py +++ b/dev-utils/watch_ddb_stream.py @@ -8,16 +8,17 @@ import boto3 -from xoto3.dynamodb.streams.consume import process_latest_from_stream -from xoto3.dynamodb.streams.records import old_and_new_items_from_stream_event_record +from xoto3.dynamodb.streams.consume import ItemImages, make_dynamodb_stream_images_multicast from xoto3.dynamodb.utils.index import hash_key_name, range_key_name +DYNAMODB_STREAMS = make_dynamodb_stream_images_multicast() -def make_accept_stream_item_for_table(item_slicer: Callable[[dict], str]): - def accept_stream_item(record: dict): - old, new = old_and_new_items_from_stream_event_record(record) + +def make_accept_stream_images(item_slicer: Callable[[dict], str]): + def accept_stream_item(images: ItemImages): + old, new = images if not old: - print(f"New item: {item_slicer(new)}") + print(f"New item: {item_slicer(new)}") # type: ignore elif not new: print(f"Deleted item {item_slicer(old)}") else: @@ -53,8 +54,6 @@ def main(): DDB_RES = boto3.resource("dynamodb") - DDB_STREAMS_CLIENT = boto3.client("dynamodbstreams") - table = DDB_RES.Table(args.table_name) if args.attribute_names: @@ -74,14 +73,12 @@ def item_slicer(item: dict): item_slicer = make_key_slicer(table) try: - t, _kill = process_latest_from_stream( - DDB_STREAMS_CLIENT, - table.latest_stream_arn, - make_accept_stream_item_for_table(item_slicer), - ) - t.join() + accept_stream_images = make_accept_stream_images(item_slicer) + with DYNAMODB_STREAMS(args.table_name) as table_stream: + for images in table_stream: + accept_stream_images(images) except KeyboardInterrupt: - pass + pass # no noisy log - Ctrl-C for clean exit if __name__ == "__main__": diff --git a/tests/xoto3/backoff_test.py b/tests/xoto3/backoff_test.py new file mode 100644 index 0000000..94e3f0a --- /dev/null +++ b/tests/xoto3/backoff_test.py @@ -0,0 +1,39 @@ +import pytest +from botocore.exceptions import ClientError + +from xoto3.backoff import backoff + + +def make_named_ce(Code: str): + return ClientError({"Error": {"Code": Code}}, "test operation") + + +def test_backoff_some_client_errors(): + count = 0 + + @backoff + def fails_twice(): + nonlocal count + if count > 1: + return "done" + count += 1 + raise make_named_ce("ThrottlingException") + + assert "done" == fails_twice() + assert count == 2 + + +def test_dont_backoff_others(): + @backoff + def not_found(): + raise make_named_ce("NotFound") + + with pytest.raises(ClientError): + not_found() + + @backoff + def whoops(): + raise Exception("whoops") + + with pytest.raises(Exception): + whoops() diff --git a/tests/xoto3/dynamodb/streams/records_test.py b/tests/xoto3/dynamodb/streams/records_test.py new file mode 100644 index 0000000..be230b5 --- /dev/null +++ b/tests/xoto3/dynamodb/streams/records_test.py @@ -0,0 +1,53 @@ +import pytest + +from xoto3.dynamodb.streams.records import ( + ItemCreated, + ItemDeleted, + ItemModified, + current_nonempty_value, + matches_key, + old_and_new_dict_tuples_from_stream, +) +from xoto3.dynamodb.utils.serde import serialize_item + + +def _serialize_record(record: dict): + return {k: serialize_item(image) for k, image in record.items()} + + +def _serialize_records(*records): + return [dict(dynamodb=_serialize_record(rec["dynamodb"])) for rec in records] + + +def _fake_stream_event(): + return dict( + Records=_serialize_records( + dict(dynamodb=dict(NewImage=dict(id=1, val=2))), + dict(dynamodb=dict(NewImage=dict(id=1, val=3), OldImage=dict(id=1, val=2))), + dict(dynamodb=dict(NewImage=dict(id=2, bar=8), OldImage=dict(id=2, bar=-9))), + dict(dynamodb=dict(NewImage=dict(id=1, val=4), OldImage=dict(id=1, val=3))), + dict(dynamodb=dict(NewImage=dict(id=2, foo="steve"), OldImage=dict(id=2, bar=8))), + dict(dynamodb=dict(OldImage=dict(id=1, val=4))), + ) + ) + + +def test_current_nonempty_value(): + list_of_images = old_and_new_dict_tuples_from_stream(_fake_stream_event()) + + assert [dict(id=1, val=2), dict(id=1, val=3), dict(id=1, val=4)] == list( + current_nonempty_value(dict(id=1))(list_of_images) + ) + + +def test_matches_key_fails_with_empty_key(): + with pytest.raises(ValueError): + matches_key(dict()) + + +def test_matches_key_works_on_new_as_well_as_old(): + assert not matches_key(dict(id=3))(ItemCreated(None, dict(id=4))) + assert not matches_key(dict(id=3))(ItemDeleted(dict(id=4), None)) + assert not matches_key(dict(hash=1, range=3))( + ItemModified(dict(hash=1, range=4, foo=0), dict(hash=1, range=4, foo=1)) + ) diff --git a/tests/xoto3/paginate_test.py b/tests/xoto3/paginate_test.py new file mode 100644 index 0000000..54209e5 --- /dev/null +++ b/tests/xoto3/paginate_test.py @@ -0,0 +1,40 @@ +from xoto3.cloudwatch.logs.events import CLOUDWATCH_LOGS_FILTER_LOG_EVENTS +from xoto3.paginate import yield_pages_from_operation + + +class FakeApi: + def __init__(self, results): + self.calls = 0 + self.results = results + + def __call__(self, **kwargs): + self.calls += 1 + return self.results[self.calls - 1] + + +def test_pagination_with_nextToken_and_limit(): + + fake_cw = FakeApi( + [ + dict(nextToken="1", events=[1, 2, 3]), + dict(nextToken="2", events=[4, 5, 6]), + dict(nextToken="3", events=[7, 8, 9]), + ] + ) + + nt = None + + def le_cb(next_token): + nonlocal nt + nt = next_token + + collected_events = list() + for page in yield_pages_from_operation( + *CLOUDWATCH_LOGS_FILTER_LOG_EVENTS, fake_cw, dict(limit=6), last_evaluated_callback=le_cb + ): + for event in page["events"]: + collected_events.append(event) + + assert collected_events == list(range(1, 7)) + + assert nt == "2" diff --git a/tests/xoto3/utils/cm_test.py b/tests/xoto3/utils/cm_test.py new file mode 100644 index 0000000..7942263 --- /dev/null +++ b/tests/xoto3/utils/cm_test.py @@ -0,0 +1,20 @@ +from contextlib import contextmanager + +from xoto3.utils.cm import xf_cm + + +@contextmanager +def yield_3(): + print("generating a 3") + yield 3 + print("cleaning that 3 right on up") + + +def test_transform_context_manager(): + def add_one(x: int): + return x + 1 + + yield_4 = xf_cm(add_one)(yield_3()) + + with yield_4 as actually_four: + assert actually_four == 4 diff --git a/tests/xoto3/utils/multicast_test.py b/tests/xoto3/utils/multicast_test.py new file mode 100644 index 0000000..ca8d8ae --- /dev/null +++ b/tests/xoto3/utils/multicast_test.py @@ -0,0 +1,50 @@ +import threading +import time +import typing as ty +from collections import defaultdict + +from xoto3.utils.multicast import LazyMulticast + + +def test_lazy_multicast(): + class Recvr(ty.NamedTuple): + nums: ty.List[int] + + CONSUMER_COUNT = 10 + NUM_NUMS = 30 + sem = threading.Semaphore(0) + + def start_numbers_stream(num_nums: int, recv): + def stream_numbers(): + for i in range(CONSUMER_COUNT): + sem.acquire() + # wait for 10 consumers to start + for i in range(num_nums): + recv(i) + + t = threading.Thread(target=stream_numbers, daemon=True) + t.start() + return t.join + + mc = LazyMulticast(start_numbers_stream) # type: ignore + + consumer_results = defaultdict(list) + + def consume_numbers(): + sem.release() + thread_id = threading.get_ident() + with mc(NUM_NUMS) as nums_stream: + for i, num in enumerate(nums_stream): + consumer_results[thread_id].append(num) + if i == NUM_NUMS - 1: + break + + for i in range(CONSUMER_COUNT): + threading.Thread(target=consume_numbers, daemon=True).start() + + time.sleep(1) + + assert len(consumer_results) == CONSUMER_COUNT + + for results in consumer_results.values(): + assert list(range(NUM_NUMS)) == results diff --git a/tests/xoto3/utils/poll_test.py b/tests/xoto3/utils/poll_test.py new file mode 100644 index 0000000..1784433 --- /dev/null +++ b/tests/xoto3/utils/poll_test.py @@ -0,0 +1,26 @@ +# type: ignore +import queue + +import pytest + +from xoto3.utils.poll import QueuePollIterable, TimedOut, expiring_poll_iter + + +def test_queue_poll_iterable(): + qpi = QueuePollIterable(queue.Queue()) + qpi.iter_timeout = 1.0 + + with pytest.raises(StopIteration): + next(iter(qpi)) + + with pytest.raises(TimedOut): + qpi(0.5) + + +def test_expiring_poll_iter(): + qpi = QueuePollIterable(queue.Queue()) + + epi = expiring_poll_iter(0.3)(qpi) + + with pytest.raises(StopIteration): + next(epi) diff --git a/xoto3/__about__.py b/xoto3/__about__.py index f5c7156..c61b439 100644 --- a/xoto3/__about__.py +++ b/xoto3/__about__.py @@ -1,4 +1,4 @@ """xoto3""" -__version__ = "1.12.2" +__version__ = "1.13.0" __author__ = "Peter Gaultney" __author_email__ = "pgaultney@xoi.io" diff --git a/xoto3/backoff.py b/xoto3/backoff.py index 1a59d15..47b2da3 100644 --- a/xoto3/backoff.py +++ b/xoto3/backoff.py @@ -1,10 +1,14 @@ -import time +"""Infinite exponential backoff for AWS API throttling. + +See retry.py for more general purpose retry strategy utilities. + +""" from logging import getLogger as get_logger import botocore.exceptions -from xoto3.errors import client_error_name - +from .errors import client_error_name +from .utils.retry import expo, retry_while, sleep_join logger = get_logger(__name__) @@ -12,24 +16,11 @@ RETRY_EXCEPTIONS = ("ProvisionedThroughputExceededException", "ThrottlingException") -def backoff(func): - """Will retry a boto3 operation closure until it succeeds, as long - as the exception was throughput-related. - """ - - def backoff_wrapper(*args, **kwargs): - retries = 0 - pause_time = 0 +def _is_boto3_retryable(e: Exception) -> bool: + if not isinstance(e, botocore.exceptions.ClientError): + return False + return client_error_name(e) in RETRY_EXCEPTIONS - while True: - try: - return func(*args, **kwargs) - except botocore.exceptions.ClientError as ce: - if client_error_name(ce) not in RETRY_EXCEPTIONS: - raise - pause_time = 2 ** retries - logger.info("Back-off set to %d seconds", pause_time) - time.sleep(pause_time) - retries += 1 - return backoff_wrapper +backoff = retry_while(_is_boto3_retryable for _ in sleep_join(expo())) +"""Infinite exponential backoff for the specified errors""" diff --git a/xoto3/cloudwatch/logs/__init__.py b/xoto3/cloudwatch/logs/__init__.py new file mode 100644 index 0000000..1928d75 --- /dev/null +++ b/xoto3/cloudwatch/logs/__init__.py @@ -0,0 +1,2 @@ +from .events import yield_filtered_log_events # noqa +from .funnel import funnel_latest_from_log_group # noqa diff --git a/xoto3/cloudwatch/logs/events.py b/xoto3/cloudwatch/logs/events.py new file mode 100644 index 0000000..68913fb --- /dev/null +++ b/xoto3/cloudwatch/logs/events.py @@ -0,0 +1,62 @@ +import time +import typing as ty +from datetime import datetime + +from typing_extensions import TypedDict + +from xoto3.paginate import PagePathsTemplate, yield_pages_from_operation + + +class LogEvent(TypedDict): + timestamp: int + message: str + ingestionTime: int + + +class FilteredLogEvent(LogEvent, total=False): + logStreamName: str + eventId: str + + +CLOUDWATCH_LOGS_FILTER_LOG_EVENTS = PagePathsTemplate( + ("nextToken",), ("nextToken",), ("limit",), ("events",), +) + + +def yield_filtered_log_events( + client, + log_group_name: str, + start_time: datetime, + filter_pattern: str = "", + watch_interval: float = 1.0, + # set to a negative number or 0 to stop yielding events when + # the end of the stream is reached. +) -> ty.Iterator[FilteredLogEvent]: + """This will iterate infinitely unless watch_interval is set to 0 or negative.""" + req = dict(logGroupName=log_group_name, startTime=int(start_time.timestamp() * 1000)) + if filter_pattern: + req["filterPattern"] = filter_pattern + + def intercept_nextToken(next_token): + """When filter_log_events returns an empty nextToken, that means it + knows of nothing further in the log group. But it turns out + you can hang on to the previous nextToken and resume calling + later on, and you should pick up where you left off with new + log events. + """ + + if next_token: + req["nextToken"] = next_token + + while True: + pages = yield_pages_from_operation( + *CLOUDWATCH_LOGS_FILTER_LOG_EVENTS, + client.filter_log_events, + req, + last_evaluated_callback=intercept_nextToken, + ) + for page in pages: + yield from page["events"] + if watch_interval <= 0: + break + time.sleep(watch_interval) diff --git a/xoto3/cloudwatch/logs/funnel.py b/xoto3/cloudwatch/logs/funnel.py new file mode 100644 index 0000000..a6aef7a --- /dev/null +++ b/xoto3/cloudwatch/logs/funnel.py @@ -0,0 +1,30 @@ +import threading +import typing as ty +from datetime import datetime, timezone + +from xoto3.utils.multicast import Cleanup + +from .events import LogEvent, yield_filtered_log_events + +LogEventFunnel = ty.Callable[[LogEvent], None] + + +def funnel_latest_from_log_group( + cloudwatch_logs_client, log_group_name: str, log_event_funnel: LogEventFunnel, +) -> Cleanup: + start = datetime.now(timezone.utc) + bottle = dict(poisoned=False) + + def put_logs_into_funnel(): + for log_event in yield_filtered_log_events(cloudwatch_logs_client, log_group_name, start): + if bottle["poisoned"]: + break + log_event_funnel(log_event) + + thread = threading.Thread(target=put_logs_into_funnel, daemon=True) + thread.start() + + def poison(): + bottle["poisoned"] = True + + return poison diff --git a/xoto3/dynamodb/streams/consume.py b/xoto3/dynamodb/streams/consume.py index a1f278f..374dcb3 100644 --- a/xoto3/dynamodb/streams/consume.py +++ b/xoto3/dynamodb/streams/consume.py @@ -1,102 +1,64 @@ import typing as ty -import time -from uuid import uuid4 -from threading import Thread -from logging import getLogger +import boto3 + +from xoto3.utils.multicast import LazyMulticast, OnNext +from xoto3.utils.stream import ShardedStreamFunnelController, funnel_sharded_stream + +from .records import ItemImages, old_and_new_items_from_stream_event_record from .shards import ( - Shard, - key_shard, refresh_live_shards, shard_iterator_from_shard, yield_records_from_shard_iterator, ) -logger = getLogger(__name__) +DynamoDbStreamEventConsumer = ty.Callable[[dict], None] -def process_latest_from_stream(client, stream_arn: str, stream_consumer, sleep_s: int = 10): +def process_latest_from_stream( + client, stream_arn: str, stream_consumer: DynamoDbStreamEventConsumer, sleep_s: float = 10.0 +) -> ShardedStreamFunnelController: """This spawns a thread which spawns other threads that each handle a - DynamoDB Stream Shard. Your consumer will get everything from every shard. - - This is therefore obviously _not_ suitable for cases where there - might be lots and lots of data. This is a utility for other cases. - - It returns the thread itself (in case you want to block until the - stream runs out, which incidentally will probably never happen) - and also a thunk that may be called to poison all the existing - shard consumers so that you don't get any more data and they - eventually close out their threads as well. - - Note however that since Python threads are not interruptible, - poisoning them will only take effect once they've received their - next item from their shard. If your table is not in active use - that may never happen. - - The good news is that all of these threads are started as daemon - threads, so if your process exits these threads will exit - immediately as well. + DynamoDB Stream Shard. Your consumer/funnel will get everything from every shard. + See the docstring on funnel_sharded_stream for further details. + """ + return funnel_sharded_stream( + lambda: refresh_live_shards(client, stream_arn), + lambda shard: shard_iterator_from_shard(client, "LATEST", shard), + lambda shard: shard_iterator_from_shard(client, "TRIM_HORIZON", shard), + lambda shard_it: yield_records_from_shard_iterator(client, shard_it), + stream_consumer, + shard_refresh_interval=sleep_s, + ) + + +def make_dynamodb_stream_images_multicast( + shard_refresh_interval: float = 2.0, +) -> LazyMulticast[str, ItemImages]: + """A slightly cleaner interface to process_latest_from_stream, + particularly if you want to be able to share the output across + multiple consumers. """ - processing_id = uuid4().hex # for debugging/sanity - live_processors_by_key: ty.Dict[str, Thread] = dict() - emptied_shards: ty.Set[str] = set() - - def shard_emptied(shard: Shard): - logger.debug(f"{processing_id} shard emptied {shard}") - shard_key = key_shard(shard) - if shard_key in live_processors_by_key: - live_processors_by_key.pop(shard_key) - emptied_shards.add(shard_key) - - sentinel_container = dict(kill_me=False) - - def run(): - iterator_type = "LATEST" - new_shards_by_key = refresh_live_shards(client, stream_arn) - - while not sentinel_container["kill_me"]: - for key, shard in new_shards_by_key.items(): - dbg = f"{processing_id} - {key}" - shard_iterator = shard_iterator_from_shard(client, iterator_type, shard) - - def run_shard_and_cleanup(): - try: - logger.debug( - f'{dbg} starting shard processor with shard {shard_iterator["ShardIterator"]}' - ) - for item in yield_records_from_shard_iterator(client, shard_iterator): - if sentinel_container["kill_me"]: - break - stream_consumer(item) - logger.debug(f"{dbg} exiting shard iterator processor") - except Exception as e: - logger.exception(e) - shard_emptied(shard) - - logger.debug(f"{dbg} spawning new shard processor for shard {key}") - live_processors_by_key[key] = Thread(target=run_shard_and_cleanup) - live_processors_by_key[key].start() - iterator_type = "TRIM_HORIZON" - # for all subsequent shards, start at their beginning - - time.sleep(sleep_s) - - new_shards_by_key = refresh_live_shards(client, stream_arn) - for shard_key in set(live_processors_by_key.keys()) | emptied_shards: - new_shards_by_key.pop(shard_key, None) - - if not live_processors_by_key and not new_shards_by_key: - break # out of the loop - - # we use a thread so that we can actually return to the caller a way to shut all this down - # Python threads are not interruptible, so this is a little uglier than one might wish - thread = Thread(target=run, daemon=True) - # it's a daemon thread so that this thread will not keep your program alive by itself - # This way, a Ctrl-C or other exit will do the trick cleanly. - thread.start() - - def kill(): - sentinel_container["kill_me"] = True - return thread, kill + def start_dynamodb_stream( + table_name: str, stream_event_callback: OnNext[ItemImages] + ) -> ty.Callable[[], None]: + session = boto3.session.Session() + table = session.resource("dynamodb").Table(table_name) + thread, kill = process_latest_from_stream( # pylint: disable=unpacking-non-sequence + session.client("dynamodbstreams"), + table.latest_stream_arn, # type: ignore + lambda record_dict: stream_event_callback( + old_and_new_items_from_stream_event_record(record_dict) + ), + sleep_s=shard_refresh_interval, + ) + + def cleanup_ddb_stream(): + kill() + thread.join() + + return cleanup_ddb_stream + + return LazyMulticast(start_dynamodb_stream) diff --git a/xoto3/dynamodb/streams/records.py b/xoto3/dynamodb/streams/records.py index 8dca815..7d3fc5d 100644 --- a/xoto3/dynamodb/streams/records.py +++ b/xoto3/dynamodb/streams/records.py @@ -1,16 +1,52 @@ -"""Dynamo Streams record processing""" +"""DynamoDB Streams record processing types and utilities""" import typing as ty - from logging import getLogger -from xoto3.dynamodb.utils.serde import deserialize_item -from xoto3.dynamodb.types import Item +from typing_extensions import Literal +from xoto3.dynamodb.types import Item, ItemKey +from xoto3.dynamodb.utils.serde import deserialize_item logger = getLogger(__name__) -def old_and_new_items_from_stream_event_record(event_record: dict,) -> ty.Tuple[Item, Item]: +class ItemCreated(ty.NamedTuple): + old: Literal[None] + new: Item + + +class ItemModified(ty.NamedTuple): + old: Item + new: Item + + +class ItemDeleted(ty.NamedTuple): + old: Item + new: Literal[None] + + +ItemImages = ty.Union[ItemCreated, ItemModified, ItemDeleted] +ExistingItemImages = ty.Union[ItemCreated, ItemModified] # a common alias + + +def item_images(old: ty.Optional[Item], new: ty.Optional[Item]) -> ItemImages: + if not old: + assert new, "If old is not present then this should be a newly created item" + return ItemCreated(None, new) + if not new: + assert old, "If new is not present then this should be a newly deleted item" + return ItemDeleted(old, None) + return ItemModified(old, new) + + +def old_and_new_items_from_stream_record_body(stream_record_body: dict) -> ItemImages: + """If you're using the `records` wrapper this will get you what you need.""" + new = deserialize_item(stream_record_body.get("NewImage", {})) + old = deserialize_item(stream_record_body.get("OldImage", {})) + return item_images(old, new) + + +def old_and_new_items_from_stream_event_record(event_record: dict) -> ItemImages: """The event['Records'] list of dicts from a Dynamo stream as delivered to a Lambda can have each of its records processed individually by this function to deliver the two 'images' from the record. @@ -25,15 +61,47 @@ def old_and_new_items_from_stream_event_record(event_record: dict,) -> ty.Tuple[ return old_and_new_items_from_stream_record_body(event_record["dynamodb"]) -def old_and_new_dict_tuples_from_stream(event: dict) -> ty.List[ty.Tuple[Item, Item]]: - """Utility wrapper for old_and_new_items_from_stream_event_record""" - tuples = [old_and_new_items_from_stream_event_record(record) for record in event["Records"]] - logger.debug(f"Extracted {len(tuples)} stream records from the event.") - return tuples +def old_and_new_dict_tuples_from_stream(event: dict) -> ty.List[ItemImages]: + """Logging wrapper for a whole stream event. You probably don't want to use this.""" + images = [old_and_new_items_from_stream_event_record(record) for record in event["Records"]] + logger.debug(f"Extracted {len(images)} stream records from the stream event.") + return images -def old_and_new_items_from_stream_record_body(stream_record_body: dict,) -> ty.Tuple[Item, Item]: - """If you're using the `records` wrapper this will get you what you need.""" - new = deserialize_item(stream_record_body.get("NewImage", {})) - old = deserialize_item(stream_record_body.get("OldImage", {})) - return old, new +def matches_key(item_key: ItemKey) -> ty.Callable[[ItemImages], bool]: + if not item_key: + raise ValueError("Empty item key") + + def _matches_key(images: ItemImages) -> bool: + """a filter function""" + old, new = images + for k, kv in item_key.items(): + if old and not old.get(k) == kv: + return False + if new and not new.get(k) == kv: + return False + return True + + return _matches_key + + +def filter_existing(images_iter: ty.Iterable[ItemImages]) -> ty.Iterator[ExistingItemImages]: + for item_images in images_iter: + if item_images.new: + yield item_images + + +def current_nonempty_value( + item_key: ItemKey, +) -> ty.Callable[[ty.Iterable[ItemImages]], ty.Iterator[Item]]: + """We're only interested in a stream of the current values for a + particular item, and we're not interested in deletions. + """ + item_matcher = matches_key(item_key) + + def item_only_if_it_exists(images_stream: ty.Iterable[ItemImages]) -> ty.Iterator[Item]: + for existing_item_images in filter_existing(images_stream): + if item_matcher(existing_item_images): + yield existing_item_images.new + + return item_only_if_it_exists diff --git a/xoto3/dynamodb/streams/shards.py b/xoto3/dynamodb/streams/shards.py index 490ae47..42a1c24 100644 --- a/xoto3/dynamodb/streams/shards.py +++ b/xoto3/dynamodb/streams/shards.py @@ -59,7 +59,7 @@ def get_stream_arn_for_table(table_name: str, streams) -> str: return list(filter(lambda s: s["TableName"] == table_name, streams))[0]["StreamArn"] -def yield_shards(client, StreamArn: str) -> ty.Iterable[Shard]: +def yield_shards(client, StreamArn: str) -> ty.Iterator[Shard]: page_yielder = yield_pages_from_operation( *DYNAMODB_STREAMS_DESCRIBE_STREAM, client.describe_stream, dict(StreamArn=StreamArn) ) @@ -79,7 +79,7 @@ def shard_iterator_from_shard( ) -def yield_records_from_shard_iterator(client, shard_iterator: ShardIterator) -> ty.Iterable[dict]: +def yield_records_from_shard_iterator(client, shard_iterator: ShardIterator) -> ty.Iterator[dict]: yielder = yield_pages_from_operation( *DYNAMODB_STREAMS_GET_RECORDS, client.get_records, @@ -93,7 +93,7 @@ def is_shard_live(shard: Shard) -> bool: return "EndingSequenceNumber" not in shard["SequenceNumberRange"] -def only_live_shards(shards: ty.Iterable[Shard]) -> ty.Iterable[Shard]: +def only_live_shards(shards: ty.Iterable[Shard]) -> ty.Iterator[Shard]: for shard in shards: if is_shard_live(shard): yield shard @@ -107,7 +107,7 @@ def key_shards(shards: ty.List[Shard]) -> ty.Dict[str, Shard]: return {key_shard(shard): shard for shard in shards} -def live_shard_chains(shards: ty.List[Shard]) -> ty.Iterable[ty.List[Shard]]: +def live_shard_chains(shards: ty.List[Shard]) -> ty.Iterator[ty.List[Shard]]: shards_by_key = key_shards(shards) live_shards = only_live_shards(shards) for live_shard in live_shards: diff --git a/xoto3/dynamodb/write_versioned/retry.py b/xoto3/dynamodb/write_versioned/retry.py index 42e2947..ff40a38 100644 --- a/xoto3/dynamodb/write_versioned/retry.py +++ b/xoto3/dynamodb/write_versioned/retry.py @@ -36,10 +36,6 @@ def timed_retry( while attempt == 0 or time.monotonic() <= expiring_at: attempt += 1 yield # make an attempt - msg = ( - "Attempt %d to perform transaction was beaten " - + "by a different attempt. Sleeping for %s seconds." - ) sleep = _choose_sleep_len_to_average_N_attempts_in_the_total_interval( max_attempts_before_expiration, random_sleep_length, @@ -50,6 +46,7 @@ def timed_retry( # we've exceeded our maximum attempts and must exit break logger.warning( - msg, attempt, f"{sleep:.3f}", + f"Attempt {attempt} to perform transaction was beaten " + f"by a different attempt. Sleeping for {sleep:.3f} seconds.", ) time.sleep(sleep) diff --git a/xoto3/paginate.py b/xoto3/paginate.py index 46a6651..6a52747 100644 --- a/xoto3/paginate.py +++ b/xoto3/paginate.py @@ -23,11 +23,20 @@ def set_at_path(path, d, val): d[path[-1]] = val +class PagePathsTemplate(ty.NamedTuple): + """Not necessary - just a template for defining paginators""" + + exclusive_start_path: KeyPath + last_evaluated_path: KeyPath + limit_path: KeyPath + items_path: KeyPath + + def yield_pages_from_operation( exclusive_start_path: KeyPath, last_evaluated_path: KeyPath, - limit_path: ty.Tuple[str, ...], - items_path: ty.Tuple[str, ...], + limit_path: KeyPath, + items_path: KeyPath, # whether or not limiting _happens_ is controlled by whether you set a limit in your request dict # but if you provide limit_path you must provide items_path and vice-versa, # or we won't be able figure out how to create the new limit for each paged request. @@ -88,6 +97,11 @@ def yield_pages_from_operation( specific operation), so that you can then invoke the same operation paginator repeatedly with different requests. + If you're interested in the actual value of the LastEvaluated + token (or its equivalent), you may pass in a callback which we + will call before every page that we yield. Thus, your code will + have the opportunity to examine the current token before + re-entering this generator. """ assert all((limit_path, items_path)) or not any((limit_path, items_path)) request = deepcopy(request) @@ -116,15 +130,17 @@ def yield_pages_from_operation( if limit: set_limit(request, limit) page_response = operation(**request) + last_evaluated = get_le(page_response) + if last_evaluated_callback: + # we call your callback for every page, not just the last one. + last_evaluated_callback(last_evaluated) yield page_response # we yield the entire response - ExclusiveStart = get_le(page_response) or None + ExclusiveStart = last_evaluated or None if starting_limit: # a limit was requested limit = limit - len(get_items(page_response)) if limit <= 0: - # we're done; before we leave, provide last evaluated if requested - if last_evaluated_callback: - last_evaluated_callback(ExclusiveStart) + # we're done ExclusiveStart = None diff --git a/xoto3/utils/cm.py b/xoto3/utils/cm.py new file mode 100644 index 0000000..b843d6c --- /dev/null +++ b/xoto3/utils/cm.py @@ -0,0 +1,24 @@ +import contextlib +from typing import Callable, ContextManager, Iterator, TypeVar + +X = TypeVar("X") +Y = TypeVar("Y") + + +def xf_cm(xf: Callable[[X], Y]) -> Callable[[ContextManager[X]], ContextManager[Y]]: + """Transform a ContextManager that returns X into a ContextManager that returns Y. + + By 'returns' we mean the value returned by __enter__. + + Useful if you commonly want to use a particular type of context manager in a different way. + """ + + def _(cm: ContextManager[X]) -> ContextManager[Y]: + @contextlib.contextmanager + def xfing_context() -> Iterator[Y]: + with cm as entered_cm: + yield xf(entered_cm) + + return xfing_context() + + return _ diff --git a/xoto3/utils/multicast.py b/xoto3/utils/multicast.py new file mode 100644 index 0000000..f12a72f --- /dev/null +++ b/xoto3/utils/multicast.py @@ -0,0 +1,91 @@ +import contextlib +import queue +import threading +import typing as ty +from functools import partial + +from .poll import QueuePollIterable + + +class _Producer(ty.NamedTuple): + queues: ty.Dict[int, queue.Queue] + cleanup: ty.Callable[[], ty.Any] + + +E = ty.TypeVar("E") + +H = ty.TypeVar("H", bound=ty.Hashable) + +OnNext = ty.Callable[[E], ty.Any] +Cleanup = ty.Callable[[], ty.Any] + + +class LazyMulticast(ty.Generic[H, E]): + """Allows concurrent process-local subscribers to an expensive + producer. Each subscriber will receive _every_ event produced after + their subscription begins. + + Almost by definition, your producer should exist in a thread. + `start_producer` should return only after the producer is fully + set up, but should not block beyond what is necessary to get the + producer running. + + Implemented as a factory for ContextManagers, to allow clients to + easily unsubscribe. + + Expects its subscribers to operate within threads, so it is + threadsafe. + + Because this anticipates the use of threads, it also allows + subscribers to use an interface that may optionally be + non-blocking. See QueuePollIterable for details. + """ + + def __init__( + self, start_producer: ty.Callable[[H, ty.Callable[[E], ty.Any]], Cleanup], + ): + self.lock = threading.Lock() + self.start_producer = start_producer + self.producers: ty.Dict[H, _Producer] = dict() + + def _recv_event_from_producer(self, producer_key: H, producer_event: E): + ss = self.producers.get(producer_key) + if not ss: + # no current consumers + return + for q in list(ss.queues.values()): + q.put(producer_event) + + def __call__(self, producer_key: H) -> ty.ContextManager[QueuePollIterable[E]]: + """Constructs a context manager that will provide access to the + underlying multicasted producer. + + This context manager is inactive until entered using `with` - + i.e no producer exists or is subscribed. + """ + + @contextlib.contextmanager + def queue_poll_context() -> ty.Iterator[QueuePollIterable[E]]: + with self.lock: + if producer_key not in self.producers: + # create a single shared producer + cleanup = self.start_producer( + producer_key, partial(self._recv_event_from_producer, producer_key) + ) + self.producers[producer_key] = _Producer(dict(), cleanup) + + ss = self.producers[producer_key] + q: queue.Queue = queue.Queue() + ss.queues[id(q)] = q + + yield QueuePollIterable(q) + + with self.lock: + # clean up the consumer + ss.queues.pop(id(q)) + if not ss.queues: + # remove the producer consumer if no one is listening + ss.cleanup() + self.producers.pop(producer_key) + + return queue_poll_context() diff --git a/xoto3/utils/poll.py b/xoto3/utils/poll.py new file mode 100644 index 0000000..596b677 --- /dev/null +++ b/xoto3/utils/poll.py @@ -0,0 +1,64 @@ +"""A thin but not terribly composable interface combining the worlds +of "things you might want to iterate over" and "things that you might +not want to block your thread on indefinitely. +""" +import queue +import time +from typing import Callable, Iterable, Iterator, Optional, TypeVar + +from typing_extensions import Protocol + +X = TypeVar("X") +Y_co = TypeVar("Y_co", covariant=True) + + +class TimedOut(Exception): + pass + + +class Poll(Protocol[Y_co]): + def __call__(self, __timeout: Optional[float] = None) -> Y_co: + """raises TimedOut exception after timeout""" + ... # pragma: nocover + + +def expiring_poll_iter(seconds_from_start: float) -> Callable[[Poll[X]], Iterator[X]]: + """Stops iterating (does not raise) when time has expired.""" + + def _expiring_iter(poll: Poll[X]) -> Iterator[X]: + time_left = seconds_from_start + end = time.monotonic() + seconds_from_start + while time_left > 0: + try: + yield poll(time_left) + except TimedOut: + pass + time_left = end - time.monotonic() + + return _expiring_iter + + +class QueuePollIterable(Iterable[X], Poll[X]): + """A convenience implementation that provides infinite queue iterators + to simple clients, and a simplified polling interface to clients + that have a need to control timeout behavior. + + """ + + def __init__(self, q: queue.Queue): + self.q = q + self.iter_timeout = None + + def __iter__(self) -> Iterator[X]: + while True: + try: + yield self(self.iter_timeout) + except TimedOut: + return + + def __call__(self, timeout: Optional[float] = None) -> X: + """Raises TimedOut after timeout""" + try: + return self.q.get(block=True, timeout=timeout) + except queue.Empty: + raise TimedOut() diff --git a/xoto3/utils/retry.py b/xoto3/utils/retry.py new file mode 100644 index 0000000..c52f50d --- /dev/null +++ b/xoto3/utils/retry.py @@ -0,0 +1,44 @@ +import time +import typing as ty +from functools import wraps + +F = ty.TypeVar("F", bound=ty.Callable) + + +def expo(length: int = -1, y: float = 1.0,) -> ty.Iterator[float]: + """Ends iteration after 'length'. + + If you want infinite exponential values, pass a negative number for 'length'. + """ + count = 0 + while length < 0 or count < length: + yield 2 ** count * y + count += 1 + + +def sleep_join( + seconds_iter: ty.Iterable[float], sleep: ty.Callable[[float], ty.Any] = time.sleep +) -> ty.Iterator: + """A common base strategy for separating retries by sleeps.""" + yield + for secs in seconds_iter: + sleep(secs) + yield + + +def retry_while(strategy: ty.Iterable[ty.Callable[[Exception], bool]]) -> ty.Callable[[F], F]: + """Uses your retry strategy every time an exception is raised.""" + + def _retry_decorator(func: F) -> F: + @wraps(func) + def retry_wrapper(*args, **kwargs): + for is_retryable in strategy: + try: + return func(*args, **kwargs) + except Exception as e: + if not is_retryable(e): + raise + + return ty.cast(F, retry_wrapper) + + return _retry_decorator diff --git a/xoto3/utils/stream.py b/xoto3/utils/stream.py new file mode 100644 index 0000000..f0638bf --- /dev/null +++ b/xoto3/utils/stream.py @@ -0,0 +1,147 @@ +import threading +import time +import typing as ty +from logging import getLogger +from uuid import uuid4 + +logger = getLogger(__name__) + +E = ty.TypeVar("E") +StreamEventFunnel = ty.Callable[[E], ty.Any] + +Shard = ty.TypeVar("Shard") +ShardIterator = ty.TypeVar("ShardIterator") + + +class ShardedStreamFunnelController(ty.NamedTuple): + thread: threading.Thread + # the thread which will poll on a schedule for new shards and + # activate a consumer for each new shard found. + poison: ty.Callable[[], None] + # call poison to stop polling for new shards and to stop each shard + # thread when it receives its next event + + +def funnel_sharded_stream( + refresh_live_shards: ty.Callable[[], ty.Dict[str, Shard]], + startup_shard_iterator: ty.Callable[[Shard], ShardIterator], + future_shard_iterator: ty.Callable[[Shard], ShardIterator], + iterate_shard: ty.Callable[[ShardIterator], ty.Iterable[E]], + # a shard should iterate until it is fully consumed. + stream_event_funnel: StreamEventFunnel, + shard_refresh_interval: float = 5.0, +) -> ShardedStreamFunnelController: + + """A single consumer stream processor for sharded streams. + + Spawns a thread polling for live shards, which spawns other + threads, each of which will consume a single stream shard. Your + event funnel will get everything from every shard. Because each + shard will be consumed by a separate thread, your event funnel + MUST be threadsafe. + + This is _not_ suitable for cases where there might be lots and + lots of data on lots and lots of shards. This is a utility for + other cases. + + Returns a thunk that may be called to poison all the existing + shard consumers so that you don't get any more data and they + eventually close out their threads as well, as well as the thread + itself, which you can use to block until poisoning has happend. + Note, however, that since Python threads are not interruptible, + poisoning them will only take effect once they've received their + next item from their shard. If your table is not in active use + that may never happen. + + The good news is that all of these threads are started as daemon + threads, so if your process exits these threads will exit + immediately as well. + + """ + processing_id = uuid4().hex # for debugging/sanity + shard_processing_threads: ty.Dict[str, threading.Thread] = dict() + emptied_shards: ty.Set[str] = set() + + def shard_emptied(shard_id: str): + if shard_id in shard_processing_threads: + logger.debug(f"{processing_id} shard emptied {shard_id}") + shard_processing_threads.pop(shard_id, None) + emptied_shards.add(shard_id) + + cloth = dict(poisoned=False) + # ...many threads + + initial_fetch_completed = threading.Semaphore(0) + # we want to fetch the live shards and their iterators before this + # function so that client code that might generate new shards (by, + # for instance, writing a new item) cannot run before we're + # 'ready' to identify those shards as new and start processing + # them from the beginning. + # + # In other words, this is how we make sure that anything done by a + # user of this method after calling this method will show up in + # their consumer. + + def find_and_consume_all_shards(): + logger.info(f"{processing_id} Beginning to find and consume all shards") + shard_iterator_fetcher = startup_shard_iterator + shards_not_yet_started = refresh_live_shards() + + while not cloth["poisoned"]: + for shard_id, shard in shards_not_yet_started.items(): + dbg = f"{processing_id} - {shard_id}" + shard_iterator = shard_iterator_fetcher(shard) + + def consume_shard(): + try: + logger.info(f"{dbg} starting shard processor with shard {shard_iterator}") + for event in iterate_shard(shard_iterator): + if cloth["poisoned"]: + break + stream_event_funnel(event) + logger.info(f"{dbg} exiting shard iterator processor") + except Exception as e: + logger.exception(e) + shard_emptied(shard_id) + + logger.info(f"{dbg} spawning new shard processor for shard {shard_id}") + shard_processing_threads[shard_id] = threading.Thread( + target=consume_shard, daemon=True + ) + shard_processing_threads[shard_id].start() + + initial_fetch_completed.release() + # the fact that this gets released more than once is not a problem. + # we only use it to guard the first trip through the for loop above. + + # we've released the main thread. any future shards that + # we discover may use a different strategy for choosing + # what part of the shard to start at. + shard_iterator_fetcher = future_shard_iterator + + time.sleep(shard_refresh_interval) + + shards_not_yet_started = refresh_live_shards() + logger.info(f"{processing_id} refreshing live shards") + for shard_id in set(shard_processing_threads) | emptied_shards: + shards_not_yet_started.pop(shard_id, None) # type: ignore + + logger.info(f"{processing_id} Ending search for shards") + + # we use a thread so that we can actually return to the caller a way to shut all this down + # Python threads are not interruptible, so this is a little uglier than one might wish + thread = threading.Thread(target=find_and_consume_all_shards, daemon=True) + # it's a daemon thread so that this thread will not keep your program alive by itself + # This way, a Ctrl-C or other exit will do the trick cleanly. + thread.start() + + initial_fetch_completed.acquire() + # once the initial fetch has completed, we can let the caller know + # that it's safe to proceed, in case they want to write anything to + # the table. + + def poison(): + logger.info(f"{processing_id} poisoning shard runners") + cloth["poisoned"] = True + + return ShardedStreamFunnelController(thread, poison)