diff --git a/asyncstdlib/functools.py b/asyncstdlib/functools.py index d26b21a..3362ae8 100644 --- a/asyncstdlib/functools.py +++ b/asyncstdlib/functools.py @@ -1,3 +1,4 @@ +from asyncio import iscoroutinefunction from typing import ( Callable, Awaitable, @@ -7,13 +8,15 @@ Generator, Optional, Coroutine, - overload, + AsyncContextManager, + Type, + cast, ) -from ._typing import T, AC, AnyIterable +from ._typing import T, AC, AnyIterable, R from ._core import ScopedIter, awaitify as _awaitify, Sentinel from .builtins import anext -from ._utility import public_module +from .contextlib import nullcontext from ._lrucache import ( lru_cache, @@ -32,6 +35,7 @@ "LRUAsyncBoundCallable", "reduce", "cached_property", + "CachedProperty", ] @@ -45,16 +49,16 @@ def cache(user_function: AC) -> LRUAsyncCallable[AC]: return lru_cache(maxsize=None)(user_function) -class AwaitableValue(Generic[T]): +class AwaitableValue(Generic[R]): """Helper to provide an arbitrary value in ``await``""" __slots__ = ("value",) - def __init__(self, value: T): + def __init__(self, value: R): self.value = value # noinspection PyUnreachableCode - def __await__(self) -> Generator[None, None, T]: + def __await__(self) -> Generator[None, None, R]: return self.value yield # type: ignore # pragma: no cover @@ -62,27 +66,136 @@ def __repr__(self) -> str: return f"{self.__class__.__name__}({self.value!r})" -class _RepeatableCoroutine(Generic[T]): - """Helper to ``await`` a coroutine also more or less than just once""" +class _FutureCachedValue(Generic[R, T]): + """A placeholder object to control concurrent access to a cached awaitable value. - __slots__ = ("call", "args", "kwargs") + When given a lock to coordinate access, only the first task to await on a + cached property triggers the underlying coroutine. Once a value has been + produced, all tasks are unblocked and given the same, single value. + + """ + + __slots__ = ("_get_attribute", "_instance", "_name", "_lock") def __init__( - self, __call: Callable[..., Coroutine[Any, Any, T]], *args: Any, **kwargs: Any + self, + get_attribute: Callable[[T], Coroutine[Any, Any, R]], + instance: T, + name: str, + lock: AsyncContextManager[Any], ): - self.call = __call - self.args = args - self.kwargs = kwargs + self._get_attribute = get_attribute + self._instance = instance + self._name = name + self._lock = lock + + def __await__(self) -> Generator[None, None, R]: + return self._await_impl().__await__() + + @property + def _instance_value(self) -> Awaitable[R]: + """Retrieve whatever is currently cached on the instance + + If the instance (no longer) has this attribute, it was deleted and the + process is restarted by delegating to the descriptor. - def __await__(self) -> Generator[Any, Any, T]: - return self.call(*self.args, **self.kwargs).__await__() + """ + try: + return self._instance.__dict__[self._name] + except KeyError: + # something deleted the cached value or future cached value placeholder. Restart + # the fetch by delegating to the cached_property descriptor. + return getattr(self._instance, self._name) + + async def _await_impl(self) -> R: + if (stored := self._instance_value) is self: + # attempt to get the lock + async with self._lock: + # check again for a cached value + if (stored := self._instance_value) is self: + # the instance attribute is still this placeholder, and we + # hold the lock. Start the getter to store the value on the + # instance and return the value. + return await self._get_attribute(self._instance) + + # another task produced a value, or the instance.__dict__ object was + # deleted in the interim. + return await stored def __repr__(self) -> str: - return f"<{self.__class__.__name__} object {self.call.__name__} at {id(self)}>" + return ( + f"<{type(self).__name__} for '{type(self._instance).__name__}." + f"{self._name}' at {id(self):#x}>" + ) + +class CachedProperty(Generic[T, R]): + def __init__( + self, + getter: Callable[[T], Awaitable[R]], + asynccontextmanager_type: Type[AsyncContextManager[Any]] = nullcontext, + ): + self.func = getter + self.attrname = None + self.__doc__ = getter.__doc__ + self._asynccontextmanager_type = asynccontextmanager_type + + def __set_name__(self, owner: Any, name: str) -> None: + if self.attrname is None: + self.attrname = name + elif name != self.attrname: + raise TypeError( + "Cannot assign the same cached_property to two different names " + f"({self.attrname!r} and {name!r})." + ) + + def __get__( + self, instance: Optional[T], owner: Optional[Type[Any]] + ) -> Union["CachedProperty[T, R]", Awaitable[R]]: + if instance is None: + return self + + name = self.attrname + if name is None: + raise TypeError( + "Cannot use cached_property instance without calling __set_name__ on it." + ) + + # check for write access first; not all objects have __dict__ (e.g. class defines slots) + try: + cache = instance.__dict__ + except AttributeError: + msg = ( + f"No '__dict__' attribute on {type(instance).__name__!r} " + f"instance to cache {name!r} property." + ) + raise TypeError(msg) from None + + # store a placeholder for other tasks to access the future cached value + # on this instance. It takes care of coordinating between different + # tasks awaiting on the placeholder until the cached value has been + # produced. + wrapper = _FutureCachedValue( + self._get_attribute, instance, name, self._asynccontextmanager_type() + ) + cache[name] = wrapper + return wrapper + + async def _get_attribute(self, instance: T) -> R: + value = await self.func(instance) + name = self.attrname + assert name is not None # enforced in __get__ + instance.__dict__[name] = AwaitableValue(value) + return value -@public_module(__name__, "cached_property") -class CachedProperty(Generic[T]): + +def cached_property( + type_or_getter: Union[Type[AsyncContextManager[Any]], Callable[[T], Awaitable[R]]], + /, +) -> Union[ + Callable[[Callable[[T], Awaitable[R]]], CachedProperty[T, R]], + CachedProperty[T, R], +]: """ Transform a method into an attribute whose value is cached @@ -108,7 +221,7 @@ def __init__(self, url): async def data(self): return await asynclib.get(self.url) - resource = Resource(1, 3) + resource = Resource("http://example.com") print(await resource.data) # needs some time... print(await resource.data) # finishes instantly del resource.data @@ -117,51 +230,53 @@ async def data(self): Unlike a :py:class:`property`, this type does not support :py:meth:`~property.setter` or :py:meth:`~property.deleter`. + If the attribute is accessed by multiple tasks before a cached value has + been produced, the getter can be run more than once. The final cached value + is determined by the last getter coroutine to return. To enforce that the + getter is executed at most once, provide a ``lock`` type - e.g. the + :py:class:`asyncio.Lock` class in an :py:mod:`asyncio` application - and + access is automatically synchronised. + + .. code-block:: python3 + + from asyncio import Lock, gather + + class Resource: + def __init__(self, url): + self.url = url + + @a.cached_property(Lock) + async def data(self): + return await asynclib.get(self.url) + + resource = Resource("http://example.com") + print(*(await gather(resource.data, resource.data))) + .. note:: Instances on which a value is to be cached must have a ``__dict__`` attribute that is a mutable mapping. """ + if isinstance(type_or_getter, type) and issubclass( + type_or_getter, AsyncContextManager + ): - def __init__(self, getter: Callable[[Any], Awaitable[T]]): - self.__wrapped__ = getter - self._name = getter.__name__ - self.__doc__ = getter.__doc__ - - def __set_name__(self, owner: Any, name: str) -> None: - # Check whether we can store anything on the instance - # Note that this is a failsafe, and might fail ugly. - # People who are clever enough to avoid this heuristic - # should also be clever enough to know the why and what. - if not any("__dict__" in dir(cls) for cls in owner.__mro__): - raise TypeError( - "'cached_property' requires '__dict__' " - f"on {owner.__name__!r} to store {name}" + def decorator( + coroutine: Callable[[T], Awaitable[R]], + ) -> CachedProperty[T, R]: + return CachedProperty( + coroutine, + asynccontextmanager_type=cast( + Type[AsyncContextManager[Any]], type_or_getter + ), ) - self._name = name - - @overload - def __get__(self, instance: None, owner: type) -> "CachedProperty[T]": ... - @overload - def __get__(self, instance: object, owner: Optional[type]) -> Awaitable[T]: ... - - def __get__( - self, instance: Optional[object], owner: Optional[type] - ) -> Union["CachedProperty[T]", Awaitable[T]]: - if instance is None: - return self - # __get__ may be called multiple times before it is first awaited to completion - # provide a placeholder that acts just like the final value does - return _RepeatableCoroutine(self._get_attribute, instance) - - async def _get_attribute(self, instance: object) -> T: - value = await self.__wrapped__(instance) - instance.__dict__[self._name] = AwaitableValue(value) - return value + return decorator + if not iscoroutinefunction(type_or_getter): + raise ValueError("cached_property can only be used with a coroutine function") -cached_property = CachedProperty + return CachedProperty(type_or_getter) __REDUCE_SENTINEL = Sentinel("") diff --git a/asyncstdlib/functools.pyi b/asyncstdlib/functools.pyi index 72fe8ab..aff4538 100644 --- a/asyncstdlib/functools.pyi +++ b/asyncstdlib/functools.pyi @@ -1,6 +1,6 @@ -from typing import Any, Awaitable, Callable, Generic, overload +from typing import Any, AsyncContextManager, Awaitable, Callable, Generic, overload -from ._typing import T, T1, T2, AC, AnyIterable +from ._typing import T, T1, T2, AC, AnyIterable, R from ._lrucache import ( LRUAsyncCallable as LRUAsyncCallable, @@ -10,14 +10,28 @@ from ._lrucache import ( def cache(user_function: AC) -> LRUAsyncCallable[AC]: ... -class cached_property(Generic[T]): - def __init__(self, getter: Callable[[Any], Awaitable[T]]) -> None: ... +class CachedProperty(Generic[T, R]): + def __init__( + self, + getter: Callable[[T], Awaitable[R]], + lock_type: type[AsyncContextManager[Any]] = ..., + ) -> None: ... def __set_name__(self, owner: Any, name: str) -> None: ... @overload - def __get__(self, instance: None, owner: type) -> "cached_property[T]": ... + def __get__(self, instance: None, owner: type[Any]) -> "CachedProperty[T, R]": ... @overload - def __get__(self, instance: object, owner: type | None) -> Awaitable[T]: ... + def __get__(self, instance: T, owner: type | None) -> Awaitable[R]: ... + # __set__ is not defined at runtime, but you are allowed to replace the cached value + def __set__(self, instance: T, value: R) -> None: ... # type: ignore[misc] # pyright: ignore[reportGeneralTypeIssues] + # __del__ is not defined at runtime, but you are allowed to delete the cached value + def __del__(self, instance: T) -> None: ... +@overload +def cached_property(getter: Callable[[T], Awaitable[R]], /) -> CachedProperty[T, R]: ... +@overload +def cached_property( + asynccontextmanager_type: type[AsyncContextManager[Any]], / +) -> Callable[[Callable[[T], Awaitable[R]]], CachedProperty[T, R]]: ... @overload async def reduce( function: Callable[[T1, T2], T1], iterable: AnyIterable[T2], initial: T1 diff --git a/docs/source/api/functools.rst b/docs/source/api/functools.rst index f1d52a7..68aab0b 100644 --- a/docs/source/api/functools.rst +++ b/docs/source/api/functools.rst @@ -35,10 +35,18 @@ Attribute Caches This type of cache tracks ``await``\ ing an attribute. -.. autofunction:: cached_property(getter: (Self) → await T) +.. py:function:: cached_property(getter: (Self) → await T, /) :decorator: +.. autofunction:: cached_property(asynccontextmanager_type: Type[AsyncContextManager], /)((Self) → await T) + :decorator: + :noindex: + .. versionadded:: 1.1.0 + .. versionadded:: 3.13.0 + + The ``asynccontextmanager_type`` decorator parameter. + Callable Caches --------------- diff --git a/unittests/test_functools.py b/unittests/test_functools.py index 08f4dff..f0efa65 100644 --- a/unittests/test_functools.py +++ b/unittests/test_functools.py @@ -3,8 +3,9 @@ import pytest import asyncstdlib as a +from asyncstdlib.functools import CachedProperty -from .utility import sync, asyncify, multi_sync, Switch, Schedule +from .utility import Lock, Schedule, Switch, asyncify, multi_sync, sync @sync @@ -24,24 +25,23 @@ async def total(self): assert (await pair.total) == 3 del pair.total assert (await pair.total) == 4 - assert type(Pair.total) is a.cached_property + assert type(Pair.total) is CachedProperty @sync async def test_cache_property_nodict(): - # note: The exact error is version- and possibly implementation-dependent. - # Some Python version wrap all errors from __set_name__. - with pytest.raises(Exception): # noqa: B017 + class Foo: + __slots__ = () - class Pair: - __slots__ = "a", "b" + def __init__(self): + pass # pragma: no cover - def __init__(self, a, b): - pass # pragma: no cover + @a.cached_property + async def bar(self): + pass # pragma: no cover - @a.cached_property - async def total(self): - pass # pragma: no cover + with pytest.raises(TypeError): + Foo().bar @multi_sync @@ -66,6 +66,54 @@ async def check_increment(to): assert (await val.cached) == 1337 # last value fetched +@multi_sync +async def test_cache_property_lock_order(): + class Value: + def __init__(self, value): + self.value = value + + @a.cached_property(Lock) + async def cached(self): + value = self.value + await Switch() + return value + + async def check_cached(to, expected): + val.value = to + assert (await val.cached) == expected + + val = Value(0) + await Schedule(check_cached(5, 5), check_cached(12, 5), check_cached(1337, 5)) + assert (await val.cached) == 5 # first value fetched + + +@multi_sync +async def test_cache_property_lock_deletion(): + class Value: + def __init__(self, value): + self.value = value + + @a.cached_property(Lock) + async def cached(self): + value = self.value + await Switch() + return value + + async def check_cached(to, expected): + val.value = to + assert (await val.cached) == expected + + async def delete_attribute(to): + val.value = to + awaitable = val.cached + del val.cached + assert (await awaitable) == to + + val = Value(0) + await Schedule(check_cached(5, 5), delete_attribute(12), check_cached(1337, 12)) + assert (await val.cached) == 12 # first value fetch after deletion + + @sync async def test_reduce(): async def reduction(x, y):