diff --git a/src/exponential.rs b/src/exponential.rs index d9d25fd..8e36ab9 100644 --- a/src/exponential.rs +++ b/src/exponential.rs @@ -139,36 +139,40 @@ impl Iterator for ExponentialBackoff { } self.attempts += 1; - match self.current_delay { + let mut tmp_cur = match self.current_delay { None => { // If current_delay is None, it's must be the first time to retry. - let mut cur = self.min_delay; - self.current_delay = Some(cur); - - // If jitter is enabled, add random jitter based on min delay. - if self.jitter { - cur += self.min_delay.mul_f32(fastrand::f32()); - } - - Some(cur) + self.current_delay = Some(self.min_delay); + self.min_delay } Some(mut cur) => { // If current delay larger than max delay, we should stop increment anymore. if let Some(max_delay) = self.max_delay { if cur < max_delay { - cur = cur.mul_f32(self.factor); + cur = saturating_mul(cur, self.factor); } + if cur > max_delay { + cur = max_delay; + } + } else { + cur = saturating_mul(cur, self.factor); } self.current_delay = Some(cur); - - // If jitter is enabled, add random jitter based on min delay. - if self.jitter { - cur += self.min_delay.mul_f32(fastrand::f32()); - } - - Some(cur) + cur } + }; + // If jitter is enabled, add random jitter based on min delay. + if self.jitter { + tmp_cur = tmp_cur.saturating_add(self.min_delay.mul_f32(fastrand::f32())); } + Some(tmp_cur) + } +} + +pub(crate) fn saturating_mul(d: Duration, rhs: f32) -> Duration { + match Duration::try_from_secs_f32(rhs * d.as_secs_f32()) { + Ok(v) => v, + Err(_) => Duration::MAX, } } @@ -231,7 +235,7 @@ mod tests { } #[test] - fn test_exponential_max_delay() { + fn test_exponential_max_delay_with_default() { let mut exp = ExponentialBuilder::default() .with_max_delay(Duration::from_secs(2)) .build(); @@ -242,6 +246,56 @@ mod tests { assert_eq!(None, exp.next()); } + #[test] + fn test_exponential_max_delay_without_default_1() { + let mut exp = ExponentialBuilder { + jitter: false, + factor: 10_000_000_000_f32, + min_delay: Duration::from_secs(1), + max_delay: None, + max_times: None, + } + .build(); + + assert_eq!(Some(Duration::from_secs(1)), exp.next()); + assert_eq!(Some(Duration::from_secs(10_000_000_000)), exp.next()); + assert_eq!(Some(Duration::MAX), exp.next()); + assert_eq!(Some(Duration::MAX), exp.next()); + } + + #[test] + fn test_exponential_max_delay_without_default_2() { + let mut exp = ExponentialBuilder { + jitter: true, + factor: 10_000_000_000_f32, + min_delay: Duration::from_secs(10_000_000_000), + max_delay: None, + max_times: Some(2), + } + .build(); + let v = exp.next().expect("value must valid"); + assert!(v >= Duration::from_secs(10_000_000_000), "current: {v:?}"); + assert!(v < Duration::from_secs(20_000_000_000), "current: {v:?}"); + assert_eq!(Some(Duration::MAX), exp.next()); + assert_eq!(None, exp.next()); + } + + #[test] + fn test_exponential_max_delay_without_default_3() { + let mut exp = ExponentialBuilder { + jitter: false, + factor: 10_000_000_000_f32, + min_delay: Duration::from_secs(10_000_000_000), + max_delay: Some(Duration::from_secs(60_000_000_000)), + max_times: Some(3), + } + .build(); + assert_eq!(Some(Duration::from_secs(10_000_000_000)), exp.next()); + assert_eq!(Some(Duration::from_secs(60_000_000_000)), exp.next()); + assert_eq!(Some(Duration::from_secs(60_000_000_000)), exp.next()); + assert_eq!(None, exp.next()); + } + #[test] fn test_exponential_max_times() { let mut exp = ExponentialBuilder::default().with_max_times(1).build();