diff --git a/storey/__init__.py b/storey/__init__.py index 722f3b10..0f3c0a10 100644 --- a/storey/__init__.py +++ b/storey/__init__.py @@ -36,6 +36,7 @@ from .flow import Batch # noqa: F401 from .flow import Choice # noqa: F401 from .flow import Complete # noqa: F401 +from .flow import ConcurrentExecution # noqa: F401 from .flow import Context # noqa: F401 from .flow import Extend # noqa: F401 from .flow import Filter # noqa: F401 diff --git a/storey/flow.py b/storey/flow.py index 97cc3c2b..f2fd9a6e 100644 --- a/storey/flow.py +++ b/storey/flow.py @@ -16,10 +16,12 @@ import copy import datetime import inspect +import pickle import time import traceback from asyncio import Task from collections import defaultdict +from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Union import aiohttp @@ -800,11 +802,11 @@ def __init__(self, max_in_flight=None, retries=None, backoff_factor=None, **kwar Flow.__init__(self, **kwargs) if max_in_flight is not None and max_in_flight < 1: raise ValueError(f"max_in_flight may not be less than 1 (got {max_in_flight})") - self.max_in_flight = max_in_flight self.retries = retries self.backoff_factor = backoff_factor - self._queue_size = max_in_flight - 1 if max_in_flight else 8 + self._max_in_flight = max_in_flight or 8 + self._queue_size = self._max_in_flight - 1 def _init(self): super()._init() @@ -916,6 +918,73 @@ async def _do(self, event): await self._worker_awaitable +class ConcurrentExecution(_ConcurrentJobExecution): + """ + Inherit this class and override `process_event()` to process events concurrently. + + :param process_event: Function that will be run on each event + + :param concurrency_mechanism: One of: + * "asyncio" (default) – for I/O implemented using asyncio + * "threading" – for blocking I/O + * "multiprocessing" – for processing-intensive tasks + + :param max_in_flight: Maximum number of events to be processed at a time (default 8) + :param retries: Maximum number of retries per event (default 0) + :param backoff_factor: Wait time in seconds between retries (default 1) + :param pass_context: If False, the process_event function will be called with just one parameter (event). If True, + the process_event function will be called with two parameters (event, context). Defaults to False. + """ + + _supported_concurrency_mechanisms = ["asyncio", "threading", "multiprocessing"] + + def __init__( + self, + event_processor: Union[Callable[[Event], Any], Callable[[Event, Any], Any]], + concurrency_mechanism=None, + pass_context=None, + **kwargs, + ): + super().__init__(**kwargs) + + self._event_processor = event_processor + + if concurrency_mechanism and concurrency_mechanism not in self._supported_concurrency_mechanisms: + raise ValueError(f"Concurrency mechanism '{concurrency_mechanism}' is not supported") + + if concurrency_mechanism == "multiprocessing" and pass_context: + try: + pickle.dumps(self.context) + except Exception as ex: + raise ValueError( + 'When concurrency_mechanism="multiprocessing" is used in conjunction with ' + "pass_context=True, context must be serializable" + ) from ex + + self._executor = None + if concurrency_mechanism == "threading": + self._executor = ThreadPoolExecutor(max_workers=self._max_in_flight) + elif concurrency_mechanism == "multiprocessing": + self._executor = ProcessPoolExecutor(max_workers=self._max_in_flight) + + self._pass_context = pass_context + + async def _process_event(self, event): + args = [event] + if self._pass_context: + args.append(self.context) + if self._executor: + result = await asyncio.get_running_loop().run_in_executor(self._executor, self._event_processor, *args) + else: + result = self._event_processor(*args) + if asyncio.iscoroutine(result): + result = await result + return result + + async def _handle_completed(self, event, response): + await self._do_downstream(response) + + class SendToHttp(_ConcurrentJobExecution): """Joins each event with data from any HTTP source. Used for event augmentation. diff --git a/tests/test_concurrent_execution.py b/tests/test_concurrent_execution.py new file mode 100644 index 00000000..a1f84bbd --- /dev/null +++ b/tests/test_concurrent_execution.py @@ -0,0 +1,78 @@ +import asyncio +import time + +import pytest + +from storey import AsyncEmitSource +from storey.flow import ConcurrentExecution, Reduce, build_flow +from tests.test_flow import append_and_return + +event_processing_duration = 0.5 + + +class SomeContext: + def __init__(self): + self.fn = lambda x: x + + +async def process_event_slow_asyncio(event, context): + assert isinstance(context, SomeContext) and callable(context.fn) + await asyncio.sleep(event_processing_duration) + return event + + +def process_event_slow_io(event, context): + assert isinstance(context, SomeContext) and callable(context.fn) + time.sleep(event_processing_duration) + return event + + +def process_event_slow_processing(event): + start = time.monotonic() + while time.monotonic() - start < event_processing_duration: + pass + return event + + +async def async_test_concurrent_execution(concurrency_mechanism, event_processor, pass_context): + controller = build_flow( + [ + AsyncEmitSource(), + ConcurrentExecution( + event_processor=event_processor, + concurrency_mechanism=concurrency_mechanism, + pass_context=pass_context, + max_in_flight=10, + context=SomeContext(), + ), + Reduce([], append_and_return), + ] + ).run() + + num_events = 8 + + start = time.monotonic() + for counter in range(num_events): + await controller.emit(counter) + + await controller.terminate() + result = await controller.await_termination() + end = time.monotonic() + + assert result == list(range(num_events)) + assert end - start > event_processing_duration, "Run time cannot be less than the time to process a single event" + assert ( + end - start < event_processing_duration * num_events + ), "Run time must be less than the time to process all events in serial" + + +@pytest.mark.parametrize( + ["concurrency_mechanism", "event_processor", "pass_context"], + [ + ("asyncio", process_event_slow_asyncio, True), + ("threading", process_event_slow_io, True), + ("multiprocessing", process_event_slow_processing, False), + ], +) +def test_concurrent_execution(concurrency_mechanism, event_processor, pass_context): + asyncio.run(async_test_concurrent_execution(concurrency_mechanism, event_processor, pass_context))