diff --git a/src/blocking_retry.rs b/src/blocking_retry.rs index 9f5038f..6141889 100644 --- a/src/blocking_retry.rs +++ b/src/blocking_retry.rs @@ -58,13 +58,19 @@ where } /// Retry struct generated by [`Retryable`]. -pub struct BlockingRetry Result> { +pub struct BlockingRetry< + B: Backoff, + T, + E, + F: FnMut() -> Result, + RF = fn(&E) -> bool, + NF = fn(&E, Duration), +> { backoff: B, - retryable: fn(&E) -> bool, - notify: fn(&E, Duration), + retryable: RF, + notify: NF, f: F, } - impl BlockingRetry where B: Backoff, @@ -79,7 +85,15 @@ where f, } } +} +impl BlockingRetry +where + B: Backoff, + F: FnMut() -> Result, + RF: FnMut(&E) -> bool, + NF: FnMut(&E, Duration), +{ /// Set the conditions for retrying. /// /// If not specified, we treat all errors as retryable. @@ -105,9 +119,13 @@ where /// Ok(()) /// } /// ``` - pub fn when(mut self, retryable: fn(&E) -> bool) -> Self { - self.retryable = retryable; - self + pub fn when bool>(self, retryable: RN) -> BlockingRetry { + BlockingRetry { + backoff: self.backoff, + retryable, + notify: self.notify, + f: self.f, + } } /// Set to notify for everything retrying. @@ -140,9 +158,13 @@ where /// Ok(()) /// } /// ``` - pub fn notify(mut self, notify: fn(&E, Duration)) -> Self { - self.notify = notify; - self + pub fn notify(self, notify: NN) -> BlockingRetry { + BlockingRetry { + backoff: self.backoff, + retryable: self.retryable, + notify, + f: self.f, + } } /// Call the retried function. @@ -245,4 +267,32 @@ mod tests { assert_eq!(*error_times.lock().unwrap(), 4); Ok(()) } + + #[test] + fn test_fn_mut_when_and_notify() -> anyhow::Result<()> { + let mut calls_retryable: Vec<()> = vec![]; + let mut calls_notify: Vec<()> = vec![]; + + let f = || Err::<(), anyhow::Error>(anyhow::anyhow!("retryable")); + + let backoff = ExponentialBuilder::default().with_min_delay(Duration::from_millis(1)); + let result = f + .retry(&backoff) + .when(|_| { + calls_retryable.push(()); + true + }) + .notify(|_, _| { + calls_notify.push(()); + }) + .call(); + + assert!(result.is_err()); + assert_eq!("retryable", result.unwrap_err().to_string()); + // `f` always returns error "retryable", so it should be executed + // 4 times (retry 3 times). + assert_eq!(calls_retryable.len(), 4); + assert_eq!(calls_notify.len(), 3); + Ok(()) + } } diff --git a/src/retry.rs b/src/retry.rs index c7e8734..232c7ba 100644 --- a/src/retry.rs +++ b/src/retry.rs @@ -85,10 +85,18 @@ where /// Retry struct generated by [`Retryable`]. #[pin_project] -pub struct Retry>, FutureFn: FnMut() -> Fut> { +pub struct Retry< + B: Backoff, + T, + E, + Fut: Future>, + FutureFn: FnMut() -> Fut, + RF = fn(&E) -> bool, + NF = fn(&E, Duration), +> { backoff: B, - retryable: fn(&E) -> bool, - notify: fn(&E, Duration), + retryable: RF, + notify: NF, future_fn: FutureFn, #[pin] @@ -111,7 +119,16 @@ where state: State::Idle, } } +} +impl Retry +where + B: Backoff, + Fut: Future>, + FutureFn: FnMut() -> Fut, + RF: FnMut(&E) -> bool, + NF: FnMut(&E, Duration), +{ /// Set the conditions for retrying. /// /// If not specified, we treat all errors as retryable. @@ -141,9 +158,17 @@ where /// Ok(()) /// } /// ``` - pub fn when(mut self, retryable: fn(&E) -> bool) -> Self { - self.retryable = retryable; - self + pub fn when bool>( + self, + retryable: RN, + ) -> Retry { + Retry { + backoff: self.backoff, + retryable, + notify: self.notify, + future_fn: self.future_fn, + state: self.state, + } } /// Set to notify for everything retrying. @@ -179,9 +204,17 @@ where /// Ok(()) /// } /// ``` - pub fn notify(mut self, notify: fn(&E, Duration)) -> Self { - self.notify = notify; - self + pub fn notify( + self, + notify: NN, + ) -> Retry { + Retry { + backoff: self.backoff, + retryable: self.retryable, + notify, + future_fn: self.future_fn, + state: self.state, + } } } @@ -201,11 +234,13 @@ enum State>> { Sleeping(#[pin] Pin>), } -impl Future for Retry +impl Future for Retry where B: Backoff, Fut: Future>, FutureFn: FnMut() -> Fut, + RF: FnMut(&E) -> bool, + NF: FnMut(&E, Duration), { type Output = Result; @@ -320,4 +355,32 @@ mod tests { assert_eq!(*error_times.lock().await, 4); Ok(()) } + + #[tokio::test] + async fn test_fn_mut_when_and_notify() -> anyhow::Result<()> { + let mut calls_retryable: Vec<()> = vec![]; + let mut calls_notify: Vec<()> = vec![]; + + let f = || async { Err::<(), anyhow::Error>(anyhow::anyhow!("retryable")) }; + + let backoff = ExponentialBuilder::default().with_min_delay(Duration::from_millis(1)); + let result = f + .retry(&backoff) + .when(|_| { + calls_retryable.push(()); + true + }) + .notify(|_, _| { + calls_notify.push(()); + }) + .await; + + assert!(result.is_err()); + assert_eq!("retryable", result.unwrap_err().to_string()); + // `f` always returns error "retryable", so it should be executed + // 4 times (retry 3 times). + assert_eq!(calls_retryable.len(), 4); + assert_eq!(calls_notify.len(), 3); + Ok(()) + } }