Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generate an event ID in Batch step #525

Merged
merged 1 commit into from
May 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 21 additions & 2 deletions storey/flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import pickle
import time
import traceback
import uuid
from asyncio import Task
from collections import defaultdict
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor
Expand Down Expand Up @@ -331,6 +332,20 @@ def _check_step_in_flow(self, type_to_check):
return False


class WithUUID:
def __init__(self):
self._current_uuid_base = None
self._current_uuid_count = 0

def _get_uuid(self):
if not self._current_uuid_base or self._current_uuid_count == 1024:
self._current_uuid_base = uuid.uuid4().hex
self._current_uuid_count = 0
result = f"{self._current_uuid_base}-{self._current_uuid_count:04}"
self._current_uuid_count += 1
return result


class Choice(Flow):
"""Redirects each input element into at most one of multiple downstreams.

Expand Down Expand Up @@ -1153,7 +1168,7 @@ async def _emit_all(self):
await self._emit_batch(key)


class Batch(_Batching):
class Batch(_Batching, WithUUID):
"""Batches events into lists of up to max_events events. Each emitted list contained max_events events, unless
flush_after_seconds seconds have passed since the first event in the batch was received, at which the batch is
emitted with potentially fewer than max_events event.
Expand All @@ -1170,8 +1185,12 @@ class Batch(_Batching):

_do_downstream_per_event = False

def __init__(self, *args, **kwargs):
_Batching.__init__(self, *args, **kwargs)
WithUUID.__init__(self)

async def _emit(self, batch, batch_key, batch_time, batch_events, last_event_time=None):
event = Event(batch)
event = Event(batch, id=self._get_uuid())
if not self._full_event:
# Preserve reference to the original events to avoid early commit of offsets
event._original_events = batch_events
Expand Down
17 changes: 1 addition & 16 deletions storey/sources.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
import threading
import time
import traceback
import uuid
import warnings
import weakref
from collections import defaultdict
Expand All @@ -34,7 +33,7 @@
from nuclio_sdk import QualifiedOffset

from .dtypes import Event, _termination_obj
from .flow import Complete, Flow
from .flow import Complete, Flow, WithUUID
from .queue import SimpleAsyncQueue
from .utils import find_filters, find_partitions, url_to_file_system

Expand Down Expand Up @@ -94,20 +93,6 @@ def _convert_to_datetime(obj, time_format: Optional[str] = None):
raise ValueError(f"Could not parse '{obj}' (of type {type(obj)}) as a time.")


class WithUUID:
def __init__(self):
self._current_uuid_base = None
self._current_uuid_count = 0

def _get_uuid(self):
if not self._current_uuid_base or self._current_uuid_count == 1024:
self._current_uuid_base = uuid.uuid4().hex
self._current_uuid_count = 0
result = f"{self._current_uuid_base}-{self._current_uuid_count:04}"
self._current_uuid_count += 1
return result


class FlowControllerBase(WithUUID):
def __init__(
self,
Expand Down
10 changes: 8 additions & 2 deletions tests/test_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -1768,15 +1768,21 @@ def test_batch():
[
SyncEmitSource(),
Batch(4, 100),
Reduce([], lambda acc, x: append_and_return(acc, x)),
Reduce([], lambda acc, x: append_and_return(acc, x), full_event=True),
]
).run()

for i in range(10):
controller.emit(i)
controller.terminate()
termination_result = controller.await_termination()
assert termination_result == [[0, 1, 2, 3], [4, 5, 6, 7], [8, 9]]
assert len(termination_result) == 3
assert termination_result[0].id
assert termination_result[0].body == [0, 1, 2, 3]
assert termination_result[1].id
assert termination_result[1].body == [4, 5, 6, 7]
assert termination_result[2].id
assert termination_result[2].body == [8, 9]


def test_batch_full_event():
Expand Down
Loading