diff --git a/tenacity/__init__.py b/tenacity/__init__.py index bd60556..a7f10ed 100644 --- a/tenacity/__init__.py +++ b/tenacity/__init__.py @@ -88,6 +88,7 @@ if t.TYPE_CHECKING: import types + from . import asyncio as tasyncio from .retry import RetryBaseT from .stop import StopBaseT from .wait import WaitBaseT @@ -279,6 +280,14 @@ def statistics(self) -> t.Dict[str, t.Any]: self._local.statistics = t.cast(t.Dict[str, t.Any], {}) return self._local.statistics + @property + def iter_state(self) -> t.Dict[str, t.Any]: + try: + return self._local.iter_state # type: ignore[no-any-return] + except AttributeError: + self._local.iter_state = t.cast(t.Dict[str, t.Any], {}) + return self._local.iter_state + def wraps(self, f: WrappedFn) -> WrappedFn: """Wrap a function for retrying. @@ -303,20 +312,13 @@ def begin(self) -> None: self.statistics["attempt_number"] = 1 self.statistics["idle_for"] = 0 - def iter(self, retry_state: "RetryCallState") -> t.Union[DoAttempt, DoSleep, t.Any]: # noqa - fut = retry_state.outcome - if fut is None: - if self.before is not None: - self.before(retry_state) - return DoAttempt() - - is_explicit_retry = fut.failed and isinstance(fut.exception(), TryAgain) - if not (is_explicit_retry or self.retry(retry_state)): - return fut.result() + def _add_action_func(self, fn: t.Callable[..., t.Any]) -> None: + self.iter_state["actions"].append(fn) - if self.after is not None: - self.after(retry_state) + def _run_retry(self, retry_state: "RetryCallState") -> None: + self.iter_state["retry_run_result"] = self.retry(retry_state) + def _run_wait(self, retry_state: "RetryCallState") -> None: if self.wait: sleep = self.wait(retry_state) else: @@ -324,24 +326,74 @@ def iter(self, retry_state: "RetryCallState") -> t.Union[DoAttempt, DoSleep, t.A retry_state.upcoming_sleep = sleep + def _run_stop(self, retry_state: "RetryCallState") -> None: self.statistics["delay_since_first_attempt"] = retry_state.seconds_since_start - if self.stop(retry_state): + self.iter_state["stop_run_result"] = self.stop(retry_state) + + def iter(self, retry_state: "RetryCallState") -> t.Union[DoAttempt, DoSleep, t.Any]: # noqa + self._begin_iter(retry_state) + result = None + for action in self.iter_state["actions"]: + result = action(retry_state) + return result + + def _begin_iter(self, retry_state: "RetryCallState") -> None: # noqa + self.iter_state.clear() + self.iter_state["actions"] = [] + + fut = retry_state.outcome + if fut is None: + if self.before is not None: + self._add_action_func(self.before) + self._add_action_func(lambda rs: DoAttempt()) + return + + self.iter_state["is_explicit_retry"] = fut.failed and isinstance(fut.exception(), TryAgain) + if not self.iter_state["is_explicit_retry"]: + self._add_action_func(self._run_retry) + self._add_action_func(self._post_retry_check_actions) + + def _post_retry_check_actions(self, retry_state: "RetryCallState") -> None: + if not (self.iter_state["is_explicit_retry"] or self.iter_state.get("retry_run_result")): + self._add_action_func(lambda rs: rs.outcome.result()) + return + + if self.after is not None: + self._add_action_func(self.after) + + self._add_action_func(self._run_wait) + self._add_action_func(self._run_stop) + self._add_action_func(self._post_stop_check_actions) + + def _post_stop_check_actions(self, retry_state: "RetryCallState") -> None: + if self.iter_state["stop_run_result"]: if self.retry_error_callback: - return self.retry_error_callback(retry_state) - retry_exc = self.retry_error_cls(fut) - if self.reraise: - raise retry_exc.reraise() - raise retry_exc from fut.exception() + self._add_action_func(self.retry_error_callback) + return + + def exc_check(rs: "RetryCallState") -> None: + fut = t.cast(Future, rs.outcome) + retry_exc = self.retry_error_cls(fut) + if self.reraise: + raise retry_exc.reraise() + raise retry_exc from fut.exception() + + self._add_action_func(exc_check) + return + + def next_action(rs: "RetryCallState") -> None: + sleep = rs.upcoming_sleep + rs.next_action = RetryAction(sleep) + rs.idle_for += sleep + self.statistics["idle_for"] += sleep + self.statistics["attempt_number"] += 1 - retry_state.next_action = RetryAction(sleep) - retry_state.idle_for += sleep - self.statistics["idle_for"] += sleep - self.statistics["attempt_number"] += 1 + self._add_action_func(next_action) if self.before_sleep is not None: - self.before_sleep(retry_state) + self._add_action_func(self.before_sleep) - return DoSleep(sleep) + self._add_action_func(lambda rs: DoSleep(rs.upcoming_sleep)) def __iter__(self) -> t.Generator[AttemptManager, None, None]: self.begin() @@ -505,16 +557,16 @@ def retry(func: WrappedFn) -> WrappedFn: @t.overload def retry( - sleep: t.Callable[[t.Union[int, float]], t.Optional[t.Awaitable[None]]] = sleep, - stop: "StopBaseT" = stop_never, - wait: "WaitBaseT" = wait_none(), - retry: "RetryBaseT" = retry_if_exception_type(), - before: t.Callable[["RetryCallState"], None] = before_nothing, - after: t.Callable[["RetryCallState"], None] = after_nothing, - before_sleep: t.Optional[t.Callable[["RetryCallState"], None]] = None, + sleep: t.Callable[[t.Union[int, float]], t.Union[None, t.Awaitable[None]]] = sleep, + stop: "t.Union[StopBaseT, tasyncio.stop.StopBaseT]" = stop_never, + wait: "t.Union[WaitBaseT, tasyncio.wait.WaitBaseT]" = wait_none(), + retry: "t.Union[RetryBaseT, tasyncio.retry.RetryBaseT]" = retry_if_exception_type(), + before: t.Callable[["RetryCallState"], t.Union[None, t.Awaitable[None]]] = before_nothing, + after: t.Callable[["RetryCallState"], t.Union[None, t.Awaitable[None]]] = after_nothing, + before_sleep: t.Optional[t.Callable[["RetryCallState"], t.Union[None, t.Awaitable[None]]]] = None, reraise: bool = False, retry_error_cls: t.Type["RetryError"] = RetryError, - retry_error_callback: t.Optional[t.Callable[["RetryCallState"], t.Any]] = None, + retry_error_callback: t.Optional[t.Callable[["RetryCallState"], t.Union[t.Any, t.Awaitable[t.Any]]]] = None, ) -> t.Callable[[WrappedFn], WrappedFn]: ... @@ -549,7 +601,7 @@ def wrap(f: WrappedFn) -> WrappedFn: return wrap -from tenacity._asyncio import AsyncRetrying # noqa:E402,I100 +from tenacity.asyncio import AsyncRetrying # noqa:E402,I100 if tornado: from tenacity.tornadoweb import TornadoRetrying diff --git a/tenacity/_asyncio.py b/tenacity/_asyncio.py deleted file mode 100644 index d901cbd..0000000 --- a/tenacity/_asyncio.py +++ /dev/null @@ -1,94 +0,0 @@ -# Copyright 2016 Étienne Bersac -# Copyright 2016 Julien Danjou -# Copyright 2016 Joshua Harlow -# Copyright 2013-2014 Ray Holder -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import functools -import sys -import typing as t -from asyncio import sleep - -from tenacity import AttemptManager -from tenacity import BaseRetrying -from tenacity import DoAttempt -from tenacity import DoSleep -from tenacity import RetryCallState - -WrappedFnReturnT = t.TypeVar("WrappedFnReturnT") -WrappedFn = t.TypeVar("WrappedFn", bound=t.Callable[..., t.Awaitable[t.Any]]) - - -class AsyncRetrying(BaseRetrying): - sleep: t.Callable[[float], t.Awaitable[t.Any]] - - def __init__(self, sleep: t.Callable[[float], t.Awaitable[t.Any]] = sleep, **kwargs: t.Any) -> None: - super().__init__(**kwargs) - self.sleep = sleep - - async def __call__( # type: ignore[override] - self, fn: WrappedFn, *args: t.Any, **kwargs: t.Any - ) -> WrappedFnReturnT: - self.begin() - - retry_state = RetryCallState(retry_object=self, fn=fn, args=args, kwargs=kwargs) - while True: - do = self.iter(retry_state=retry_state) - if isinstance(do, DoAttempt): - try: - result = await fn(*args, **kwargs) - except BaseException: # noqa: B902 - retry_state.set_exception(sys.exc_info()) # type: ignore[arg-type] - else: - retry_state.set_result(result) - elif isinstance(do, DoSleep): - retry_state.prepare_for_next_attempt() - await self.sleep(do) - else: - return do # type: ignore[no-any-return] - - def __iter__(self) -> t.Generator[AttemptManager, None, None]: - raise TypeError("AsyncRetrying object is not iterable") - - def __aiter__(self) -> "AsyncRetrying": - self.begin() - self._retry_state = RetryCallState(self, fn=None, args=(), kwargs={}) - return self - - async def __anext__(self) -> AttemptManager: - while True: - do = self.iter(retry_state=self._retry_state) - if do is None: - raise StopAsyncIteration - elif isinstance(do, DoAttempt): - return AttemptManager(retry_state=self._retry_state) - elif isinstance(do, DoSleep): - self._retry_state.prepare_for_next_attempt() - await self.sleep(do) - else: - raise StopAsyncIteration - - def wraps(self, fn: WrappedFn) -> WrappedFn: - fn = super().wraps(fn) - # Ensure wrapper is recognized as a coroutine function. - - @functools.wraps(fn, functools.WRAPPER_ASSIGNMENTS + ("__defaults__", "__kwdefaults__")) - async def async_wrapped(*args: t.Any, **kwargs: t.Any) -> t.Any: - return await fn(*args, **kwargs) - - # Preserve attributes - async_wrapped.retry = fn.retry # type: ignore[attr-defined] - async_wrapped.retry_with = fn.retry_with # type: ignore[attr-defined] - - return async_wrapped # type: ignore[return-value] diff --git a/tenacity/asyncio/__init__.py b/tenacity/asyncio/__init__.py new file mode 100644 index 0000000..12ae126 --- /dev/null +++ b/tenacity/asyncio/__init__.py @@ -0,0 +1,234 @@ +# Copyright 2016 Étienne Bersac +# Copyright 2016 Julien Danjou +# Copyright 2016 Joshua Harlow +# Copyright 2013-2014 Ray Holder +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import asyncio +import functools +import inspect +import sys +import typing as t + +from tenacity import AttemptManager +from tenacity import BaseRetrying +from tenacity import DoAttempt +from tenacity import DoSleep +from tenacity import RetryCallState +from tenacity import RetryError +from tenacity import after_nothing +from tenacity import before_nothing + +# Import all built-in retry strategies for easier usage. +from .retry import RetryBaseT +from .retry import retry_all # noqa +from .retry import retry_always # noqa +from .retry import retry_any # noqa +from .retry import retry_if_exception # noqa +from .retry import retry_if_exception_type # noqa +from .retry import retry_if_exception_cause_type # noqa +from .retry import retry_if_not_exception_type # noqa +from .retry import retry_if_not_result # noqa +from .retry import retry_if_result # noqa +from .retry import retry_never # noqa +from .retry import retry_unless_exception_type # noqa +from .retry import retry_if_exception_message # noqa +from .retry import retry_if_not_exception_message # noqa +# Import all built-in stop strategies for easier usage. +from .stop import StopBaseT +from .stop import stop_after_attempt # noqa +from .stop import stop_after_delay # noqa +from .stop import stop_before_delay # noqa +from .stop import stop_all # noqa +from .stop import stop_any # noqa +from .stop import stop_never # noqa +from .stop import stop_when_event_set # noqa +# Import all built-in wait strategies for easier usage. +from .wait import WaitBaseT +from .wait import wait_chain # noqa +from .wait import wait_combine # noqa +from .wait import wait_exponential # noqa +from .wait import wait_fixed # noqa +from .wait import wait_incrementing # noqa +from .wait import wait_none # noqa +from .wait import wait_random # noqa +from .wait import wait_random_exponential # noqa +from .wait import wait_random_exponential as wait_full_jitter # noqa +from .wait import wait_exponential_jitter # noqa + +WrappedFnReturnT = t.TypeVar("WrappedFnReturnT") +WrappedFn = t.TypeVar("WrappedFn", bound=t.Callable[..., t.Awaitable[t.Any]]) + + +def _is_coroutine_callable(call: t.Callable[..., t.Any]) -> bool: + if inspect.isroutine(call): + return inspect.iscoroutinefunction(call) + if inspect.isclass(call): + return False + dunder_call = getattr(call, "__call__", None) # noqa: B004 + return inspect.iscoroutinefunction(dunder_call) + + +class AsyncRetrying(BaseRetrying): + def __init__( + self, + sleep: t.Callable[[t.Union[int, float]], t.Union[None, t.Awaitable[None]]] = asyncio.sleep, + stop: "t.Union[StopBaseT, StopBaseT]" = stop_never, + wait: "t.Union[WaitBaseT, WaitBaseT]" = wait_none(), + retry: "t.Union[RetryBaseT, RetryBaseT]" = retry_if_exception_type(), + before: t.Callable[["RetryCallState"], t.Union[None, t.Awaitable[None]]] = before_nothing, + after: t.Callable[["RetryCallState"], t.Union[None, t.Awaitable[None]]] = after_nothing, + before_sleep: t.Optional[t.Callable[["RetryCallState"], t.Union[None, t.Awaitable[None]]]] = None, + reraise: bool = False, + retry_error_cls: t.Type["RetryError"] = RetryError, + retry_error_callback: t.Optional[t.Callable[["RetryCallState"], t.Union[t.Any, t.Awaitable[t.Any]]]] = None, + ) -> None: + super().__init__( + sleep=sleep, # type: ignore[arg-type] + stop=stop, # type: ignore[arg-type] + wait=wait, # type: ignore[arg-type] + retry=retry, # type: ignore[arg-type] + before=before, # type: ignore[arg-type] + after=after, # type: ignore[arg-type] + before_sleep=before_sleep, # type: ignore[arg-type] + reraise=reraise, + retry_error_cls=retry_error_cls, + retry_error_callback=retry_error_callback, + ) + + async def __call__( # type: ignore[override] + self, fn: WrappedFn, *args: t.Any, **kwargs: t.Any + ) -> WrappedFnReturnT: + self.begin() + + retry_state = RetryCallState(retry_object=self, fn=fn, args=args, kwargs=kwargs) + while True: + do = await self.iter(retry_state=retry_state) + if isinstance(do, DoAttempt): + try: + result = await fn(*args, **kwargs) + except BaseException: # noqa: B902 + retry_state.set_exception(sys.exc_info()) # type: ignore[arg-type] + else: + retry_state.set_result(result) + elif isinstance(do, DoSleep): + retry_state.prepare_for_next_attempt() + await self.sleep(do) # type: ignore[misc] + else: + return do # type: ignore[no-any-return] + + @classmethod + def _wrap_action_func(cls, fn: t.Callable[..., t.Any]) -> t.Callable[..., t.Any]: + if _is_coroutine_callable(fn): + return fn + + async def inner(*args: t.Any, **kwargs: t.Any) -> t.Any: + return fn(*args, **kwargs) + + return inner + + def _add_action_func(self, fn: t.Callable[..., t.Any]) -> None: + self.iter_state["actions"].append(self._wrap_action_func(fn)) + + async def _run_retry(self, retry_state: "RetryCallState") -> None: # type: ignore[override] + self.iter_state["retry_run_result"] = await self._wrap_action_func(self.retry)(retry_state) + + async def _run_wait(self, retry_state: "RetryCallState") -> None: # type: ignore[override] + if self.wait: + sleep = await self._wrap_action_func(self.wait)(retry_state) + else: + sleep = 0.0 + + retry_state.upcoming_sleep = sleep + + async def _run_stop(self, retry_state: "RetryCallState") -> None: # type: ignore[override] + self.statistics["delay_since_first_attempt"] = retry_state.seconds_since_start + self.iter_state["stop_run_result"] = await self._wrap_action_func(self.stop)(retry_state) + + async def iter(self, retry_state: "RetryCallState") -> t.Union[DoAttempt, DoSleep, t.Any]: # noqa: A003 + self._begin_iter(retry_state) + result = None + for action in self.iter_state["actions"]: + result = await action(retry_state) + return result + + def __iter__(self) -> t.Generator[AttemptManager, None, None]: + raise TypeError("AsyncRetrying object is not iterable") + + def __aiter__(self) -> "AsyncRetrying": + self.begin() + self._retry_state = RetryCallState(self, fn=None, args=(), kwargs={}) + return self + + async def __anext__(self) -> AttemptManager: + while True: + do = await self.iter(retry_state=self._retry_state) + if do is None: + raise StopAsyncIteration + elif isinstance(do, DoAttempt): + return AttemptManager(retry_state=self._retry_state) + elif isinstance(do, DoSleep): + self._retry_state.prepare_for_next_attempt() + await self.sleep(do) # type: ignore[misc] + else: + raise StopAsyncIteration + + def wraps(self, fn: WrappedFn) -> WrappedFn: + fn = super().wraps(fn) + # Ensure wrapper is recognized as a coroutine function. + + @functools.wraps(fn, functools.WRAPPER_ASSIGNMENTS + ("__defaults__", "__kwdefaults__")) + async def async_wrapped(*args: t.Any, **kwargs: t.Any) -> t.Any: + return await fn(*args, **kwargs) + + # Preserve attributes + async_wrapped.retry = fn.retry # type: ignore[attr-defined] + async_wrapped.retry_with = fn.retry_with # type: ignore[attr-defined] + + return async_wrapped # type: ignore[return-value] + + +__all__ = [ + "retry_all", + "retry_always", + "retry_any", + "retry_if_exception", + "retry_if_exception_type", + "retry_if_exception_cause_type", + "retry_if_not_exception_type", + "retry_if_not_result", + "retry_if_result", + "retry_never", + "retry_unless_exception_type", + "retry_if_exception_message", + "retry_if_not_exception_message", + "stop_after_attempt", + "stop_after_delay", + "stop_before_delay", + "stop_all", + "stop_any", + "stop_never", + "stop_when_event_set", + "wait_chain", + "wait_combine", + "wait_exponential", + "wait_fixed", + "wait_incrementing", + "wait_none", + "wait_random", + "wait_random_exponential", + "wait_full_jitter", + "wait_exponential_jitter", + "WrappedFn", + "AsyncRetrying", +] diff --git a/tenacity/asyncio/retry.py b/tenacity/asyncio/retry.py new file mode 100644 index 0000000..eb63286 --- /dev/null +++ b/tenacity/asyncio/retry.py @@ -0,0 +1,283 @@ +# Copyright 2016–2021 Julien Danjou +# Copyright 2016 Joshua Harlow +# Copyright 2013-2014 Ray Holder +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import abc +import re +import typing + +from tenacity import retry_base + +if typing.TYPE_CHECKING: + from tenacity import RetryCallState + + +class retry_base(retry_base): # type: ignore[no-redef] + """Abstract base class for retry strategies.""" + + @abc.abstractmethod + async def __call__(self, retry_state: "RetryCallState") -> bool: # type: ignore[override] + pass + + +RetryBaseT = typing.Union[retry_base, typing.Callable[["RetryCallState"], typing.Awaitable[bool]]] + + +class _retry_never(retry_base): + """Retry strategy that never rejects any result.""" + + async def __call__(self, retry_state: "RetryCallState") -> bool: # type: ignore[override] + return False + + +retry_never = _retry_never() + + +class _retry_always(retry_base): + """Retry strategy that always rejects any result.""" + + async def __call__(self, retry_state: "RetryCallState") -> bool: # type: ignore[override] + return True + + +retry_always = _retry_always() + + +class retry_if_exception(retry_base): + """Retry strategy that retries if an exception verifies a predicate.""" + + def __init__(self, predicate: typing.Callable[[BaseException], typing.Awaitable[bool]]) -> None: + self.predicate = predicate + + async def __call__(self, retry_state: "RetryCallState") -> bool: # type: ignore[override] + if retry_state.outcome is None: + raise RuntimeError("__call__() called before outcome was set") + + if retry_state.outcome.failed: + exception = retry_state.outcome.exception() + if exception is None: + raise RuntimeError("outcome failed but the exception is None") + return await self.predicate(exception) + else: + return False + + +class retry_if_exception_type(retry_if_exception): + """Retries if an exception has been raised of one or more types.""" + + def __init__( + self, + exception_types: typing.Union[ + typing.Type[BaseException], + typing.Tuple[typing.Type[BaseException], ...], + ] = Exception, + ) -> None: + self.exception_types = exception_types + + async def predicate(e: BaseException) -> bool: + return isinstance(e, exception_types) + + super().__init__(predicate) + + +class retry_if_not_exception_type(retry_if_exception): + """Retries except an exception has been raised of one or more types.""" + + def __init__( + self, + exception_types: typing.Union[ + typing.Type[BaseException], + typing.Tuple[typing.Type[BaseException], ...], + ] = Exception, + ) -> None: + self.exception_types = exception_types + + async def predicate(e: BaseException) -> bool: + return not isinstance(e, exception_types) + + super().__init__(predicate) + + +class retry_unless_exception_type(retry_if_exception): + """Retries until an exception is raised of one or more types.""" + + def __init__( + self, + exception_types: typing.Union[ + typing.Type[BaseException], + typing.Tuple[typing.Type[BaseException], ...], + ] = Exception, + ) -> None: + self.exception_types = exception_types + + async def predicate(e: BaseException) -> bool: + return not isinstance(e, exception_types) + + super().__init__(predicate) + + async def __call__(self, retry_state: "RetryCallState") -> bool: # type: ignore[override] + if retry_state.outcome is None: + raise RuntimeError("__call__() called before outcome was set") + + # always retry if no exception was raised + if not retry_state.outcome.failed: + return True + + exception = retry_state.outcome.exception() + if exception is None: + raise RuntimeError("outcome failed but the exception is None") + return await self.predicate(exception) + + +class retry_if_exception_cause_type(retry_base): + """Retries if any of the causes of the raised exception is of one or more types. + + The check on the type of the cause of the exception is done recursively (until finding + an exception in the chain that has no `__cause__`) + """ + + def __init__( + self, + exception_types: typing.Union[ + typing.Type[BaseException], + typing.Tuple[typing.Type[BaseException], ...], + ] = Exception, + ) -> None: + self.exception_cause_types = exception_types + + async def __call__(self, retry_state: "RetryCallState") -> bool: # type: ignore[override] + if retry_state.outcome is None: + raise RuntimeError("__call__ called before outcome was set") + + if retry_state.outcome.failed: + exc = retry_state.outcome.exception() + while exc is not None: + if isinstance(exc.__cause__, self.exception_cause_types): + return True + exc = exc.__cause__ + + return False + + +class retry_if_result(retry_base): + """Retries if the result verifies a predicate.""" + + def __init__(self, predicate: typing.Callable[[typing.Any], typing.Awaitable[bool]]) -> None: + self.predicate = predicate + + async def __call__(self, retry_state: "RetryCallState") -> bool: # type: ignore[override] + if retry_state.outcome is None: + raise RuntimeError("__call__() called before outcome was set") + + if not retry_state.outcome.failed: + return await self.predicate(retry_state.outcome.result()) + else: + return False + + +class retry_if_not_result(retry_base): + """Retries if the result refutes a predicate.""" + + def __init__(self, predicate: typing.Callable[[typing.Any], typing.Awaitable[bool]]) -> None: + self.predicate = predicate + + async def __call__(self, retry_state: "RetryCallState") -> bool: # type: ignore[override] + if retry_state.outcome is None: + raise RuntimeError("__call__() called before outcome was set") + + if not retry_state.outcome.failed: + return not await self.predicate(retry_state.outcome.result()) + else: + return False + + +class retry_if_exception_message(retry_if_exception): + """Retries if an exception message equals or matches.""" + + def __init__( + self, + message: typing.Optional[str] = None, + match: typing.Optional[str] = None, + ) -> None: + if message and match: + raise TypeError(f"{self.__class__.__name__}() takes either 'message' or 'match', not both") + + # set predicate + if message: + + async def message_fnc(exception: BaseException) -> bool: + return message == str(exception) + + predicate = message_fnc + elif match: + prog = re.compile(match) + + async def match_fnc(exception: BaseException) -> bool: + return bool(prog.match(str(exception))) + + predicate = match_fnc + else: + raise TypeError(f"{self.__class__.__name__}() missing 1 required argument 'message' or 'match'") + + super().__init__(predicate) + + +class retry_if_not_exception_message(retry_if_exception_message): + """Retries until an exception message equals or matches.""" + + def __init__( + self, + message: typing.Optional[str] = None, + match: typing.Optional[str] = None, + ) -> None: + super().__init__(message, match) + if_predicate = self.predicate + + # invert predicate + async def predicate(e: BaseException) -> bool: + return not if_predicate(e) + + self.predicate = predicate + + async def __call__(self, retry_state: "RetryCallState") -> bool: # type: ignore[override] + if retry_state.outcome is None: + raise RuntimeError("__call__() called before outcome was set") + + if not retry_state.outcome.failed: + return True + + exception = retry_state.outcome.exception() + if exception is None: + raise RuntimeError("outcome failed but the exception is None") + return await self.predicate(exception) + + +class retry_any(retry_base): + """Retries if any of the retries condition is valid.""" + + def __init__(self, *retries: retry_base) -> None: + self.retries = retries + + async def __call__(self, retry_state: "RetryCallState") -> bool: # type: ignore[override] + return any(r(retry_state) for r in self.retries) + + +class retry_all(retry_base): + """Retries if all the retries condition are valid.""" + + def __init__(self, *retries: retry_base) -> None: + self.retries = retries + + async def __call__(self, retry_state: "RetryCallState") -> bool: # type: ignore[override] + return all(r(retry_state) for r in self.retries) diff --git a/tenacity/asyncio/stop.py b/tenacity/asyncio/stop.py new file mode 100644 index 0000000..1528426 --- /dev/null +++ b/tenacity/asyncio/stop.py @@ -0,0 +1,122 @@ +# Copyright 2016–2021 Julien Danjou +# Copyright 2016 Joshua Harlow +# Copyright 2013-2014 Ray Holder +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import abc +import typing + +from tenacity import _utils +from tenacity.stop import stop_base + +if typing.TYPE_CHECKING: + import asyncio + + from tenacity import RetryCallState + + +class stop_base(stop_base): # type: ignore[no-redef] + """Abstract base class for stop strategies.""" + + @abc.abstractmethod + async def __call__(self, retry_state: "RetryCallState") -> bool: # type: ignore[override] + pass + + +StopBaseT = typing.Union[stop_base, typing.Callable[["RetryCallState"], typing.Awaitable[bool]]] + + +class stop_any(stop_base): + """Stop if any of the stop condition is valid.""" + + def __init__(self, *stops: stop_base) -> None: + self.stops = stops + + async def __call__(self, retry_state: "RetryCallState") -> bool: # type: ignore[override] + return any(x(retry_state) for x in self.stops) + + +class stop_all(stop_base): + """Stop if all the stop conditions are valid.""" + + def __init__(self, *stops: stop_base) -> None: + self.stops = stops + + async def __call__(self, retry_state: "RetryCallState") -> bool: # type: ignore[override] + return all(x(retry_state) for x in self.stops) + + +class _stop_never(stop_base): + """Never stop.""" + + async def __call__(self, retry_state: "RetryCallState") -> bool: # type: ignore[override] + return False + + +stop_never = _stop_never() + + +class stop_when_event_set(stop_base): + """Stop when the given event is set.""" + + def __init__(self, event: "asyncio.Event") -> None: + self.event = event + + async def __call__(self, retry_state: "RetryCallState") -> bool: # type: ignore[override] + return self.event.is_set() + + +class stop_after_attempt(stop_base): + """Stop when the previous attempt >= max_attempt.""" + + def __init__(self, max_attempt_number: int) -> None: + self.max_attempt_number = max_attempt_number + + async def __call__(self, retry_state: "RetryCallState") -> bool: # type: ignore[override] + return retry_state.attempt_number >= self.max_attempt_number + + +class stop_after_delay(stop_base): + """ + Stop when the time from the first attempt >= limit. + + Note: `max_delay` will be exceeded, so when used with a `wait`, the actual total delay will be greater + than `max_delay` by some of the final sleep period before `max_delay` is exceeded. + + If you need stricter timing with waits, consider `stop_before_delay` instead. + """ + + def __init__(self, max_delay: _utils.time_unit_type) -> None: + self.max_delay = _utils.to_seconds(max_delay) + + async def __call__(self, retry_state: "RetryCallState") -> bool: # type: ignore[override] + if retry_state.seconds_since_start is None: + raise RuntimeError("__call__() called but seconds_since_start is not set") + return retry_state.seconds_since_start >= self.max_delay + + +class stop_before_delay(stop_base): + """ + Stop right before the next attempt would take place after the time from the first attempt >= limit. + + Most useful when you are using with a `wait` function like wait_random_exponential, but need to make + sure that the max_delay is not exceeded. + """ + + def __init__(self, max_delay: _utils.time_unit_type) -> None: + self.max_delay = _utils.to_seconds(max_delay) + + async def __call__(self, retry_state: "RetryCallState") -> bool: # type: ignore[override] + if retry_state.seconds_since_start is None: + raise RuntimeError("__call__() called but seconds_since_start is not set") + return retry_state.seconds_since_start + retry_state.upcoming_sleep >= self.max_delay diff --git a/tenacity/asyncio/wait.py b/tenacity/asyncio/wait.py new file mode 100644 index 0000000..021b34d --- /dev/null +++ b/tenacity/asyncio/wait.py @@ -0,0 +1,219 @@ +# Copyright 2016–2021 Julien Danjou +# Copyright 2016 Joshua Harlow +# Copyright 2013-2014 Ray Holder +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import abc +import random +import typing + +from tenacity import _utils +from tenacity.wait import wait_base + +if typing.TYPE_CHECKING: + from tenacity import RetryCallState + + +class wait_base(wait_base): # type: ignore[no-redef] + """Abstract base class for wait strategies.""" + + @abc.abstractmethod + async def __call__(self, retry_state: "RetryCallState") -> float: # type: ignore[override] + pass + + +WaitBaseT = typing.Union[wait_base, typing.Callable[["RetryCallState"], typing.Awaitable[typing.Union[float, int]]]] + + +class wait_fixed(wait_base): + """Wait strategy that waits a fixed amount of time between each retry.""" + + def __init__(self, wait: _utils.time_unit_type) -> None: + self.wait_fixed = _utils.to_seconds(wait) + + async def __call__(self, retry_state: "RetryCallState") -> float: # type: ignore[override] + return self.wait_fixed + + +class wait_none(wait_fixed): + """Wait strategy that doesn't wait at all before retrying.""" + + def __init__(self) -> None: + super().__init__(0) + + +class wait_random(wait_base): + """Wait strategy that waits a random amount of time between min/max.""" + + def __init__(self, min: _utils.time_unit_type = 0, max: _utils.time_unit_type = 1) -> None: # noqa + self.wait_random_min = _utils.to_seconds(min) + self.wait_random_max = _utils.to_seconds(max) + + async def __call__(self, retry_state: "RetryCallState") -> float: # type: ignore[override] + return self.wait_random_min + (random.random() * (self.wait_random_max - self.wait_random_min)) + + +class wait_combine(wait_base): + """Combine several waiting strategies.""" + + def __init__(self, *strategies: wait_base) -> None: + self.wait_funcs = strategies + + async def __call__(self, retry_state: "RetryCallState") -> float: # type: ignore[override] + return sum(x(retry_state=retry_state) for x in self.wait_funcs) + + +class wait_chain(wait_base): + """Chain two or more waiting strategies. + + If all strategies are exhausted, the very last strategy is used + thereafter. + + For example:: + + @retry(wait=wait_chain(*[wait_fixed(1) for i in range(3)] + + [wait_fixed(2) for j in range(5)] + + [wait_fixed(5) for k in range(4))) + def wait_chained(): + print("Wait 1s for 3 attempts, 2s for 5 attempts and 5s + thereafter.") + """ + + def __init__(self, *strategies: wait_base) -> None: + self.strategies = strategies + + async def __call__(self, retry_state: "RetryCallState") -> float: # type: ignore[override] + wait_func_no = min(max(retry_state.attempt_number, 1), len(self.strategies)) + wait_func = self.strategies[wait_func_no - 1] + return wait_func(retry_state=retry_state) + + +class wait_incrementing(wait_base): + """Wait an incremental amount of time after each attempt. + + Starting at a starting value and incrementing by a value for each attempt + (and restricting the upper limit to some maximum value). + """ + + def __init__( + self, + start: _utils.time_unit_type = 0, + increment: _utils.time_unit_type = 100, + max: _utils.time_unit_type = _utils.MAX_WAIT, # noqa + ) -> None: + self.start = _utils.to_seconds(start) + self.increment = _utils.to_seconds(increment) + self.max = _utils.to_seconds(max) + + async def __call__(self, retry_state: "RetryCallState") -> float: # type: ignore[override] + result = self.start + (self.increment * (retry_state.attempt_number - 1)) + return max(0, min(result, self.max)) + + +class wait_exponential(wait_base): + """Wait strategy that applies exponential backoff. + + It allows for a customized multiplier and an ability to restrict the + upper and lower limits to some maximum and minimum value. + + The intervals are fixed (i.e. there is no jitter), so this strategy is + suitable for balancing retries against latency when a required resource is + unavailable for an unknown duration, but *not* suitable for resolving + contention between multiple processes for a shared resource. Use + wait_random_exponential for the latter case. + """ + + def __init__( + self, + multiplier: typing.Union[int, float] = 1, + max: _utils.time_unit_type = _utils.MAX_WAIT, # noqa + exp_base: typing.Union[int, float] = 2, + min: _utils.time_unit_type = 0, # noqa + ) -> None: + self.multiplier = multiplier + self.min = _utils.to_seconds(min) + self.max = _utils.to_seconds(max) + self.exp_base = exp_base + + async def __call__(self, retry_state: "RetryCallState") -> float: # type: ignore[override] + try: + exp = self.exp_base ** (retry_state.attempt_number - 1) + result = self.multiplier * exp + except OverflowError: + return self.max + return max(max(0, self.min), min(result, self.max)) + + +class wait_random_exponential(wait_exponential): + """Random wait with exponentially widening window. + + An exponential backoff strategy used to mediate contention between multiple + uncoordinated processes for a shared resource in distributed systems. This + is the sense in which "exponential backoff" is meant in e.g. Ethernet + networking, and corresponds to the "Full Jitter" algorithm described in + this blog post: + + https://aws.amazon.com/blogs/architecture/exponential-backoff-and-jitter/ + + Each retry occurs at a random time in a geometrically expanding interval. + It allows for a custom multiplier and an ability to restrict the upper + limit of the random interval to some maximum value. + + Example:: + + wait_random_exponential(multiplier=0.5, # initial window 0.5s + max=60) # max 60s timeout + + When waiting for an unavailable resource to become available again, as + opposed to trying to resolve contention for a shared resource, the + wait_exponential strategy (which uses a fixed interval) may be preferable. + + """ + + async def __call__(self, retry_state: "RetryCallState") -> float: # type: ignore[override] + high = await super().__call__(retry_state=retry_state) + return random.uniform(0, high) + + +class wait_exponential_jitter(wait_base): + """Wait strategy that applies exponential backoff and jitter. + + It allows for a customized initial wait, maximum wait and jitter. + + This implements the strategy described here: + https://cloud.google.com/storage/docs/retry-strategy + + The wait time is min(initial * 2**n + random.uniform(0, jitter), maximum) + where n is the retry count. + """ + + def __init__( + self, + initial: float = 1, + max: float = _utils.MAX_WAIT, # noqa + exp_base: float = 2, + jitter: float = 1, + ) -> None: + self.initial = initial + self.max = max + self.exp_base = exp_base + self.jitter = jitter + + async def __call__(self, retry_state: "RetryCallState") -> float: # type: ignore[override] + jitter = random.uniform(0, self.jitter) + try: + exp = self.exp_base ** (retry_state.attempt_number - 1) + result = self.initial * exp + jitter + except OverflowError: + result = self.max + return max(0, min(result, self.max)) diff --git a/tests/test_asyncio.py b/tests/test_asyncio.py index 542f540..7d832a3 100644 --- a/tests/test_asyncio.py +++ b/tests/test_asyncio.py @@ -22,7 +22,7 @@ import tenacity from tenacity import AsyncRetrying, RetryError -from tenacity import _asyncio as tasyncio +from tenacity import asyncio as tasyncio from tenacity import retry, retry_if_result, stop_after_attempt from tenacity.wait import wait_fixed @@ -55,6 +55,12 @@ async def _retryable_coroutine_with_2_attempts(thing): thing.go() +@retry(stop=tasyncio.stop_after_attempt(2)) +async def _async_retryable_coroutine_with_2_attempts(thing): + await asyncio.sleep(0.00001) + thing.go() + + class TestAsync(unittest.TestCase): @asynctest async def test_retry(self): @@ -82,6 +88,14 @@ async def test_stop_after_attempt(self): except RetryError: assert thing.counter == 2 + @asynctest + async def test_stop_after_attempt_async(self): + thing = NoIOErrorAfterCount(2) + try: + await _async_retryable_coroutine_with_2_attempts(thing) + except RetryError: + assert thing.counter == 2 + def test_repr(self): repr(tasyncio.AsyncRetrying())