diff --git a/src/combinations.rs b/src/combinations.rs index 68a59c5e4..57c440c5c 100644 --- a/src/combinations.rs +++ b/src/combinations.rs @@ -1,7 +1,8 @@ use std::fmt; -use std::iter::FusedIterator; +use std::iter::{Fuse, FusedIterator}; use super::lazy_buffer::LazyBuffer; +use super::size_hint::{self, SizeHint}; use alloc::vec::Vec; /// An iterator to iterate through all the `k`-length combinations in an iterator. @@ -52,9 +53,15 @@ impl Combinations { #[inline] pub fn n(&self) -> usize { self.pool.len() } + /// Fill the pool to get its length. + pub(crate) fn real_n(&mut self) -> usize { + while self.pool.get_next() {} + self.pool.len() + } + /// Returns a reference to the source iterator. #[inline] - pub(crate) fn src(&self) -> &I { &self.pool.it } + pub(crate) fn src(&self) -> &Fuse { &self.pool.it } /// Resets this `Combinations` back to an initial state for combinations of length /// `k` over the same pool data source. If `k` is larger than the current length @@ -77,6 +84,20 @@ impl Combinations { self.pool.prefill(k); } } + + fn remaining_for(&self, n: usize) -> Option { + let k = self.k(); + if self.first { + binomial(n, k) + } else { + self.indices + .iter() + .enumerate() + .fold(Some(0), |sum, (k0, n0)| { + sum.and_then(|s| s.checked_add(binomial(n - 1 - *n0, k - k0)?)) + }) + } + } } impl Iterator for Combinations @@ -120,9 +141,32 @@ impl Iterator for Combinations // Create result vector based on the indices Some(self.indices.iter().map(|i| self.pool[*i].clone()).collect()) } + + fn size_hint(&self) -> SizeHint { + size_hint::try_map(self.pool.size_hint(), |n| self.remaining_for(n)) + } + + fn count(mut self) -> usize { + let n = self.real_n(); + self.remaining_for(n).expect("Iterator count greater than usize::MAX") + } } impl FusedIterator for Combinations where I: Iterator, I::Item: Clone {} + +pub(crate) fn binomial(mut n: usize, mut k: usize) -> Option { + if n < k { + return Some(0); + } + // n! / (n - k)! / k! but trying to avoid it overflows: + k = (n - k).min(k); + let mut c = 1; + for i in 1..=k { + c = (c / i).checked_mul(n)? + c % i * n / i; + n -= 1; + } + Some(c) +} diff --git a/src/combinations_with_replacement.rs b/src/combinations_with_replacement.rs index 0fec9671a..1e5ffb573 100644 --- a/src/combinations_with_replacement.rs +++ b/src/combinations_with_replacement.rs @@ -2,7 +2,9 @@ use alloc::vec::Vec; use std::fmt; use std::iter::FusedIterator; +use super::combinations::binomial; use super::lazy_buffer::LazyBuffer; +use super::size_hint::{self, SizeHint}; /// An iterator to iterate through all the `n`-length combinations in an iterator, with replacement. /// @@ -36,6 +38,21 @@ where fn current(&self) -> Vec { self.indices.iter().map(|i| self.pool[*i].clone()).collect() } + + fn remaining_for(&self, n: usize) -> Option { + let k_perms = |n: usize, k: usize| binomial((n + k).saturating_sub(1), k); + let k = self.indices.len(); + if self.first { + k_perms(n, k) + } else { + self.indices + .iter() + .enumerate() + .fold(Some(0), |sum, (k0, n0)| { + sum.and_then(|s| s.checked_add(k_perms(n - 1 - *n0, k - k0)?)) + }) + } + } } /// Create a new `CombinationsWithReplacement` from a clonable iterator. @@ -100,6 +117,16 @@ where None => None, } } + + fn size_hint(&self) -> SizeHint { + size_hint::try_map(self.pool.size_hint(), |n| self.remaining_for(n)) + } + + fn count(mut self) -> usize { + while self.pool.get_next() {} + let n = self.pool.len(); + self.remaining_for(n).expect("Iterator count greater than usize::MAX") + } } impl FusedIterator for CombinationsWithReplacement diff --git a/src/lazy_buffer.rs b/src/lazy_buffer.rs index ca24062aa..88ee06c7c 100644 --- a/src/lazy_buffer.rs +++ b/src/lazy_buffer.rs @@ -1,10 +1,12 @@ +use std::iter::Fuse; use std::ops::Index; use alloc::vec::Vec; +use crate::size_hint::{self, SizeHint}; + #[derive(Debug, Clone)] pub struct LazyBuffer { - pub it: I, - done: bool, + pub it: Fuse, buffer: Vec, } @@ -14,8 +16,7 @@ where { pub fn new(it: I) -> LazyBuffer { LazyBuffer { - it, - done: false, + it: it.fuse(), buffer: Vec::new(), } } @@ -24,27 +25,24 @@ where self.buffer.len() } + pub fn size_hint(&self) -> SizeHint { + size_hint::add_scalar(self.it.size_hint(), self.len()) + } + pub fn get_next(&mut self) -> bool { - if self.done { - return false; - } if let Some(x) = self.it.next() { self.buffer.push(x); true } else { - self.done = true; false } } pub fn prefill(&mut self, len: usize) { let buffer_len = self.buffer.len(); - - if !self.done && len > buffer_len { + if len > buffer_len { let delta = len - buffer_len; - self.buffer.extend(self.it.by_ref().take(delta)); - self.done = self.buffer.len() < len; } } } diff --git a/src/permutations.rs b/src/permutations.rs index d03b85262..9c29248bf 100644 --- a/src/permutations.rs +++ b/src/permutations.rs @@ -3,6 +3,7 @@ use std::fmt; use std::iter::once; use super::lazy_buffer::LazyBuffer; +use super::size_hint::{self, SizeHint}; /// An iterator adaptor that iterates through all the `k`-permutations of the /// elements from an iterator. @@ -47,11 +48,6 @@ enum CompleteState { } } -enum CompleteStateRemaining { - Known(usize), - Overflow, -} - impl fmt::Debug for Permutations where I: Iterator + fmt::Debug, I::Item: fmt::Debug, @@ -72,14 +68,8 @@ pub fn permutations(iter: I, k: usize) -> Permutations { }; } - let mut enough_vals = true; - - while vals.len() < k { - if !vals.get_next() { - enough_vals = false; - break; - } - } + vals.prefill(k); + let enough_vals = vals.len() == k; let state = if enough_vals { PermutationState::StartUnknownLen { k } @@ -122,42 +112,42 @@ where } fn count(self) -> usize { - fn from_complete(complete_state: CompleteState) -> usize { - match complete_state.remaining() { - CompleteStateRemaining::Known(count) => count, - CompleteStateRemaining::Overflow => { - panic!("Iterator count greater than usize::MAX"); - } - } - } - let Permutations { vals, state } = self; match state { PermutationState::StartUnknownLen { k } => { let n = vals.len() + vals.it.count(); - let complete_state = CompleteState::Start { n, k }; - - from_complete(complete_state) + CompleteState::Start { n, k }.count() } PermutationState::OngoingUnknownLen { k, min_n } => { let prev_iteration_count = min_n - k + 1; let n = vals.len() + vals.it.count(); - let complete_state = CompleteState::Start { n, k }; - - from_complete(complete_state) - prev_iteration_count + CompleteState::Start { n, k }.count() - prev_iteration_count }, - PermutationState::Complete(state) => from_complete(state), + PermutationState::Complete(state) => state.count(), PermutationState::Empty => 0 } } - fn size_hint(&self) -> (usize, Option) { + fn size_hint(&self) -> SizeHint { match self.state { - PermutationState::StartUnknownLen { .. } | - PermutationState::OngoingUnknownLen { .. } => (0, None), // TODO can we improve this lower bound? + // Note: the product for `CompleteState::Start` in `remaining` increases with `n`. + PermutationState::StartUnknownLen { k } => { + size_hint::try_map( + self.vals.size_hint(), + |n| CompleteState::Start { n, k }.remaining(), + ) + } + PermutationState::OngoingUnknownLen { k, min_n } => { + let prev_iteration_count = min_n - k + 1; + size_hint::try_map(self.vals.size_hint(), |n| { + CompleteState::Start { n, k } + .remaining() + .and_then(|count| count.checked_sub(prev_iteration_count)) + }) + } PermutationState::Complete(ref state) => match state.remaining() { - CompleteStateRemaining::Known(count) => (count, Some(count)), - CompleteStateRemaining::Overflow => (::std::usize::MAX, None) + Some(count) => (count, Some(count)), + None => (::std::usize::MAX, None) } PermutationState::Empty => (0, Some(0)) } @@ -185,7 +175,7 @@ where let mut complete_state = CompleteState::Start { n, k }; // Advance the complete-state iterator to the correct point - for _ in 0..(prev_iteration_count + 1) { + for _ in 0..=prev_iteration_count { complete_state.advance(); } @@ -238,40 +228,30 @@ impl CompleteState { } } - fn remaining(&self) -> CompleteStateRemaining { - use self::CompleteStateRemaining::{Known, Overflow}; - + /// The remaining count of elements, if it does not overflow. + fn remaining(&self) -> Option { match *self { CompleteState::Start { n, k } => { if n < k { - return Known(0); + return Some(0); } - - let count: Option = (n - k + 1..n + 1).fold(Some(1), |acc, i| { + (n - k + 1..n + 1).fold(Some(1), |acc, i| { acc.and_then(|acc| acc.checked_mul(i)) - }); - - match count { - Some(count) => Known(count), - None => Overflow - } + }) } CompleteState::Ongoing { ref indices, ref cycles } => { - let mut count: usize = 0; - - for (i, &c) in cycles.iter().enumerate() { - let radix = indices.len() - i; - let next_count = count.checked_mul(radix) - .and_then(|count| count.checked_add(c)); - - count = match next_count { - Some(count) => count, - None => { return Overflow; } - }; - } - - Known(count) + cycles.iter().enumerate().fold(Some(0), |acc, (i, c)| { + acc.and_then(|count| { + let radix = indices.len() - i; + count.checked_mul(radix)?.checked_add(*c) + }) + }) } } } + + /// The remaining count of elements, panics if it overflows. + fn count(&self) -> usize { + self.remaining().expect("Iterator count greater than usize::MAX") + } } diff --git a/src/powerset.rs b/src/powerset.rs index 4d7685b12..b43223462 100644 --- a/src/powerset.rs +++ b/src/powerset.rs @@ -3,7 +3,7 @@ use std::iter::FusedIterator; use std::usize; use alloc::vec::Vec; -use super::combinations::{Combinations, combinations}; +use super::combinations::{Combinations, binomial, combinations}; use super::size_hint; /// An iterator to iterate through the powerset of the elements from an iterator. @@ -81,6 +81,13 @@ impl Iterator for Powerset (0, self_total.1) } } + + fn count(mut self) -> usize { + let k = self.combs.k(); + let n = self.combs.real_n(); + // It could be `(1 << n) - self.pos` but `1 << n` might overflow. + self.combs.count() + (k + 1..=n).map(|k| binomial(n, k).unwrap()).sum::() + } } impl FusedIterator for Powerset diff --git a/src/size_hint.rs b/src/size_hint.rs index 71ea1412b..920964cac 100644 --- a/src/size_hint.rs +++ b/src/size_hint.rs @@ -117,3 +117,17 @@ pub fn min(a: SizeHint, b: SizeHint) -> SizeHint { }; (lower, upper) } + +/// Try to apply a function `f` on both bounds of a `SizeHint`, failure means overflow. +/// +/// For the resulting size hint to be correct, `f` must be increasing. +#[inline] +pub fn try_map(sh: SizeHint, mut f: F) -> SizeHint +where + F: FnMut(usize) -> Option, +{ + let (mut low, mut hi) = sh; + low = f(low).unwrap_or(usize::MAX); + hi = hi.and_then(f); + (low, hi) +} diff --git a/tests/test_std.rs b/tests/test_std.rs index 77207d87e..49b137c24 100644 --- a/tests/test_std.rs +++ b/tests/test_std.rs @@ -909,12 +909,42 @@ fn combinations_zero() { it::assert_equal((0..0).combinations(0), vec![vec![]]); } +#[test] +fn combinations_range_count() { + for n in 0..6 { + for k in 0..=n { + let len = (n - k + 1..=n).product::() / (1..=k).product::(); + let mut it = (0..n).combinations(k); + for count in (0..=len).rev() { + assert_eq!(it.size_hint(), (count, Some(count))); + assert_eq!(it.clone().count(), count); + assert_eq!(it.next().is_none(), count == 0); + } + } + } +} + #[test] fn permutations_zero() { it::assert_equal((1..3).permutations(0), vec![vec![]]); it::assert_equal((0..0).permutations(0), vec![vec![]]); } +#[test] +fn permutations_range_count() { + for n in 0..6 { + for k in 0..=n { + let len: usize = (n - k + 1..=n).product(); + let mut it = (0..n).permutations(k); + for count in (0..=len).rev() { + assert_eq!(it.size_hint(), (count, Some(count))); + assert_eq!(it.clone().count(), count); + assert_eq!(it.next().is_none(), count == 0); + } + } + } +} + #[test] fn combinations_with_replacement() { // Pool smaller than n @@ -948,6 +978,21 @@ fn combinations_with_replacement() { ); } +#[test] +fn combinations_with_replacement_range_count() { + for n in 0..6 { + for k in 0..=n { + let len = (n..n + k).product::() / (1..=k).product::(); + let mut it = (0..n).combinations_with_replacement(k); + for count in (0..=len).rev() { + assert_eq!(it.size_hint(), (count, Some(count))); + assert_eq!(it.clone().count(), count); + assert_eq!(it.next().is_none(), count == 0); + } + } + } +} + #[test] fn powerset() { it::assert_equal((0..0).powerset(), vec![vec![]]); @@ -963,6 +1008,11 @@ fn powerset() { assert_eq!((0..4).powerset().count(), 1 << 4); assert_eq!((0..8).powerset().count(), 1 << 8); assert_eq!((0..16).powerset().count(), 1 << 16); + let mut it = (0..8).powerset(); + for count in (0..=(1 << 8)).rev() { + assert_eq!(it.clone().count(), count); + assert_eq!(it.next().is_none(), count == 0); + } } #[test]