Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enabled Event and CapacityLimiter to be instantiated outside an event loop #651

Merged
merged 9 commits into from
Dec 14, 2023
2 changes: 2 additions & 0 deletions docs/versionhistory.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ This library adheres to `Semantic Versioning 2.0 <http://semver.org/>`_.
- Add support for ``byte``-based paths in ``connect_unix``, ``create_unix_listeners``,
``create_unix_datagram_socket``, and ``create_connected_unix_datagram_socket``. (PR by
Lura Skye)
- Enabled the ``Event`` and ``CapacityLimiter`` classes to be instantiated outside an
event loop thread
- Fixed adjusting the total number of tokens in a ``CapacityLimiter`` on asyncio failing
to wake up tasks waiting to acquire the limiter in certain edge cases (fixed with help
from Egor Blagov)
Expand Down
4 changes: 3 additions & 1 deletion src/anyio/_backends/_trio.py
Original file line number Diff line number Diff line change
Expand Up @@ -615,7 +615,9 @@ def set(self) -> None:


class CapacityLimiter(BaseCapacityLimiter):
def __new__(cls, *args: object, **kwargs: object) -> CapacityLimiter:
def __new__(
cls, *args: Any, original: trio.CapacityLimiter | None = None
) -> CapacityLimiter:
return object.__new__(cls)

def __init__(
Expand Down
131 changes: 129 additions & 2 deletions src/anyio/_core/_synchronization.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
from __future__ import annotations

import math
from collections import deque
from dataclasses import dataclass
from types import TracebackType

from sniffio import AsyncLibraryNotFoundError

from ..lowlevel import cancel_shielded_checkpoint, checkpoint, checkpoint_if_cancelled
from ._eventloop import get_async_backend
from ._exceptions import BusyResourceError, WouldBlock
Expand Down Expand Up @@ -76,7 +79,10 @@ class SemaphoreStatistics:

class Event:
def __new__(cls) -> Event:
return get_async_backend().create_event()
try:
return get_async_backend().create_event()
except AsyncLibraryNotFoundError:
return EventAdapter()

def set(self) -> None:
"""Set the flag, notifying all listeners."""
Expand All @@ -101,6 +107,35 @@ def statistics(self) -> EventStatistics:
raise NotImplementedError


class EventAdapter(Event):
_internal_event: Event | None = None

def __new__(cls) -> EventAdapter:
return object.__new__(cls)

@property
def _event(self) -> Event:
if self._internal_event is None:
self._internal_event = get_async_backend().create_event()

return self._internal_event

def set(self) -> None:
self._event.set()

def is_set(self) -> bool:
return self._internal_event is not None and self._internal_event.is_set()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I understand that if the event is not created then it cannot be set, but should is_set actually create the event?

Suggested change
return self._internal_event is not None and self._internal_event.is_set()
return self._event.is_set()

statistics also doesn't need the event to be created (it could return empty statistics if it's not), yet it creates it.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

perhaps statistics should have a special case (like is_set does) to allow lookup from outside of the event loop context?

Lock.statistics() for instance is accessible from outside of the event loop already.

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added.


async def wait(self) -> None:
await self._event.wait()

def statistics(self) -> EventStatistics:
if self._internal_event is None:
return EventStatistics(tasks_waiting=0)

return self._internal_event.statistics()


class Lock:
_owner_task: TaskInfo | None = None

Expand Down Expand Up @@ -373,7 +408,10 @@ def statistics(self) -> SemaphoreStatistics:

class CapacityLimiter:
def __new__(cls, total_tokens: float) -> CapacityLimiter:
return get_async_backend().create_capacity_limiter(total_tokens)
try:
return get_async_backend().create_capacity_limiter(total_tokens)
except AsyncLibraryNotFoundError:
return CapacityLimiterAdapter(total_tokens)

async def __aenter__(self) -> None:
raise NotImplementedError
Expand Down Expand Up @@ -482,6 +520,95 @@ def statistics(self) -> CapacityLimiterStatistics:
raise NotImplementedError


