diff --git a/docs/versionhistory.rst b/docs/versionhistory.rst index 94eaf15f..7931bf46 100644 --- a/docs/versionhistory.rst +++ b/docs/versionhistory.rst @@ -8,6 +8,8 @@ This library adheres to `Semantic Versioning 2.0 `_. - Add support for ``byte``-based paths in ``connect_unix``, ``create_unix_listeners``, ``create_unix_datagram_socket``, and ``create_connected_unix_datagram_socket``. (PR by Lura Skye) +- Enabled the ``Event`` and ``CapacityLimiter`` classes to be instantiated outside an + event loop thread - Fixed adjusting the total number of tokens in a ``CapacityLimiter`` on asyncio failing to wake up tasks waiting to acquire the limiter in certain edge cases (fixed with help from Egor Blagov) diff --git a/src/anyio/_backends/_trio.py b/src/anyio/_backends/_trio.py index eb891d22..2caa2a43 100644 --- a/src/anyio/_backends/_trio.py +++ b/src/anyio/_backends/_trio.py @@ -615,7 +615,9 @@ def set(self) -> None: class CapacityLimiter(BaseCapacityLimiter): - def __new__(cls, *args: object, **kwargs: object) -> CapacityLimiter: + def __new__( + cls, *args: Any, original: trio.CapacityLimiter | None = None + ) -> CapacityLimiter: return object.__new__(cls) def __init__( diff --git a/src/anyio/_core/_synchronization.py b/src/anyio/_core/_synchronization.py index fdd4f5fb..33172dcd 100644 --- a/src/anyio/_core/_synchronization.py +++ b/src/anyio/_core/_synchronization.py @@ -1,9 +1,12 @@ from __future__ import annotations +import math from collections import deque from dataclasses import dataclass from types import TracebackType +from sniffio import AsyncLibraryNotFoundError + from ..lowlevel import cancel_shielded_checkpoint, checkpoint, checkpoint_if_cancelled from ._eventloop import get_async_backend from ._exceptions import BusyResourceError, WouldBlock @@ -76,7 +79,10 @@ class SemaphoreStatistics: class Event: def __new__(cls) -> Event: - return get_async_backend().create_event() + try: + return get_async_backend().create_event() + except AsyncLibraryNotFoundError: + return EventAdapter() def set(self) -> None: """Set the flag, notifying all listeners.""" @@ -101,6 +107,35 @@ def statistics(self) -> EventStatistics: raise NotImplementedError +class EventAdapter(Event): + _internal_event: Event | None = None + + def __new__(cls) -> EventAdapter: + return object.__new__(cls) + + @property + def _event(self) -> Event: + if self._internal_event is None: + self._internal_event = get_async_backend().create_event() + + return self._internal_event + + def set(self) -> None: + self._event.set() + + def is_set(self) -> bool: + return self._internal_event is not None and self._internal_event.is_set() + + async def wait(self) -> None: + await self._event.wait() + + def statistics(self) -> EventStatistics: + if self._internal_event is None: + return EventStatistics(tasks_waiting=0) + + return self._internal_event.statistics() + + class Lock: _owner_task: TaskInfo | None = None @@ -373,7 +408,10 @@ def statistics(self) -> SemaphoreStatistics: class CapacityLimiter: def __new__(cls, total_tokens: float) -> CapacityLimiter: - return get_async_backend().create_capacity_limiter(total_tokens) + try: + return get_async_backend().create_capacity_limiter(total_tokens) + except AsyncLibraryNotFoundError: + return CapacityLimiterAdapter(total_tokens) async def __aenter__(self) -> None: raise NotImplementedError @@ -482,6 +520,99 @@ def statistics(self) -> CapacityLimiterStatistics: raise NotImplementedError +class CapacityLimiterAdapter(CapacityLimiter): + _internal_limiter: CapacityLimiter | None = None + + def __new__(cls, total_tokens: float) -> CapacityLimiterAdapter: + return object.__new__(cls) + + def __init__(self, total_tokens: float) -> None: + self.total_tokens = total_tokens + + @property + def _limiter(self) -> CapacityLimiter: + if self._internal_limiter is None: + self._internal_limiter = get_async_backend().create_capacity_limiter( + self._total_tokens + ) + + return self._internal_limiter + + async def __aenter__(self) -> None: + await self._limiter.__aenter__() + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> bool | None: + return await self._limiter.__aexit__(exc_type, exc_val, exc_tb) + + @property + def total_tokens(self) -> float: + if self._internal_limiter is None: + return self._total_tokens + + return self._internal_limiter.total_tokens + + @total_tokens.setter + def total_tokens(self, value: float) -> None: + if not isinstance(value, int) and value is not math.inf: + raise TypeError("total_tokens must be an int or math.inf") + elif value < 1: + raise ValueError("total_tokens must be >= 1") + + if self._internal_limiter is None: + self._total_tokens = value + return + + self._limiter.total_tokens = value + + @property + def borrowed_tokens(self) -> int: + if self._internal_limiter is None: + return 0 + + return self._internal_limiter.borrowed_tokens + + @property + def available_tokens(self) -> float: + if self._internal_limiter is None: + return self._total_tokens + + return self._internal_limiter.available_tokens + + def acquire_nowait(self) -> None: + self._limiter.acquire_nowait() + + def acquire_on_behalf_of_nowait(self, borrower: object) -> None: + self._limiter.acquire_on_behalf_of_nowait(borrower) + + async def acquire(self) -> None: + await self._limiter.acquire() + + async def acquire_on_behalf_of(self, borrower: object) -> None: + await self._limiter.acquire_on_behalf_of(borrower) + + def release(self) -> None: + self._limiter.release() + + def release_on_behalf_of(self, borrower: object) -> None: + self._limiter.release_on_behalf_of(borrower) + + def statistics(self) -> CapacityLimiterStatistics: + if self._internal_limiter is None: + return CapacityLimiterStatistics( + borrowed_tokens=0, + total_tokens=self.total_tokens, + borrowers=(), + tasks_waiting=0, + ) + + return self._internal_limiter.statistics() + + class ResourceGuard: """ A context manager for ensuring that a resource is only used by a single task at a diff --git a/tests/test_synchronization.py b/tests/test_synchronization.py index 6011710a..5eddc066 100644 --- a/tests/test_synchronization.py +++ b/tests/test_synchronization.py @@ -1,6 +1,7 @@ from __future__ import annotations import asyncio +from typing import Any import pytest @@ -13,6 +14,7 @@ WouldBlock, create_task_group, fail_after, + run, to_thread, wait_all_tasks_blocked, ) @@ -141,6 +143,21 @@ async def acquire() -> None: task1.cancel() await asyncio.wait_for(task2, 1) + def test_instantiate_outside_event_loop( + self, anyio_backend_name: str, anyio_backend_options: dict[str, Any] + ) -> None: + async def use_lock() -> None: + async with lock: + pass + + lock = Lock() + statistics = lock.statistics() + assert not statistics.locked + assert statistics.owner is None + assert statistics.tasks_waiting == 0 + + run(use_lock, backend=anyio_backend_name, backend_options=anyio_backend_options) + class TestEvent: async def test_event(self) -> None: @@ -208,6 +225,21 @@ async def waiter() -> None: assert event.statistics().tasks_waiting == 0 + def test_instantiate_outside_event_loop( + self, anyio_backend_name: str, anyio_backend_options: dict[str, Any] + ) -> None: + async def use_event() -> None: + event.set() + await event.wait() + + event = Event() + assert not event.is_set() + assert event.statistics().tasks_waiting == 0 + + run( + use_event, backend=anyio_backend_name, backend_options=anyio_backend_options + ) + class TestCondition: async def test_contextmanager(self) -> None: @@ -304,6 +336,22 @@ async def waiter() -> None: assert not condition.statistics().lock_statistics.locked assert condition.statistics().tasks_waiting == 0 + def test_instantiate_outside_event_loop( + self, anyio_backend_name: str, anyio_backend_options: dict[str, Any] + ) -> None: + async def use_condition() -> None: + async with condition: + pass + + condition = Condition() + assert condition.statistics().tasks_waiting == 0 + + run( + use_condition, + backend=anyio_backend_name, + backend_options=anyio_backend_options, + ) + class TestSemaphore: async def test_contextmanager(self) -> None: @@ -426,6 +474,22 @@ async def acquire() -> None: task1.cancel() await asyncio.wait_for(task2, 1) + def test_instantiate_outside_event_loop( + self, anyio_backend_name: str, anyio_backend_options: dict[str, Any] + ) -> None: + async def use_semaphore() -> None: + async with semaphore: + pass + + semaphore = Semaphore(1) + assert semaphore.statistics().tasks_waiting == 0 + + run( + use_semaphore, + backend=anyio_backend_name, + backend_options=anyio_backend_options, + ) + class TestCapacityLimiter: async def test_bad_init_type(self) -> None: @@ -595,3 +659,33 @@ async def worker(entered_event: Event) -> None: # Allow all tasks to exit continue_event.set() + + def test_instantiate_outside_event_loop( + self, anyio_backend_name: str, anyio_backend_options: dict[str, Any] + ) -> None: + async def use_limiter() -> None: + async with limiter: + pass + + limiter = CapacityLimiter(1) + limiter.total_tokens = 2 + + with pytest.raises(TypeError): + limiter.total_tokens = "2" # type: ignore[assignment] + + with pytest.raises(TypeError): + limiter.total_tokens = 3.0 + + assert limiter.total_tokens == 2 + assert limiter.borrowed_tokens == 0 + statistics = limiter.statistics() + assert statistics.total_tokens == 2 + assert statistics.borrowed_tokens == 0 + assert statistics.borrowers == () + assert statistics.tasks_waiting == 0 + + run( + use_limiter, + backend=anyio_backend_name, + backend_options=anyio_backend_options, + )