diff --git a/src/anyio/streams/memory.py b/src/anyio/streams/memory.py index bee369ca..edc19e4e 100644 --- a/src/anyio/streams/memory.py +++ b/src/anyio/streams/memory.py @@ -1,6 +1,6 @@ from collections import deque, OrderedDict from dataclasses import dataclass, field -from typing import TypeVar, Generic, List, Deque, Tuple +from typing import TypeVar, Generic, List, Deque from .. import get_cancelled_exc_class from .._core._lowlevel import checkpoint @@ -18,7 +18,8 @@ class MemoryObjectStreamState(Generic[T_Item]): buffer: Deque[T_Item] = field(init=False, default_factory=deque) open_send_channels: int = field(init=False, default=0) open_receive_channels: int = field(init=False, default=0) - waiting_receivers: Deque[Tuple[Event, List[T_Item]]] = field(init=False, default_factory=deque) + waiting_receivers: 'OrderedDict[Event, List[T_Item]]' = field(init=False, + default_factory=OrderedDict) waiting_senders: 'OrderedDict[Event, T_Item]' = field(init=False, default_factory=OrderedDict) @@ -66,34 +67,17 @@ async def receive(self) -> T_Item: # Add ourselves in the queue receive_event = create_event() container: List[T_Item] = [] - ticket = receive_event, container - self._state.waiting_receivers.append(ticket) + self._state.waiting_receivers[receive_event] = container try: await receive_event.wait() except get_cancelled_exc_class(): - # If we already received an item in the container, pass it to the next receiver in - # line - index = self._state.waiting_receivers.index(ticket) + 1 - if container: - item = container[0] - while index < len(self._state.waiting_receivers): - receive_event, container = self._state.waiting_receivers[index] - if container: - item, container[0] = container[0], item - else: - # Found an untriggered receiver - container.append(item) - await receive_event.set() - break - else: - # Could not find an untriggered receiver, so in order to not lose any - # items, put it in the buffer, even if it exceeds the maximum buffer size - self._state.buffer.append(item) - - raise + # Ignore the immediate cancellation if we already received an item, so as not to + # lose it + if not container: + raise finally: - self._state.waiting_receivers.remove(ticket) + self._state.waiting_receivers.pop(receive_event, None) if container: return container[0] @@ -151,13 +135,11 @@ async def send_nowait(self, item: T_Item) -> None: if not self._state.open_receive_channels: raise BrokenResourceError - for receive_event, container in self._state.waiting_receivers: - if not container: - container.append(item) - await receive_event.set() - return - - if len(self._state.buffer) < self._state.max_buffer_size: + if self._state.waiting_receivers: + receive_event, container = self._state.waiting_receivers.popitem(last=False) + container.append(item) + await receive_event.set() + elif len(self._state.buffer) < self._state.max_buffer_size: self._state.buffer.append(item) else: raise WouldBlock @@ -199,6 +181,7 @@ async def aclose(self) -> None: self._closed = True self._state.open_send_channels -= 1 if self._state.open_send_channels == 0: - receive_events = [event for event, container in self._state.waiting_receivers] + receive_events = list(self._state.waiting_receivers.keys()) + self._state.waiting_receivers.clear() for event in receive_events: await event.set() diff --git a/tests/streams/test_memory.py b/tests/streams/test_memory.py index 4c5f0f8b..a37d1931 100644 --- a/tests/streams/test_memory.py +++ b/tests/streams/test_memory.py @@ -231,10 +231,9 @@ async def test_cancel_during_receive(): async def scoped_receiver(): nonlocal receiver_scope async with open_cancel_scope() as receiver_scope: - await receive.receive() + received.append(await receive.receive()) - async def receiver(): - received.append(await receive.receive()) + assert receiver_scope.cancel_called receiver_scope = None received = [] @@ -242,34 +241,7 @@ async def receiver(): async with create_task_group() as tg: await tg.spawn(scoped_receiver) await wait_all_tasks_blocked() - await tg.spawn(receiver) + await send.send_nowait('hello') await receiver_scope.cancel() - await send.send('hello') assert received == ['hello'] - - -async def test_cancel_during_receive_last_receiver(): - """ - Test that cancelling a pending receive() operation does not cause an item in the stream to be - lost, even if there are no other receivers waiting. - - """ - async def scoped_receiver(): - nonlocal receiver_scope - async with open_cancel_scope() as receiver_scope: - await receive.receive() - pytest.fail('This point should never be reached') - - receiver_scope = None - send, receive = create_memory_object_stream() - async with create_task_group() as tg: - await tg.spawn(scoped_receiver) - await wait_all_tasks_blocked() - await receiver_scope.cancel() - await send.send_nowait('hello') - - with pytest.raises(WouldBlock): - await send.send_nowait('world') - - assert await receive.receive_nowait() == 'hello'