class CapacityLimiterAdapter(CapacityLimiter):
_internal_limiter: CapacityLimiter | None = None

def __new__(cls, total_tokens: float) -> CapacityLimiterAdapter:
return object.__new__(cls)

def __init__(self, total_tokens: float) -> None:
if not isinstance(total_tokens, int) and total_tokens is not math.inf:
raise TypeError("total_tokens must be an int or math.inf")
elif total_tokens < 1:
raise ValueError("total_tokens must be >= 1")

self._total_tokens = total_tokens

@property
def _limiter(self) -> CapacityLimiter:
if self._internal_limiter is None:
self._internal_limiter = get_async_backend().create_capacity_limiter(
self._total_tokens
)

return self._internal_limiter

async def __aenter__(self) -> None:
await self._limiter.__aenter__()

async def __aexit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> bool | None:
return await self._limiter.__aexit__(exc_type, exc_val, exc_tb)

@property
def total_tokens(self) -> float:
if self._internal_limiter is None:
return self._total_tokens

return self._internal_limiter.total_tokens

@total_tokens.setter
def total_tokens(self, value: float) -> None:
self._limiter.total_tokens = value
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(obscure/minor comment) could this be delayed too?

Suggested change
self._limiter.total_tokens = value
if self._internal_limiter is None:
self._total_tokens = value
else:
self._limiter.total_tokens = value

Copy link
Collaborator

@gschaffner gschaffner Dec 14, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mhm, i suppose that if total_tokens.setter is delayed then this adapter might not be thread-safe. is it intended to be thread-safe?

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Frankly, the idea that someone would change the adapter's total tokens from a worker thread had not even occurred to me. The whole idea seems pretty bizarre to me.

Copy link
Collaborator

@gschaffner gschaffner Dec 14, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, i think that it's reasonable to require them to do the mutation in the event loop thread. CancelScope.cancel for example is not safe to call from a non-main thread or a signal handler. Event.set is not thread-safe either. it's extra not thread safe after adding EventAdapter, but it wasn't thread-safe before either.

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll revert my latest commit then?

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The more I think about it, the more I'm leaning towards allowing users to mutate total_tokens before the initialization.

Copy link
Collaborator

@gschaffner gschaffner Dec 14, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

huh, does Trio document that?

i'm peeking at the Trio code and total_tokens.setter looks thread-unsafe—there's no threading.Lock, so if two threads mutate total_tokens in parallel then two _wake_waiters will run, which can cause too many waiters to get woken up.

(anyio.CapacityLimiter.)total_tokens.setter is currently thread-unsafe, though, so i think that d8ba6de is in alignment with the status quo.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The more I think about it, the more I'm leaning towards allowing users to mutate total_tokens before the initialization.

i'm leaning this way too.

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, generally the API should be considered thread-unsafe, save for that parts that were explicitly designed to allow access from worker threads (the anyio.from_thread module).

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oops—i misinterpreted

It might be reasonable to allow the tokens to be set from outside the event loop, just like trio.CapacityLimiter allows.

as

It might be reasonable to allow the tokens to be set from other threads, just like trio.CapacityLimiter allows.

and was replying to something you didn't say. :p


@property
def borrowed_tokens(self) -> int:
if self._internal_limiter is None:
return 0

return self._internal_limiter.borrowed_tokens

@property
def available_tokens(self) -> float:
if self._internal_limiter is None:
return self._total_tokens

return self._internal_limiter.available_tokens

def acquire_nowait(self) -> None:
self._limiter.acquire_nowait()

def acquire_on_behalf_of_nowait(self, borrower: object) -> None:
self._limiter.acquire_on_behalf_of_nowait(borrower)

async def acquire(self) -> None:
await self._limiter.acquire()

async def acquire_on_behalf_of(self, borrower: object) -> None:
await self._limiter.acquire_on_behalf_of(borrower)

def release(self) -> None:
self._limiter.release()

def release_on_behalf_of(self, borrower: object) -> None:
self._limiter.release_on_behalf_of(borrower)

