Skip to content

Commit

Permalink
Fix original events copying (#451)
Browse files Browse the repository at this point in the history
* Fix original events copying

Following #428

* Add regression test

* Add comments

---------

Co-authored-by: Gal Topper <galt@iguazio.com>
  • Loading branch information
gtopper and Gal Topper authored Jul 18, 2023
1 parent 69276d9 commit b5886d3
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 2 deletions.
7 changes: 5 additions & 2 deletions storey/flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,13 +269,16 @@ async def _do_downstream(self, event):
if len(self._outlets) > 1:
awaitable_result = event._awaitable_result
event._awaitable_result = None
original_events = event._original_events
original_events = getattr(event, "_original_events", None)
# Temporarily delete self-reference to avoid deepcopy getting stuck in an infinite loop
event._original_events = None
for i in range(1, len(self._outlets)):
event_copy = copy.deepcopy(event)
event_copy._awaitable_result = awaitable_result
event._original_events = original_events
event_copy._original_events = original_events
tasks.append(asyncio.get_running_loop().create_task(self._outlets[i]._do_and_recover(event_copy)))
# Set self-reference back after deepcopy
event._original_events = original_events
event._awaitable_result = awaitable_result
if self.verbose and self.logger:
step_name = self.name
Expand Down
15 changes: 15 additions & 0 deletions tests/test_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,6 +342,21 @@ def test_multiple_upstreams():
assert termination_result == 55 + 450


def test_multiple_upstreams_csv_source():
source = CSVSource("tests/test.csv")
map1 = Map(lambda x: append_and_return(x, "map1"))
map2 = Map(lambda x: append_and_return(x, "map2"))
reduce = Reduce([], append_and_return)
source.to(map1)
source.to(map2)
map1.to(reduce)
map2.to(reduce)
controller = source.run()

termination_result = controller.await_termination()
assert termination_result == [[1, 2, 3, "map1"], [1, 2, 3, "map2"], [4, 5, 6, "map1"], [4, 5, 6, "map2"]]


def test_multiple_upstreams_completion():
source = SyncEmitSource()
map1 = Map(lambda x: x + 1)
Expand Down

0 comments on commit b5886d3

Please sign in to comment.