diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 8598c1f..8a79e29 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -1,3 +1,9 @@ +3.1.1 +----- + +* Since 3.8, CancelledError is a subclass of BaseException rather than Exception, so we need to catch it explicitly. +* Enabled `mypy` for `wrapper` function. + 3.1.0 ----- diff --git a/memoize/statuses.py b/memoize/statuses.py index 06c0d68..9ae62b0 100644 --- a/memoize/statuses.py +++ b/memoize/statuses.py @@ -5,7 +5,7 @@ import datetime import logging from abc import ABCMeta, abstractmethod -from asyncio import Future +from asyncio import Future, CancelledError from typing import Dict, Awaitable, Union from memoize.entry import CacheKey, CacheEntry @@ -30,7 +30,7 @@ def mark_updated(self, key: CacheKey, entry: CacheEntry) -> None: raise NotImplementedError() @abstractmethod - def mark_update_aborted(self, key: CacheKey, exception: Exception) -> None: + def mark_update_aborted(self, key: CacheKey, exception: Union[Exception, CancelledError]) -> None: """Informs that update failed to complete. Calls to 'is_being_updated' will return False until 'mark_being_updated' will be called. Accepts exception to propagate it across all clients awaiting an update.""" @@ -79,7 +79,7 @@ def mark_updated(self, key: CacheKey, entry: CacheEntry) -> None: update = self._updates_in_progress.pop(key) update.set_result(entry) - def mark_update_aborted(self, key: CacheKey, exception: Exception) -> None: + def mark_update_aborted(self, key: CacheKey, exception: Union[Exception, CancelledError]) -> None: if key not in self._updates_in_progress: raise ValueError('Key {} is not being updated'.format(key)) update = self._updates_in_progress.pop(key) diff --git a/memoize/wrapper.py b/memoize/wrapper.py index 9875072..d257917 100644 --- a/memoize/wrapper.py +++ b/memoize/wrapper.py @@ -6,7 +6,7 @@ import datetime import functools import logging -from asyncio import Future +from asyncio import Future, CancelledError from typing import Optional, Callable from memoize.configuration import CacheConfiguration, NotConfiguredCacheCalledException, \ @@ -17,8 +17,8 @@ from memoize.statuses import UpdateStatuses, InMemoryLocks -def memoize(method: Optional[Callable] = None, configuration: CacheConfiguration = None, - invalidation: InvalidationSupport = None, update_statuses: UpdateStatuses = None): +def memoize(method: Optional[Callable] = None, configuration: Optional[CacheConfiguration] = None, + invalidation: Optional[InvalidationSupport] = None, update_statuses: Optional[UpdateStatuses] = None): """Wraps function with memoization. If entry reaches time it should be updated, refresh is performed in background, @@ -116,14 +116,14 @@ async def refresh(actual_entry: Optional[CacheEntry], key: CacheKey, logger.debug('Timeout for %s: %s', key, e) update_statuses.mark_update_aborted(key, e) raise CachedMethodFailedException('Refresh timed out') from e - except Exception as e: + except (Exception, CancelledError) as e: logger.debug('Error while refreshing cache for %s: %s', key, e) update_statuses.mark_update_aborted(key, e) raise CachedMethodFailedException('Refresh failed to complete') from e @functools.wraps(method) async def wrapper(*args, **kwargs): - if not configuration.configured(): + if configuration is None or not configuration.configured(): raise NotConfiguredCacheCalledException() configuration_snapshot = MutableCacheConfiguration.initialized_with(configuration) diff --git a/mypy.ini b/mypy.ini index d3d503f..9384ef3 100644 --- a/mypy.ini +++ b/mypy.ini @@ -1,2 +1,3 @@ [mypy] no_implicit_optional=False +check_untyped_defs=True \ No newline at end of file diff --git a/setup.py b/setup.py index 74e4a2f..22ab504 100644 --- a/setup.py +++ b/setup.py @@ -11,7 +11,7 @@ def prepare_description(): setup( name='py-memoize', - version='3.1.0', + version='3.1.1', author='Michal Zmuda', author_email='zmu.michal@gmail.com', url='https://github.com/DreamLab/memoize', diff --git a/tests/end2end/test_wrapper.py b/tests/end2end/test_wrapper.py index 2441ed9..6a40f7a 100644 --- a/tests/end2end/test_wrapper.py +++ b/tests/end2end/test_wrapper.py @@ -1,5 +1,6 @@ import asyncio import time +from asyncio import CancelledError from datetime import timedelta from unittest.mock import Mock @@ -174,6 +175,37 @@ async def get_value(arg, kwarg=None): assert context.value.__class__ == CachedMethodFailedException assert str(context.value.__cause__) == str(ValueError('stub0')) + async def test_should_return_cancelled_exception_for_all_concurrent_callers(self): + # given + value = 0 + + @memoize() + async def get_value(arg, kwarg=None): + new_task = asyncio.create_task(asyncio.sleep(1)) + new_task.cancel() # this will raise CancelledError + await new_task + + # when + res1 = get_value('test', kwarg='args1') + res2 = get_value('test', kwarg='args1') + res3 = get_value('test', kwarg='args1') + + # then + with pytest.raises(Exception) as context: + await res1 + assert context.value.__class__ == CachedMethodFailedException + assert str(context.value.__cause__) == str(CancelledError()) + + with pytest.raises(Exception) as context: + await res2 + assert context.value.__class__ == CachedMethodFailedException + assert str(context.value.__cause__) == str(CancelledError()) + + with pytest.raises(Exception) as context: + await res3 + assert context.value.__class__ == CachedMethodFailedException + assert str(context.value.__cause__) == str(CancelledError()) + async def test_should_return_timeout_for_all_concurrent_callers(self): # given value = 0