def statistics(self) -> CapacityLimiterStatistics:
if self._internal_limiter is None:
return CapacityLimiterStatistics(
borrowed_tokens=0,
total_tokens=self.total_tokens,
borrowers=(),
tasks_waiting=0,
)

return self._internal_limiter.statistics()


class ResourceGuard:
"""
A context manager for ensuring that a resource is only used by a single task at a
Expand Down
86 changes: 86 additions & 0 deletions tests/test_synchronization.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import asyncio
from typing import Any

import pytest

Expand All @@ -13,6 +14,7 @@
WouldBlock,
create_task_group,
fail_after,
run,
to_thread,
wait_all_tasks_blocked,
)
Expand Down Expand Up @@ -141,6 +143,21 @@ async def acquire() -> None:
task1.cancel()
await asyncio.wait_for(task2, 1)

def test_instantiate_outside_event_loop(
self, anyio_backend_name: str, anyio_backend_options: dict[str, Any]
) -> None:
async def use_lock() -> None:
async with lock:
pass

lock = Lock()
statistics = lock.statistics()
assert not statistics.locked
assert statistics.owner is None
assert statistics.tasks_waiting == 0

run(use_lock, backend=anyio_backend_name, backend_options=anyio_backend_options)


class TestEvent:
async def test_event(self) -> None:
Expand Down Expand Up @@ -208,6 +225,21 @@ async def waiter() -> None:

assert event.statistics().tasks_waiting == 0

def test_instantiate_outside_event_loop(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

to prevent the other synchronization primitives from regressing on this issue, it could make sense to add an analogous test_instantiate_outside_event_loop for each of them.

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

self, anyio_backend_name: str, anyio_backend_options: dict[str, Any]
) -> None:
async def use_event() -> None:
event.set()
await event.wait()

event = Event()
assert not event.is_set()
assert event.statistics().tasks_waiting == 0

run(
use_event, backend=anyio_backend_name, backend_options=anyio_backend_options
)


class TestCondition:
async def test_contextmanager(self) -> None:
Expand Down Expand Up @@ -304,6 +336,22 @@ async def waiter() -> None:
assert not condition.statistics().lock_statistics.locked
assert condition.statistics().tasks_waiting == 0

def test_instantiate_outside_event_loop(
self, anyio_backend_name: str, anyio_backend_options: dict[str, Any]
) -> None:
async def use_condition() -> None:
async with condition:
pass

condition = Condition()
assert condition.statistics().tasks_waiting == 0

run(
use_condition,
backend=anyio_backend_name,
backend_options=anyio_backend_options,
)


class TestSemaphore:
async def test_contextmanager(self) -> None:
Expand Down Expand Up @@ -426,6 +474,22 @@ async def acquire() -> None:
task1.cancel()
await asyncio.wait_for(task2, 1)

def test_instantiate_outside_event_loop(
self, anyio_backend_name: str, anyio_backend_options: dict[str, Any]
) -> None:
async def use_semaphore() -> None:
async with semaphore:
pass

semaphore = Semaphore(1)
assert semaphore.statistics().tasks_waiting == 0

run(
use_semaphore,
backend=anyio_backend_name,
backend_options=anyio_backend_options,
)


class TestCapacityLimiter:
async def test_bad_init_type(self) -> None:
Expand Down Expand Up @@ -595,3 +659,25 @@ async def worker(entered_event: Event) -> None:

# Allow all tasks to exit
continue_event.set()

def test_instantiate_outside_event_loop(
self, anyio_backend_name: str, anyio_backend_options: dict[str, Any]
) -> None:
async def use_limiter() -> None:
async with limiter:
pass

limiter = CapacityLimiter(1)
assert limiter.total_tokens == 1
assert limiter.borrowed_tokens == 0
statistics = limiter.statistics()
assert statistics.total_tokens == 1
assert statistics.borrowed_tokens == 0
assert statistics.borrowers == ()
assert statistics.tasks_waiting == 0

run(
use_limiter,
backend=anyio_backend_name,
backend_options=anyio_backend_options,
)