diff --git a/src/stamina/_core.py b/src/stamina/_core.py index 43e24ce..cd7da22 100644 --- a/src/stamina/_core.py +++ b/src/stamina/_core.py @@ -5,6 +5,7 @@ from __future__ import annotations import datetime as dt +import random import sys from dataclasses import dataclass, replace @@ -15,6 +16,7 @@ AsyncIterator, Awaitable, Callable, + ClassVar, Iterator, Tuple, Type, @@ -108,12 +110,17 @@ class Attempt: .. versionadded:: 23.2.0 """ - __slots__ = ("_t_attempt",) + __slots__ = ("_t_attempt", "_next_wait_fn") _t_attempt: _t.AttemptManager - def __init__(self, attempt: _t.AttemptManager): + def __init__( + self, + attempt: _t.AttemptManager, + next_wait_fn: Callable[[int], float] | None, + ): self._t_attempt = attempt + self._next_wait_fn = next_wait_fn def __repr__(self) -> str: return f"" @@ -131,9 +138,18 @@ def next_wait(self) -> float: The number of seconds of backoff before the *next* attempt if *this* attempt fails. + + .. warning:: + This value does **not** include a possible random jitter and is + therefore just a *lower bound* of the actual value. + .. versionadded:: 24.3.0 """ - return self._t_attempt.retry_state.upcoming_sleep # type: ignore[no-any-return] + return ( + self._next_wait_fn(self._t_attempt.retry_state.attempt_number + 1) + if self._next_wait_fn + else 0.0 + ) def __enter__(self) -> None: return self._t_attempt.__enter__() # type: ignore[no-any-return] @@ -382,13 +398,30 @@ def __aiter__(self) -> _t.AsyncRetrying: @dataclass class _RetryContextIterator: - __slots__ = ("_t_kw", "_t_a_retrying", "_name", "_args", "_kw") + __slots__ = ( + "_t_kw", + "_t_a_retrying", + "_name", + "_args", + "_kw", + "_wait_jitter", + "_wait_initial", + "_wait_max", + "_wait_exp_base", + ) _t_kw: dict[str, object] _t_a_retrying: _t.AsyncRetrying _name: str _args: tuple[object, ...] _kw: dict[str, object] + _wait_jitter: float + _wait_initial: float + _wait_max: float + _wait_exp_base: float + + _random: ClassVar[random.Random] = random.Random() # noqa: S311 + @classmethod def from_params( cls, @@ -411,30 +444,26 @@ def from_params( _retry = _t.retry_if_exception_type(on) else: _retry = _t.retry_if_exception(on) - return cls( + + if isinstance(wait_initial, dt.timedelta): + wait_initial = wait_initial.total_seconds() + + if isinstance(wait_max, dt.timedelta): + wait_max = wait_max.total_seconds() + + if isinstance(wait_jitter, dt.timedelta): + wait_jitter = wait_jitter.total_seconds() + + inst = cls( _name=name, _args=args, _kw=kw, + _wait_jitter=wait_jitter, + _wait_initial=wait_initial, + _wait_max=wait_max, + _wait_exp_base=wait_exp_base, _t_kw={ "retry": _retry, - "wait": _t.wait_exponential_jitter( - initial=( - wait_initial.total_seconds() - if isinstance(wait_initial, dt.timedelta) - else wait_initial - ), - max=( - wait_max.total_seconds() - if isinstance(wait_max, dt.timedelta) - else wait_max - ), - exp_base=wait_exp_base, - jitter=( - wait_jitter.total_seconds() - if isinstance(wait_jitter, dt.timedelta) - else wait_jitter - ), - ), "stop": _make_stop( attempts=attempts, timeout=( @@ -448,6 +477,10 @@ def from_params( _t_a_retrying=_LAZY_NO_ASYNC_RETRY, ) + inst._t_kw["wait"] = inst._jittered_backoff_for_rcs + + return inst + def with_name( self, name: str, args: tuple[object, ...], kw: dict[str, object] ) -> _RetryContextIterator: @@ -459,7 +492,7 @@ def with_name( def __iter__(self) -> Iterator[Attempt]: if not CONFIG.is_active: for r in _t.Retrying(reraise=True, stop=_STOP_NO_RETRY): - yield Attempt(r) + yield Attempt(r, None) return @@ -469,7 +502,7 @@ def __iter__(self) -> Iterator[Attempt]: ), **self._t_kw, ): - yield Attempt(r) + yield Attempt(r, self._backoff_for_attempt_number) def __aiter__(self) -> AsyncIterator[Attempt]: if CONFIG.is_active: @@ -486,7 +519,31 @@ def __aiter__(self) -> AsyncIterator[Attempt]: return self async def __anext__(self) -> Attempt: - return Attempt(await self._t_a_retrying.__anext__()) + return Attempt( + await self._t_a_retrying.__anext__(), + self._backoff_for_attempt_number, + ) + + def _backoff_for_attempt_number(self, num: int) -> float: + """ + Compute a jitter-less lower bound for backoff number *num*. + + *num* is 1-based. + """ + return min( + self._wait_max, + self._wait_initial * (self._wait_exp_base ** (num - 1)), + ) + + def _jittered_backoff_for_rcs(self, rcs: _t.RetryCallState) -> float: + """ + Compute the backoff for *rcs*. + """ + return min( + self._wait_max, + self._backoff_for_attempt_number(rcs.attempt_number) + + self._random.uniform(0, self._wait_jitter), + ) def _make_before_sleep( diff --git a/tests/test_async.py b/tests/test_async.py index 20de2be..46ddcf5 100644 --- a/tests/test_async.py +++ b/tests/test_async.py @@ -199,18 +199,13 @@ async def test_next_wait(): """ The next_wait property is updated. """ - i = 0 - - async for attempt in stamina.retry_context(on=ValueError, wait_max=0.001): + async for attempt in stamina.retry_context(on=ValueError, wait_max=0.0001): with attempt: - if i == 0: - assert 0.0 == attempt.next_wait + assert pytest.approx(0.0001) == attempt.next_wait - i += 1 + if attempt.num == 1: raise ValueError - assert pytest.approx(0.001) == attempt.next_wait - async def test_retry_blocks_can_be_disabled(): """ diff --git a/tests/test_sync.py b/tests/test_sync.py index d63348c..b819a7a 100644 --- a/tests/test_sync.py +++ b/tests/test_sync.py @@ -4,6 +4,8 @@ import datetime as dt +from types import SimpleNamespace + import pytest import tenacity @@ -171,17 +173,30 @@ def test_next_wait(): """ The next_wait property is updated. """ - i = 0 - for attempt in stamina.retry_context(on=ValueError, wait_max=0.001): + for attempt in stamina.retry_context(on=ValueError, wait_max=0.0001): with attempt: - if i == 0: - assert 0.0 == attempt.next_wait + assert pytest.approx(0.0001) == attempt.next_wait - i += 1 + if attempt.num == 1: raise ValueError - assert pytest.approx(0.001) == attempt.next_wait + +def test_backoff_computation_clamps(): + """ + The backoff returned by _RetryContextIterator._backoff_for_attempt_number + and _RetryContextIterator._jittered_backoff_for_rcs never exceeds wait_max. + """ + rci = stamina.retry_context(on=ValueError, wait_max=0.42) + + for i in range(1, 10): + backoff = rci._backoff_for_attempt_number(i) + assert backoff <= 0.42 + + jittered = rci._jittered_backoff_for_rcs( + SimpleNamespace(attempt_number=i) + ) + assert jittered <= 0.42 class TestMakeStop: