Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ConcurrentExecution step #511

Merged
merged 16 commits into from
Apr 21, 2024
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions storey/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from .flow import Batch # noqa: F401
from .flow import Choice # noqa: F401
from .flow import Complete # noqa: F401
from .flow import ConcurrentExecution # noqa: F401
from .flow import Context # noqa: F401
from .flow import Extend # noqa: F401
from .flow import Filter # noqa: F401
Expand Down
67 changes: 65 additions & 2 deletions storey/flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,12 @@
import copy
import datetime
import inspect
import pickle
import time
import traceback
from asyncio import Task
from collections import defaultdict
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor
from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Union

import aiohttp
Expand Down Expand Up @@ -800,11 +802,11 @@ def __init__(self, max_in_flight=None, retries=None, backoff_factor=None, **kwar
Flow.__init__(self, **kwargs)
if max_in_flight is not None and max_in_flight < 1:
raise ValueError(f"max_in_flight may not be less than 1 (got {max_in_flight})")
self.max_in_flight = max_in_flight
self.retries = retries
self.backoff_factor = backoff_factor

self._queue_size = max_in_flight - 1 if max_in_flight else 8
self.max_in_flight = max_in_flight or 8
self._queue_size = self.max_in_flight - 1

def _init(self):
super()._init()
Expand Down Expand Up @@ -916,6 +918,67 @@ async def _do(self, event):
await self._worker_awaitable


class ConcurrentExecution(_ConcurrentJobExecution):
"""
Inherit this class and override `process_event()` to process events concurrently.

:param process_event: Function that will be run on each event

:param concurrency_mechanism: One of:
* "asyncio" (default) – for I/O implemented using asyncio
* "threading" – for blocking I/O
* "multiprocessing" – for processing-intensive tasks

:param max_in_flight: Maximum number of events to be processed at a time (default 8)
:param retries: Maximum number of retries per event (default 0)
:param backoff_factor: Wait time in seconds between retries (default 1)
"""

_supported_concurrency_mechanisms = ["asyncio", "threading", "multiprocessing"]

def __init__(
self, event_processor: Callable[[Event], Any], concurrency_mechanism=None, pass_context=None, **kwargs
):
super().__init__(**kwargs)

self._event_processor = event_processor

if concurrency_mechanism and concurrency_mechanism not in self._supported_concurrency_mechanisms:
raise ValueError(f"Concurrency mechanism '{concurrency_mechanism}' is not supported")

if concurrency_mechanism == "multiprocessing" and pass_context:
try:
pickle.dumps(self.context)
except Exception as ex:
raise ValueError(
'When concurrency_mechanism="multiprocessing" is used in conjunction with '
"pass_context=True, context must be serializable"
) from ex

self._executor = None
if concurrency_mechanism == "threading":
self._executor = ThreadPoolExecutor(max_workers=self.max_in_flight)
elif concurrency_mechanism == "multiprocessing":
self._executor = ProcessPoolExecutor(max_workers=self.max_in_flight)

self._pass_context = pass_context

async def _process_event(self, event):
args = [event]
if self._pass_context:
args.append(self.context)
if self._executor:
result = await asyncio.get_running_loop().run_in_executor(self._executor, self._event_processor, *args)
else:
result = self._event_processor(*args)
if asyncio.iscoroutine(result):
result = await result
return result

async def _handle_completed(self, event, response):
await self._do_downstream(response)


class SendToHttp(_ConcurrentJobExecution):
"""Joins each event with data from any HTTP source. Used for event augmentation.

Expand Down
69 changes: 69 additions & 0 deletions tests/test_concurrent_execution.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import asyncio
import time

import pytest

from storey import AsyncEmitSource
from storey.flow import ConcurrentExecution, Reduce, build_flow
from tests.test_flow import append_and_return

event_processing_duration = 0.5


async def process_event_slow_asyncio(event):
await asyncio.sleep(event_processing_duration)
return event


def process_event_slow_io(event):
time.sleep(event_processing_duration)
return event


def process_event_slow_processing(event):
start = time.monotonic()
while time.monotonic() - start < event_processing_duration:
pass
return event


async def async_test_concurrent_execution(concurrency_mechanism, event_processor):
controller = build_flow(
[
AsyncEmitSource(),
ConcurrentExecution(
event_processor=event_processor,
concurrency_mechanism=concurrency_mechanism,
max_in_flight=10,
),
Reduce([], append_and_return),
]
).run()

num_events = 8

start = time.monotonic()
for counter in range(num_events):
await controller.emit(counter)

await controller.terminate()
result = await controller.await_termination()
end = time.monotonic()

assert result == list(range(num_events))
assert end - start > event_processing_duration, "Run time cannot be less than the time to process a single event"
assert (
end - start < event_processing_duration * num_events
), "Run time must be less than the time to process all events in serial"


@pytest.mark.parametrize(
["concurrency_mechanism", "event_processor"],
[
("asyncio", process_event_slow_asyncio),
("threading", process_event_slow_io),
("multiprocessing", process_event_slow_processing),
],
)
def test_concurrent_execution(concurrency_mechanism, event_processor):
asyncio.run(async_test_concurrent_execution(concurrency_mechanism, event_processor))
Loading