From a9ab8190c60f93ba12e5cfcf98cf7ab898cfda2c Mon Sep 17 00:00:00 2001 From: Gal Topper Date: Thu, 11 Apr 2024 16:37:28 +0800 Subject: [PATCH 01/15] Add `ConcurrentExecution` step Based on the preexisting internal `_ConcurrentJobExecution` class. With support for three concurrency mechanisms: asyncio, threading, and multiprocessing. --- storey/flow.py | 49 ++++++++++++++++++++- tests/test_concurrent_execution.py | 69 ++++++++++++++++++++++++++++++ 2 files changed, 117 insertions(+), 1 deletion(-) create mode 100644 tests/test_concurrent_execution.py diff --git a/storey/flow.py b/storey/flow.py index 97cc3c2b..98adefbf 100644 --- a/storey/flow.py +++ b/storey/flow.py @@ -20,6 +20,7 @@ import traceback from asyncio import Task from collections import defaultdict +from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Union import aiohttp @@ -804,7 +805,8 @@ def __init__(self, max_in_flight=None, retries=None, backoff_factor=None, **kwar self.retries = retries self.backoff_factor = backoff_factor - self._queue_size = max_in_flight - 1 if max_in_flight else 8 + max_in_flight = max_in_flight or 8 + self._queue_size = max_in_flight - 1 def _init(self): super()._init() @@ -916,6 +918,51 @@ async def _do(self, event): await self._worker_awaitable +class ConcurrentExecution(_ConcurrentJobExecution): + """ + Inherit this class and override `process_event()` to process events concurrently. + + :param process_event: Function that will be run on each event + + :param concurrency_mechanism: One of: + * "asyncio" (default) – for I/O implemented using asyncio + * "threading" – for blocking I/O + * "multiprocessing" – for processing-intensive tasks + + :param max_in_flight: Maximum number of events to be processed at a time (default 8) + :param retries: Maximum number of retries per event (default 0) + :param backoff_factor: Wait time in seconds between retries (default 1) + """ + + _suppported_concurrency_mechanisms = ["asyncio", "threading", "multiprocessing"] + + def __init__(self, event_processor: Callable[[Event], Any], concurrency_mechanism=None, **kwargs): + super().__init__(**kwargs) + + self._event_processor = event_processor + + if concurrency_mechanism and concurrency_mechanism not in self._suppported_concurrency_mechanisms: + raise ValueError(f"Concurrency mechanism '{concurrency_mechanism}' is not supported") + + self._executor = None + if concurrency_mechanism == "threading": + self._executor = ThreadPoolExecutor(max_workers=self.max_in_flight) + elif concurrency_mechanism == "multiprocessing": + self._executor = ProcessPoolExecutor(max_workers=self.max_in_flight) + + async def _process_event(self, event): + if self._executor: + result = await asyncio.get_running_loop().run_in_executor(self._executor, self._event_processor, event) + else: + result = self._event_processor(event) + if asyncio.iscoroutine(result): + result = await result + return result + + async def _handle_completed(self, event, response): + await self._do_downstream(response) + + class SendToHttp(_ConcurrentJobExecution): """Joins each event with data from any HTTP source. Used for event augmentation. diff --git a/tests/test_concurrent_execution.py b/tests/test_concurrent_execution.py new file mode 100644 index 00000000..2cc8b510 --- /dev/null +++ b/tests/test_concurrent_execution.py @@ -0,0 +1,69 @@ +import asyncio +import time + +import pytest + +from storey import AsyncEmitSource +from storey.flow import ConcurrentExecution, Reduce, build_flow +from tests.test_flow import append_and_return + +event_processing_duration = 0.5 + + +async def process_event_slow_asyncio(event): + await asyncio.sleep(event_processing_duration) + return event + + +def process_event_slow_io(event): + time.sleep(event_processing_duration) + return event + + +def process_event_slow_processing(event): + start = time.monotonic() + while time.monotonic() - start < event_processing_duration: + pass + return event + + +async def async_test_concurrent_execution(concurrency_mechanism, event_processor): + controller = build_flow( + [ + AsyncEmitSource(), + ConcurrentExecution( + event_processor=event_processor, + concurrency_mechanism=concurrency_mechanism, + max_in_flight=10, + ), + Reduce([], append_and_return), + ] + ).run() + + num_events = 8 + + start = time.monotonic() + for counter in range(num_events): + await controller.emit(counter) + + await controller.terminate() + result = await controller.await_termination() + end = time.monotonic() + + assert result == list(range(num_events)) + assert end - start > event_processing_duration, "Run time cannot be less than the time to process a single event" + assert ( + end - start < event_processing_duration * num_events + ), "Run time must be less than the time to process all events in serial" + + +@pytest.mark.parametrize( + ["concurrency_mechanism", "event_processor"], + [ + ("asyncio", process_event_slow_asyncio), + ("threading", process_event_slow_io), + ("multiprocessing", process_event_slow_processing), + ], +) +def test_concurrent_execution(concurrency_mechanism, event_processor): + asyncio.run(async_test_concurrent_execution(concurrency_mechanism, event_processor)) From 0f09dab8acb4cc64e8e1865b690b990e0e0475a1 Mon Sep 17 00:00:00 2001 From: Gal Topper Date: Sun, 14 Apr 2024 19:08:27 +0800 Subject: [PATCH 02/15] Fix typo --- storey/flow.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/storey/flow.py b/storey/flow.py index 98adefbf..d29ad1ae 100644 --- a/storey/flow.py +++ b/storey/flow.py @@ -934,14 +934,14 @@ class ConcurrentExecution(_ConcurrentJobExecution): :param backoff_factor: Wait time in seconds between retries (default 1) """ - _suppported_concurrency_mechanisms = ["asyncio", "threading", "multiprocessing"] + _supported_concurrency_mechanisms = ["asyncio", "threading", "multiprocessing"] def __init__(self, event_processor: Callable[[Event], Any], concurrency_mechanism=None, **kwargs): super().__init__(**kwargs) self._event_processor = event_processor - if concurrency_mechanism and concurrency_mechanism not in self._suppported_concurrency_mechanisms: + if concurrency_mechanism and concurrency_mechanism not in self._supported_concurrency_mechanisms: raise ValueError(f"Concurrency mechanism '{concurrency_mechanism}' is not supported") self._executor = None From cf025fbca92d79b7f3c08c92402b15ef4dab81e3 Mon Sep 17 00:00:00 2001 From: Gal Topper Date: Mon, 15 Apr 2024 15:37:04 +0800 Subject: [PATCH 03/15] Expose `ConcurrentExecution` --- storey/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/storey/__init__.py b/storey/__init__.py index 722f3b10..0f3c0a10 100644 --- a/storey/__init__.py +++ b/storey/__init__.py @@ -36,6 +36,7 @@ from .flow import Batch # noqa: F401 from .flow import Choice # noqa: F401 from .flow import Complete # noqa: F401 +from .flow import ConcurrentExecution # noqa: F401 from .flow import Context # noqa: F401 from .flow import Extend # noqa: F401 from .flow import Filter # noqa: F401 From 16d7fe5fa97ba63995f792162af3860bdc0e3715 Mon Sep 17 00:00:00 2001 From: Gal Topper Date: Mon, 15 Apr 2024 17:31:41 +0800 Subject: [PATCH 04/15] Allow `ConcurrentExecution` to pass context to user function --- storey/flow.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/storey/flow.py b/storey/flow.py index d29ad1ae..9f8ed36d 100644 --- a/storey/flow.py +++ b/storey/flow.py @@ -936,7 +936,9 @@ class ConcurrentExecution(_ConcurrentJobExecution): _supported_concurrency_mechanisms = ["asyncio", "threading", "multiprocessing"] - def __init__(self, event_processor: Callable[[Event], Any], concurrency_mechanism=None, **kwargs): + def __init__( + self, event_processor: Callable[[Event], Any], concurrency_mechanism=None, pass_context=None, **kwargs + ): super().__init__(**kwargs) self._event_processor = event_processor @@ -950,11 +952,16 @@ def __init__(self, event_processor: Callable[[Event], Any], concurrency_mechanis elif concurrency_mechanism == "multiprocessing": self._executor = ProcessPoolExecutor(max_workers=self.max_in_flight) + self._pass_context = pass_context + async def _process_event(self, event): + args = [event] + if self._pass_context: + args += self.context if self._executor: - result = await asyncio.get_running_loop().run_in_executor(self._executor, self._event_processor, event) + result = await asyncio.get_running_loop().run_in_executor(self._executor, self._event_processor, *args) else: - result = self._event_processor(event) + result = self._event_processor(*args) if asyncio.iscoroutine(result): result = await result return result From 918223dcbf799eab851ac66157a09a8527993a7a Mon Sep 17 00:00:00 2001 From: Gal Topper Date: Mon, 15 Apr 2024 17:55:33 +0800 Subject: [PATCH 05/15] Fix list append --- storey/flow.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/storey/flow.py b/storey/flow.py index 9f8ed36d..06221a8a 100644 --- a/storey/flow.py +++ b/storey/flow.py @@ -957,7 +957,7 @@ def __init__( async def _process_event(self, event): args = [event] if self._pass_context: - args += self.context + args.append(self.context) if self._executor: result = await asyncio.get_running_loop().run_in_executor(self._executor, self._event_processor, *args) else: From d9ac0693ab890e663fcf9f224807e119aea2d875 Mon Sep 17 00:00:00 2001 From: Gal Topper Date: Mon, 15 Apr 2024 18:15:10 +0800 Subject: [PATCH 06/15] Fail early on non-serializable context --- storey/flow.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/storey/flow.py b/storey/flow.py index 06221a8a..9f9ad49a 100644 --- a/storey/flow.py +++ b/storey/flow.py @@ -16,6 +16,7 @@ import copy import datetime import inspect +import pickle import time import traceback from asyncio import Task @@ -946,6 +947,15 @@ def __init__( if concurrency_mechanism and concurrency_mechanism not in self._supported_concurrency_mechanisms: raise ValueError(f"Concurrency mechanism '{concurrency_mechanism}' is not supported") + if concurrency_mechanism == "multiprocessing" and pass_context: + try: + pickle.dumps(self.context) + except Exception as ex: + raise ValueError( + 'When concurrency_mechanism="multiprocessing" is used in conjunction with ' + "pass_context=True, context must be serializable" + ) from ex + self._executor = None if concurrency_mechanism == "threading": self._executor = ThreadPoolExecutor(max_workers=self.max_in_flight) From f78f0096cf53a472fd9e8ec743d53a2bf745b6c8 Mon Sep 17 00:00:00 2001 From: Gal Topper Date: Wed, 17 Apr 2024 18:57:11 +0800 Subject: [PATCH 07/15] Fix passing of default `max_in_flight` as `max_workers` --- storey/flow.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/storey/flow.py b/storey/flow.py index 9f9ad49a..03071de1 100644 --- a/storey/flow.py +++ b/storey/flow.py @@ -802,12 +802,11 @@ def __init__(self, max_in_flight=None, retries=None, backoff_factor=None, **kwar Flow.__init__(self, **kwargs) if max_in_flight is not None and max_in_flight < 1: raise ValueError(f"max_in_flight may not be less than 1 (got {max_in_flight})") - self.max_in_flight = max_in_flight self.retries = retries self.backoff_factor = backoff_factor - max_in_flight = max_in_flight or 8 - self._queue_size = max_in_flight - 1 + self.max_in_flight = max_in_flight or 8 + self._queue_size = self.max_in_flight - 1 def _init(self): super()._init() From 682444e32878460e1a3349b747e76ba153e15bea Mon Sep 17 00:00:00 2001 From: Gal Topper Date: Wed, 17 Apr 2024 19:10:30 +0800 Subject: [PATCH 08/15] Fix --- storey/flow.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/storey/flow.py b/storey/flow.py index 03071de1..e309db36 100644 --- a/storey/flow.py +++ b/storey/flow.py @@ -805,8 +805,8 @@ def __init__(self, max_in_flight=None, retries=None, backoff_factor=None, **kwar self.retries = retries self.backoff_factor = backoff_factor - self.max_in_flight = max_in_flight or 8 - self._queue_size = self.max_in_flight - 1 + self._max_in_flight = max_in_flight or 8 + self._queue_size = self._max_in_flight - 1 def _init(self): super()._init() @@ -957,9 +957,9 @@ def __init__( self._executor = None if concurrency_mechanism == "threading": - self._executor = ThreadPoolExecutor(max_workers=self.max_in_flight) + self._executor = ThreadPoolExecutor(max_workers=self._max_in_flight) elif concurrency_mechanism == "multiprocessing": - self._executor = ProcessPoolExecutor(max_workers=self.max_in_flight) + self._executor = ProcessPoolExecutor(max_workers=self._max_in_flight) self._pass_context = pass_context From 153b1f7f37a6381f7a65d99f16e4a9c3687191a1 Mon Sep 17 00:00:00 2001 From: Gal Topper Date: Thu, 18 Apr 2024 14:23:02 +0800 Subject: [PATCH 09/15] Support passing context to multiprocessing step --- requirements.txt | 1 + storey/__init__.py | 1 + storey/flow.py | 41 ++++++++++++++----------- tests/test_concurrent_execution.py | 48 +++++++++++++++++++----------- 4 files changed, 57 insertions(+), 34 deletions(-) diff --git a/requirements.txt b/requirements.txt index 88f94e1c..ab71edba 100644 --- a/requirements.txt +++ b/requirements.txt @@ -13,3 +13,4 @@ fsspec>=0.6.2 v3iofs~=0.1.17 xxhash>=1 nuclio-sdk>=0.5.3 +dill~=0.3.8 diff --git a/storey/__init__.py b/storey/__init__.py index 0f3c0a10..bf9b37bb 100644 --- a/storey/__init__.py +++ b/storey/__init__.py @@ -54,6 +54,7 @@ from .flow import Reduce # noqa: F401 from .flow import Rename # noqa: F401 from .flow import SendToHttp # noqa: F401 +from .flow import UserFunction # noqa: F401 from .flow import build_flow # noqa: F401 from .sources import AsyncEmitSource # noqa: F401 from .sources import CSVSource # noqa: F401 diff --git a/storey/flow.py b/storey/flow.py index e309db36..0bfc8e07 100644 --- a/storey/flow.py +++ b/storey/flow.py @@ -16,7 +16,6 @@ import copy import datetime import inspect -import pickle import time import traceback from asyncio import Task @@ -25,6 +24,7 @@ from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Union import aiohttp +import dill from .dtypes import Event, FlowError, V3ioError, _termination_obj, known_driver_schemes from .queue import AsyncQueue @@ -918,6 +918,16 @@ async def _do(self, event): await self._worker_awaitable +class UserFunction: + def call(self, *args): + raise NotImplementedError() + + def _unpickle_context_and_call(self, *args): + event, context = args + context = dill.loads(context) + return self.call(event, context) + + class ConcurrentExecution(_ConcurrentJobExecution): """ Inherit this class and override `process_event()` to process events concurrently. @@ -936,9 +946,7 @@ class ConcurrentExecution(_ConcurrentJobExecution): _supported_concurrency_mechanisms = ["asyncio", "threading", "multiprocessing"] - def __init__( - self, event_processor: Callable[[Event], Any], concurrency_mechanism=None, pass_context=None, **kwargs - ): + def __init__(self, event_processor: UserFunction, concurrency_mechanism=None, pass_context=None, **kwargs): super().__init__(**kwargs) self._event_processor = event_processor @@ -946,15 +954,6 @@ def __init__( if concurrency_mechanism and concurrency_mechanism not in self._supported_concurrency_mechanisms: raise ValueError(f"Concurrency mechanism '{concurrency_mechanism}' is not supported") - if concurrency_mechanism == "multiprocessing" and pass_context: - try: - pickle.dumps(self.context) - except Exception as ex: - raise ValueError( - 'When concurrency_mechanism="multiprocessing" is used in conjunction with ' - "pass_context=True, context must be serializable" - ) from ex - self._executor = None if concurrency_mechanism == "threading": self._executor = ThreadPoolExecutor(max_workers=self._max_in_flight) @@ -965,12 +964,20 @@ def __init__( async def _process_event(self, event): args = [event] - if self._pass_context: - args.append(self.context) if self._executor: - result = await asyncio.get_running_loop().run_in_executor(self._executor, self._event_processor, *args) + func = self._event_processor.call + if isinstance(self._executor, ProcessPoolExecutor): + if self._pass_context: + # dill, unlike pickle, is able to serialize function objects + args.append(dill.dumps(self.context)) + func = self._event_processor._unpickle_context_and_call + elif self._pass_context: + args.append(self.context) + result = await asyncio.get_running_loop().run_in_executor(self._executor, func, *args) else: - result = self._event_processor(*args) + if self._pass_context: + args.append(self.context) + result = self._event_processor.call(*args) if asyncio.iscoroutine(result): result = await result return result diff --git a/tests/test_concurrent_execution.py b/tests/test_concurrent_execution.py index 2cc8b510..61c0ba21 100644 --- a/tests/test_concurrent_execution.py +++ b/tests/test_concurrent_execution.py @@ -4,27 +4,38 @@ import pytest from storey import AsyncEmitSource -from storey.flow import ConcurrentExecution, Reduce, build_flow +from storey.flow import ConcurrentExecution, Reduce, UserFunction, build_flow from tests.test_flow import append_and_return event_processing_duration = 0.5 -async def process_event_slow_asyncio(event): - await asyncio.sleep(event_processing_duration) - return event +class SomeContext: + def __init__(self): + self.fn = lambda x: x -def process_event_slow_io(event): - time.sleep(event_processing_duration) - return event +class ProcessEventSlowAsyncio(UserFunction): + async def call(self, event, context): + assert isinstance(context, SomeContext) and callable(context.fn) + await asyncio.sleep(event_processing_duration) + return event -def process_event_slow_processing(event): - start = time.monotonic() - while time.monotonic() - start < event_processing_duration: - pass - return event +class ProcessEventSlowIO(UserFunction): + def call(self, event, context): + assert isinstance(context, SomeContext) and callable(context.fn) + time.sleep(event_processing_duration) + return event + + +class ProcessEventSlowProcessing(UserFunction): + def call(self, event, context): + assert isinstance(context, SomeContext) and callable(context.fn) + start = time.monotonic() + while time.monotonic() - start < event_processing_duration: + pass + return event async def async_test_concurrent_execution(concurrency_mechanism, event_processor): @@ -34,7 +45,9 @@ async def async_test_concurrent_execution(concurrency_mechanism, event_processor ConcurrentExecution( event_processor=event_processor, concurrency_mechanism=concurrency_mechanism, + pass_context=True, max_in_flight=10, + context=SomeContext(), ), Reduce([], append_and_return), ] @@ -58,12 +71,13 @@ async def async_test_concurrent_execution(concurrency_mechanism, event_processor @pytest.mark.parametrize( - ["concurrency_mechanism", "event_processor"], + ["concurrency_mechanism", "event_processor_class"], [ - ("asyncio", process_event_slow_asyncio), - ("threading", process_event_slow_io), - ("multiprocessing", process_event_slow_processing), + ("asyncio", ProcessEventSlowAsyncio), + ("threading", ProcessEventSlowIO), + ("multiprocessing", ProcessEventSlowProcessing), ], ) -def test_concurrent_execution(concurrency_mechanism, event_processor): +def test_concurrent_execution(concurrency_mechanism, event_processor_class): + event_processor = event_processor_class() asyncio.run(async_test_concurrent_execution(concurrency_mechanism, event_processor)) From 587b91bae58dc51c026b1ae715f9a6d12b411042 Mon Sep 17 00:00:00 2001 From: Gal Topper Date: Thu, 18 Apr 2024 14:28:09 +0800 Subject: [PATCH 10/15] Minor refactoring --- storey/flow.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/storey/flow.py b/storey/flow.py index 0bfc8e07..18a02deb 100644 --- a/storey/flow.py +++ b/storey/flow.py @@ -966,13 +966,13 @@ async def _process_event(self, event): args = [event] if self._executor: func = self._event_processor.call - if isinstance(self._executor, ProcessPoolExecutor): - if self._pass_context: + context = self.context + if self._pass_context: + if isinstance(self._executor, ProcessPoolExecutor): # dill, unlike pickle, is able to serialize function objects - args.append(dill.dumps(self.context)) + context = dill.dumps(self.context) func = self._event_processor._unpickle_context_and_call - elif self._pass_context: - args.append(self.context) + args.append(context) result = await asyncio.get_running_loop().run_in_executor(self._executor, func, *args) else: if self._pass_context: From 3bf0123b52379835c58f229cc05589c0b5cf5d4d Mon Sep 17 00:00:00 2001 From: Gal Topper Date: Thu, 18 Apr 2024 17:53:05 +0800 Subject: [PATCH 11/15] Change event processor back to function because of mlrun serialization issues --- storey/__init__.py | 1 - storey/flow.py | 25 ++++++++--------- tests/test_concurrent_execution.py | 44 ++++++++++++++---------------- 3 files changed, 31 insertions(+), 39 deletions(-) diff --git a/storey/__init__.py b/storey/__init__.py index bf9b37bb..0f3c0a10 100644 --- a/storey/__init__.py +++ b/storey/__init__.py @@ -54,7 +54,6 @@ from .flow import Reduce # noqa: F401 from .flow import Rename # noqa: F401 from .flow import SendToHttp # noqa: F401 -from .flow import UserFunction # noqa: F401 from .flow import build_flow # noqa: F401 from .sources import AsyncEmitSource # noqa: F401 from .sources import CSVSource # noqa: F401 diff --git a/storey/flow.py b/storey/flow.py index 18a02deb..6913d06c 100644 --- a/storey/flow.py +++ b/storey/flow.py @@ -15,6 +15,7 @@ import asyncio import copy import datetime +import functools import inspect import time import traceback @@ -918,16 +919,6 @@ async def _do(self, event): await self._worker_awaitable -class UserFunction: - def call(self, *args): - raise NotImplementedError() - - def _unpickle_context_and_call(self, *args): - event, context = args - context = dill.loads(context) - return self.call(event, context) - - class ConcurrentExecution(_ConcurrentJobExecution): """ Inherit this class and override `process_event()` to process events concurrently. @@ -946,7 +937,13 @@ class ConcurrentExecution(_ConcurrentJobExecution): _supported_concurrency_mechanisms = ["asyncio", "threading", "multiprocessing"] - def __init__(self, event_processor: UserFunction, concurrency_mechanism=None, pass_context=None, **kwargs): + @staticmethod + def _unpickle_context_and_call(function, *args): + event, context = args + context = dill.loads(context) + return function(event, context) + + def __init__(self, event_processor: Callable, concurrency_mechanism=None, pass_context=None, **kwargs): super().__init__(**kwargs) self._event_processor = event_processor @@ -965,19 +962,19 @@ def __init__(self, event_processor: UserFunction, concurrency_mechanism=None, pa async def _process_event(self, event): args = [event] if self._executor: - func = self._event_processor.call + func = self._event_processor context = self.context if self._pass_context: if isinstance(self._executor, ProcessPoolExecutor): # dill, unlike pickle, is able to serialize function objects context = dill.dumps(self.context) - func = self._event_processor._unpickle_context_and_call + func = functools.partial(self._unpickle_context_and_call, self._event_processor) args.append(context) result = await asyncio.get_running_loop().run_in_executor(self._executor, func, *args) else: if self._pass_context: args.append(self.context) - result = self._event_processor.call(*args) + result = self._event_processor(*args) if asyncio.iscoroutine(result): result = await result return result diff --git a/tests/test_concurrent_execution.py b/tests/test_concurrent_execution.py index 61c0ba21..ad3e4bcb 100644 --- a/tests/test_concurrent_execution.py +++ b/tests/test_concurrent_execution.py @@ -4,7 +4,7 @@ import pytest from storey import AsyncEmitSource -from storey.flow import ConcurrentExecution, Reduce, UserFunction, build_flow +from storey.flow import ConcurrentExecution, Reduce, build_flow from tests.test_flow import append_and_return event_processing_duration = 0.5 @@ -15,27 +15,24 @@ def __init__(self): self.fn = lambda x: x -class ProcessEventSlowAsyncio(UserFunction): - async def call(self, event, context): - assert isinstance(context, SomeContext) and callable(context.fn) - await asyncio.sleep(event_processing_duration) - return event +async def process_event_slow_asyncio(event, context): + assert isinstance(context, SomeContext) and callable(context.fn) + await asyncio.sleep(event_processing_duration) + return event -class ProcessEventSlowIO(UserFunction): - def call(self, event, context): - assert isinstance(context, SomeContext) and callable(context.fn) - time.sleep(event_processing_duration) - return event +def process_even_slow_io(event, context): + assert isinstance(context, SomeContext) and callable(context.fn) + time.sleep(event_processing_duration) + return event -class ProcessEventSlowProcessing(UserFunction): - def call(self, event, context): - assert isinstance(context, SomeContext) and callable(context.fn) - start = time.monotonic() - while time.monotonic() - start < event_processing_duration: - pass - return event +def process_event_slow_processing(event, context): + assert isinstance(context, SomeContext) and callable(context.fn) + start = time.monotonic() + while time.monotonic() - start < event_processing_duration: + pass + return event async def async_test_concurrent_execution(concurrency_mechanism, event_processor): @@ -71,13 +68,12 @@ async def async_test_concurrent_execution(concurrency_mechanism, event_processor @pytest.mark.parametrize( - ["concurrency_mechanism", "event_processor_class"], + ["concurrency_mechanism", "event_processor"], [ - ("asyncio", ProcessEventSlowAsyncio), - ("threading", ProcessEventSlowIO), - ("multiprocessing", ProcessEventSlowProcessing), + ("asyncio", process_event_slow_asyncio), + ("threading", process_even_slow_io), + ("multiprocessing", process_event_slow_processing), ], ) -def test_concurrent_execution(concurrency_mechanism, event_processor_class): - event_processor = event_processor_class() +def test_concurrent_execution(concurrency_mechanism, event_processor): asyncio.run(async_test_concurrent_execution(concurrency_mechanism, event_processor)) From 30eff1497792526217c912cd7633bf0fec6d175b Mon Sep 17 00:00:00 2001 From: Gal Topper Date: Thu, 18 Apr 2024 18:04:48 +0800 Subject: [PATCH 12/15] Move function to avoid serialization issue --- storey/flow.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/storey/flow.py b/storey/flow.py index 6913d06c..5aff3de5 100644 --- a/storey/flow.py +++ b/storey/flow.py @@ -919,6 +919,12 @@ async def _do(self, event): await self._worker_awaitable +def _unpickle_context_and_call(function, *args): + event, context = args + context = dill.loads(context) + return function(event, context) + + class ConcurrentExecution(_ConcurrentJobExecution): """ Inherit this class and override `process_event()` to process events concurrently. @@ -937,12 +943,6 @@ class ConcurrentExecution(_ConcurrentJobExecution): _supported_concurrency_mechanisms = ["asyncio", "threading", "multiprocessing"] - @staticmethod - def _unpickle_context_and_call(function, *args): - event, context = args - context = dill.loads(context) - return function(event, context) - def __init__(self, event_processor: Callable, concurrency_mechanism=None, pass_context=None, **kwargs): super().__init__(**kwargs) @@ -968,7 +968,7 @@ async def _process_event(self, event): if isinstance(self._executor, ProcessPoolExecutor): # dill, unlike pickle, is able to serialize function objects context = dill.dumps(self.context) - func = functools.partial(self._unpickle_context_and_call, self._event_processor) + func = functools.partial(_unpickle_context_and_call, self._event_processor) args.append(context) result = await asyncio.get_running_loop().run_in_executor(self._executor, func, *args) else: From 79d9b2b9099771aac0790641d0e05e56f38d39c5 Mon Sep 17 00:00:00 2001 From: Gal Topper Date: Thu, 18 Apr 2024 18:36:02 +0800 Subject: [PATCH 13/15] Add documentation --- storey/flow.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/storey/flow.py b/storey/flow.py index 5aff3de5..bce22bc0 100644 --- a/storey/flow.py +++ b/storey/flow.py @@ -939,6 +939,8 @@ class ConcurrentExecution(_ConcurrentJobExecution): :param max_in_flight: Maximum number of events to be processed at a time (default 8) :param retries: Maximum number of retries per event (default 0) :param backoff_factor: Wait time in seconds between retries (default 1) + :param pass_context: If False, the process_event function will be called with just one parameter (event). If True, + the process_event function will be called with two parameters (event, context). Defaults to False. """ _supported_concurrency_mechanisms = ["asyncio", "threading", "multiprocessing"] From 87dc1df39c50035f49af71aca3a7d45a7d9134a8 Mon Sep 17 00:00:00 2001 From: Gal Topper Date: Thu, 18 Apr 2024 19:37:11 +0800 Subject: [PATCH 14/15] Revert attempts to pass context to multiprocessing, add docs --- storey/flow.py | 40 +++++++++++++++--------------- tests/test_concurrent_execution.py | 21 ++++++++-------- 2 files changed, 30 insertions(+), 31 deletions(-) diff --git a/storey/flow.py b/storey/flow.py index bce22bc0..f2fd9a6e 100644 --- a/storey/flow.py +++ b/storey/flow.py @@ -15,8 +15,8 @@ import asyncio import copy import datetime -import functools import inspect +import pickle import time import traceback from asyncio import Task @@ -25,7 +25,6 @@ from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Union import aiohttp -import dill from .dtypes import Event, FlowError, V3ioError, _termination_obj, known_driver_schemes from .queue import AsyncQueue @@ -919,12 +918,6 @@ async def _do(self, event): await self._worker_awaitable -def _unpickle_context_and_call(function, *args): - event, context = args - context = dill.loads(context) - return function(event, context) - - class ConcurrentExecution(_ConcurrentJobExecution): """ Inherit this class and override `process_event()` to process events concurrently. @@ -945,7 +938,13 @@ class ConcurrentExecution(_ConcurrentJobExecution): _supported_concurrency_mechanisms = ["asyncio", "threading", "multiprocessing"] - def __init__(self, event_processor: Callable, concurrency_mechanism=None, pass_context=None, **kwargs): + def __init__( + self, + event_processor: Union[Callable[[Event], Any], Callable[[Event, Any], Any]], + concurrency_mechanism=None, + pass_context=None, + **kwargs, + ): super().__init__(**kwargs) self._event_processor = event_processor @@ -953,6 +952,15 @@ def __init__(self, event_processor: Callable, concurrency_mechanism=None, pass_c if concurrency_mechanism and concurrency_mechanism not in self._supported_concurrency_mechanisms: raise ValueError(f"Concurrency mechanism '{concurrency_mechanism}' is not supported") + if concurrency_mechanism == "multiprocessing" and pass_context: + try: + pickle.dumps(self.context) + except Exception as ex: + raise ValueError( + 'When concurrency_mechanism="multiprocessing" is used in conjunction with ' + "pass_context=True, context must be serializable" + ) from ex + self._executor = None if concurrency_mechanism == "threading": self._executor = ThreadPoolExecutor(max_workers=self._max_in_flight) @@ -963,19 +971,11 @@ def __init__(self, event_processor: Callable, concurrency_mechanism=None, pass_c async def _process_event(self, event): args = [event] + if self._pass_context: + args.append(self.context) if self._executor: - func = self._event_processor - context = self.context - if self._pass_context: - if isinstance(self._executor, ProcessPoolExecutor): - # dill, unlike pickle, is able to serialize function objects - context = dill.dumps(self.context) - func = functools.partial(_unpickle_context_and_call, self._event_processor) - args.append(context) - result = await asyncio.get_running_loop().run_in_executor(self._executor, func, *args) + result = await asyncio.get_running_loop().run_in_executor(self._executor, self._event_processor, *args) else: - if self._pass_context: - args.append(self.context) result = self._event_processor(*args) if asyncio.iscoroutine(result): result = await result diff --git a/tests/test_concurrent_execution.py b/tests/test_concurrent_execution.py index ad3e4bcb..a1f84bbd 100644 --- a/tests/test_concurrent_execution.py +++ b/tests/test_concurrent_execution.py @@ -21,28 +21,27 @@ async def process_event_slow_asyncio(event, context): return event -def process_even_slow_io(event, context): +def process_event_slow_io(event, context): assert isinstance(context, SomeContext) and callable(context.fn) time.sleep(event_processing_duration) return event -def process_event_slow_processing(event, context): - assert isinstance(context, SomeContext) and callable(context.fn) +def process_event_slow_processing(event): start = time.monotonic() while time.monotonic() - start < event_processing_duration: pass return event -async def async_test_concurrent_execution(concurrency_mechanism, event_processor): +async def async_test_concurrent_execution(concurrency_mechanism, event_processor, pass_context): controller = build_flow( [ AsyncEmitSource(), ConcurrentExecution( event_processor=event_processor, concurrency_mechanism=concurrency_mechanism, - pass_context=True, + pass_context=pass_context, max_in_flight=10, context=SomeContext(), ), @@ -68,12 +67,12 @@ async def async_test_concurrent_execution(concurrency_mechanism, event_processor @pytest.mark.parametrize( - ["concurrency_mechanism", "event_processor"], + ["concurrency_mechanism", "event_processor", "pass_context"], [ - ("asyncio", process_event_slow_asyncio), - ("threading", process_even_slow_io), - ("multiprocessing", process_event_slow_processing), + ("asyncio", process_event_slow_asyncio, True), + ("threading", process_event_slow_io, True), + ("multiprocessing", process_event_slow_processing, False), ], ) -def test_concurrent_execution(concurrency_mechanism, event_processor): - asyncio.run(async_test_concurrent_execution(concurrency_mechanism, event_processor)) +def test_concurrent_execution(concurrency_mechanism, event_processor, pass_context): + asyncio.run(async_test_concurrent_execution(concurrency_mechanism, event_processor, pass_context)) From 19fe42a228b39e18101f5d62cddfe4a0b08245c1 Mon Sep 17 00:00:00 2001 From: Gal Topper Date: Thu, 18 Apr 2024 20:04:42 +0800 Subject: [PATCH 15/15] Remove dill requirement --- requirements.txt | 1 - 1 file changed, 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index ab71edba..88f94e1c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -13,4 +13,3 @@ fsspec>=0.6.2 v3iofs~=0.1.17 xxhash>=1 nuclio-sdk>=0.5.3 -dill~=0.3.8