diff --git a/tenacity/__init__.py b/tenacity/__init__.py index 11e2faa..dc8dfbc 100644 --- a/tenacity/__init__.py +++ b/tenacity/__init__.py @@ -15,8 +15,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - - +import dataclasses import functools import sys import threading @@ -97,6 +96,29 @@ WrappedFn = t.TypeVar("WrappedFn", bound=t.Callable[..., t.Any]) +dataclass_kwargs = {} +if sys.version_info >= (3, 10): + dataclass_kwargs.update({"slots": True}) + + +@dataclasses.dataclass(**dataclass_kwargs) +class IterState: + actions: t.List[t.Callable[["RetryCallState"], t.Any]] = dataclasses.field( + default_factory=list + ) + retry_run_result: bool = False + delay_since_first_attempt: int = 0 + stop_run_result: bool = False + is_explicit_retry: bool = False + + def reset(self) -> None: + self.actions = [] + self.retry_run_result = False + self.delay_since_first_attempt = 0 + self.stop_run_result = False + self.is_explicit_retry = False + + class TryAgain(Exception): """Always retry the executed function when raised.""" @@ -287,6 +309,14 @@ def statistics(self) -> t.Dict[str, t.Any]: self._local.statistics = t.cast(t.Dict[str, t.Any], {}) return self._local.statistics + @property + def iter_state(self) -> IterState: + try: + return self._local.iter_state # type: ignore[no-any-return] + except AttributeError: + self._local.iter_state = IterState() + return self._local.iter_state + def wraps(self, f: WrappedFn) -> WrappedFn: """Wrap a function for retrying. @@ -313,20 +343,13 @@ def begin(self) -> None: self.statistics["attempt_number"] = 1 self.statistics["idle_for"] = 0 - def iter(self, retry_state: "RetryCallState") -> t.Union[DoAttempt, DoSleep, t.Any]: # noqa - fut = retry_state.outcome - if fut is None: - if self.before is not None: - self.before(retry_state) - return DoAttempt() - - is_explicit_retry = fut.failed and isinstance(fut.exception(), TryAgain) - if not (is_explicit_retry or self.retry(retry_state)): - return fut.result() + def _add_action_func(self, fn: t.Callable[..., t.Any]) -> None: + self.iter_state.actions.append(fn) - if self.after is not None: - self.after(retry_state) + def _run_retry(self, retry_state: "RetryCallState") -> None: + self.iter_state.retry_run_result = self.retry(retry_state) + def _run_wait(self, retry_state: "RetryCallState") -> None: if self.wait: sleep = self.wait(retry_state) else: @@ -334,24 +357,75 @@ def iter(self, retry_state: "RetryCallState") -> t.Union[DoAttempt, DoSleep, t.A retry_state.upcoming_sleep = sleep + def _run_stop(self, retry_state: "RetryCallState") -> None: self.statistics["delay_since_first_attempt"] = retry_state.seconds_since_start - if self.stop(retry_state): + self.iter_state.stop_run_result = self.stop(retry_state) + + def iter(self, retry_state: "RetryCallState") -> t.Union[DoAttempt, DoSleep, t.Any]: # noqa + self._begin_iter(retry_state) + result = None + for action in self.iter_state.actions: + result = action(retry_state) + return result + + def _begin_iter(self, retry_state: "RetryCallState") -> None: # noqa + self.iter_state.reset() + + fut = retry_state.outcome + if fut is None: + if self.before is not None: + self._add_action_func(self.before) + self._add_action_func(lambda rs: DoAttempt()) + return + + self.iter_state.is_explicit_retry = fut.failed and isinstance( + fut.exception(), TryAgain + ) + if not self.iter_state.is_explicit_retry: + self._add_action_func(self._run_retry) + self._add_action_func(self._post_retry_check_actions) + + def _post_retry_check_actions(self, retry_state: "RetryCallState") -> None: + if not (self.iter_state.is_explicit_retry or self.iter_state.retry_run_result): + self._add_action_func(lambda rs: rs.outcome.result()) + return + + if self.after is not None: + self._add_action_func(self.after) + + self._add_action_func(self._run_wait) + self._add_action_func(self._run_stop) + self._add_action_func(self._post_stop_check_actions) + + def _post_stop_check_actions(self, retry_state: "RetryCallState") -> None: + if self.iter_state.stop_run_result: if self.retry_error_callback: - return self.retry_error_callback(retry_state) - retry_exc = self.retry_error_cls(fut) - if self.reraise: - raise retry_exc.reraise() - raise retry_exc from fut.exception() + self._add_action_func(self.retry_error_callback) + return + + def exc_check(rs: "RetryCallState") -> None: + fut = t.cast(Future, rs.outcome) + retry_exc = self.retry_error_cls(fut) + if self.reraise: + raise retry_exc.reraise() + raise retry_exc from fut.exception() + + self._add_action_func(exc_check) + return + + def next_action(rs: "RetryCallState") -> None: + sleep = rs.upcoming_sleep + rs.next_action = RetryAction(sleep) + rs.idle_for += sleep + self.statistics["idle_for"] += sleep + self.statistics["attempt_number"] += 1 - retry_state.next_action = RetryAction(sleep) - retry_state.idle_for += sleep - self.statistics["idle_for"] += sleep - self.statistics["attempt_number"] += 1 + self._add_action_func(next_action) if self.before_sleep is not None: - self.before_sleep(retry_state) + self._add_action_func(self.before_sleep) - return DoSleep(sleep) + self._add_action_func(lambda rs: DoSleep(rs.upcoming_sleep)) def __iter__(self) -> t.Generator[AttemptManager, None, None]: self.begin()