Skip to content

Commit

Permalink
Update tests to use the new Receiver.is_selected method
Browse files Browse the repository at this point in the history
Signed-off-by: Sahas Subramanian <sahas.subramanian@proton.me>
  • Loading branch information
shsms committed Oct 8, 2024
1 parent 5c48082 commit 1d31b1f
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 33 deletions.
16 changes: 8 additions & 8 deletions tests/test_file_watcher_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

import pytest

from frequenz.channels import ReceiverStoppedError, select, selected_from
from frequenz.channels import ReceiverStoppedError, select
from frequenz.channels.file_watcher import EventType, FileWatcher
from frequenz.channels.timer import SkipMissedAndDrift, Timer

Expand All @@ -32,9 +32,9 @@ async def test_file_watcher(tmp_path: pathlib.Path) -> None:
timer = Timer(timedelta(seconds=0.1), SkipMissedAndDrift())

async for selected in select(file_watcher, timer):
if selected_from(selected, timer):
if timer.is_selected(selected):
filename.write_text(f"{selected.message}")
elif selected_from(selected, file_watcher):
elif file_watcher.is_selected(selected):
event_type = EventType.CREATE if number_of_writes == 0 else EventType.MODIFY
event = selected.message
# If we receive updates for the directory itself, they should be only
Expand Down Expand Up @@ -93,18 +93,18 @@ async def test_file_watcher_deletes(tmp_path: pathlib.Path) -> None:
# D: Delete
# E: FileWatcher Event
async for selected in select(file_watcher, write_timer, deletion_timer):
if selected_from(selected, write_timer):
if write_timer.is_selected(selected):
if number_of_write >= 2 and number_of_events == 0:
continue
filename.write_text(f"{selected.message}")
number_of_write += 1
elif selected_from(selected, deletion_timer):
elif deletion_timer.is_selected(selected):
# Avoid removing the file twice
if not pathlib.Path(filename).is_file():
continue
os.remove(filename)
number_of_deletes += 1
elif selected_from(selected, file_watcher):
elif file_watcher.is_selected(selected):
number_of_events += 1
if number_of_events >= 2:
break
Expand Down Expand Up @@ -135,9 +135,9 @@ async def test_file_watcher_exit_iterator(tmp_path: pathlib.Path) -> None:
timer = Timer(timedelta(seconds=0.1), SkipMissedAndDrift())

async for selected in select(file_watcher, timer):
if selected_from(selected, timer):
if timer.is_selected(selected):
filename.write_text(f"{selected.message}")
elif selected_from(selected, file_watcher):
elif file_watcher.is_selected(selected):
number_of_writes += 1
if number_of_writes == expected_number_of_writes:
file_watcher._stop_event.set() # pylint: disable=protected-access
Expand Down
8 changes: 4 additions & 4 deletions tests/test_select.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import pytest

from frequenz.channels import Receiver, ReceiverStoppedError, Selected, selected_from
from frequenz.channels import Receiver, ReceiverStoppedError, Selected


class TestSelected:
Expand All @@ -19,7 +19,7 @@ def test_with_message(self) -> None:
recv.consume.return_value = 42
selected = Selected[int](recv)

assert selected_from(selected, recv)
assert recv.is_selected(selected)
assert selected.message == 42
assert selected.exception is None
assert not selected.was_stopped
Expand All @@ -31,7 +31,7 @@ def test_with_exception(self) -> None:
recv.consume.side_effect = exception
selected = Selected[int](recv)

assert selected_from(selected, recv)
assert recv.is_selected(selected)
with pytest.raises(Exception, match="test"):
_ = selected.message
assert selected.exception is exception
Expand All @@ -44,7 +44,7 @@ def test_with_stopped(self) -> None:
recv.consume.side_effect = exception
selected = Selected[int](recv)

assert selected_from(selected, recv)
assert recv.is_selected(selected)
with pytest.raises(
ReceiverStoppedError,
match=r"Receiver <MagicMock spec='_GenericAlias' id='\d+'> was stopped",
Expand Down
41 changes: 20 additions & 21 deletions tests/test_select_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ class at a time.
Selected,
UnhandledSelectedError,
select,
selected_from,
)
from frequenz.channels.event import Event

Expand Down Expand Up @@ -87,7 +86,7 @@ def assert_received_from(
number is negative, a > check is performed with the absolute value. If
it is 0, no check is performed.
"""
assert selected_from(selected, receiver)
assert receiver.is_selected(selected)
assert selected.message is None
assert selected.exception is None
assert not selected.was_stopped
Expand Down Expand Up @@ -120,7 +119,7 @@ def assert_receiver_stopped(
number is negative, a > check is performed with the absolute value. If
it is 0, no check is performed.
"""
assert selected_from(selected, receiver)
assert receiver.is_selected(selected)
assert selected.was_stopped
assert isinstance(selected.exception, ReceiverStoppedError)
assert selected.exception.receiver is receiver
Expand Down Expand Up @@ -245,33 +244,33 @@ async def test_break(
"""Test that break works."""
selected: Selected[Any] | None = None
async for selected in select(self.recv1, self.recv2, self.recv3):
if selected_from(selected, self.recv1):
if self.recv1.is_selected(selected):
continue
if selected_from(selected, self.recv2):
if self.recv2.is_selected(selected):
continue
if selected_from(selected, self.recv3):
if self.recv3.is_selected(selected):
break

assert selected is not None
self.assert_received_from(selected, self.recv3, at_time=2)

async for selected in select(self.recv1, self.recv2, self.recv3):
if selected_from(selected, self.recv1):
if self.recv1.is_selected(selected):
continue
if selected_from(selected, self.recv2):
if self.recv2.is_selected(selected):
break
if selected_from(selected, self.recv3):
if self.recv3.is_selected(selected):
continue

assert selected is not None
self.assert_received_from(selected, self.recv2, at_time=6)

async for selected in select(self.recv1, self.recv2, self.recv3):
if selected_from(selected, self.recv1):
if self.recv1.is_selected(selected):
continue
if selected_from(selected, self.recv2):
if self.recv2.is_selected(selected):
continue
if selected_from(selected, self.recv3):
if self.recv3.is_selected(selected):
break

assert selected is not None
Expand All @@ -281,7 +280,7 @@ async def test_break(
assert self.recv3.is_stopped

async for selected in select(self.recv2):
if selected_from(selected, self.recv2):
if self.recv2.is_selected(selected):
continue

self.assert_receiver_stopped(
Expand All @@ -300,9 +299,9 @@ async def test_missed_select_from(
selected: Selected[Any] | None = None
with pytest.raises(UnhandledSelectedError) as excinfo:
async for selected in select(self.recv1, self.recv2, self.recv3):
if selected_from(selected, self.recv1):
if self.recv1.is_selected(selected):
continue
if selected_from(selected, self.recv2):
if self.recv2.is_selected(selected):
continue

assert False, "Should not reach this point"
Expand Down Expand Up @@ -392,11 +391,11 @@ async def test_multiple_ready(
received.clear()
last_time = now

if selected_from(selected, self.recv1):
if self.recv1.is_selected(selected):
received.add(self.recv1.name)
elif selected_from(selected, self.recv2):
elif self.recv2.is_selected(selected):
received.add(self.recv2.name)
elif selected_from(selected, self.recv3):
elif self.recv3.is_selected(selected):
received.add(self.recv3.name)
else:
assert False, "Should not reach this point"
Expand Down Expand Up @@ -425,11 +424,11 @@ def test_tasks_are_cleaned_up_with_break(self) -> None:
async def run() -> None:
task = loop.create_task(self.run_multiple_ready())
async for selected in select(self.recv1, self.recv2, self.recv3):
if selected_from(selected, self.recv1):
if self.recv1.is_selected(selected):
continue
if selected_from(selected, self.recv2):
if self.recv2.is_selected(selected):
continue
if selected_from(selected, self.recv3):
if self.recv3.is_selected(selected):
break

task.cancel()
Expand Down

0 comments on commit 1d31b1f

Please sign in to comment.