diff --git a/src/iterator.rs b/src/iterator.rs index 01cf6f9..e55e0ca 100644 --- a/src/iterator.rs +++ b/src/iterator.rs @@ -97,14 +97,11 @@ mod test { // Check that each entry doesn't exist // Check that every number is "hit" (as they'd have to be) for a perfect bijection // Check that the number is within range - let mut set = HashSet::new(); + let mut set = HashSet::with_capacity(length.get() as usize); for elem in it { - let set_result = set.get(&elem); - // Make sure there are no duplicates - assert!(set_result.is_none()); - set.insert(elem); + assert!(set.insert(elem)); } // Need to dereference the types into regular integers let mut result: Vec = set.into_iter().collect(); diff --git a/src/kensler.rs b/src/kensler.rs index 58e6517..490fcba 100644 --- a/src/kensler.rs +++ b/src/kensler.rs @@ -7,7 +7,7 @@ use crate::error::{PermutationError, PermutationResult}; #[cfg(feature = "use-rand")] use rand::prelude::*; -use std::num::NonZeroU32; +use std::num::{NonZeroU32, Wrapping}; /// The `HashedPermutation` struct stores the initial `seed` and `length` of the permutation /// vector. In other words, if you want to shuffle the numbers from `0..n`, then `length = n`. @@ -56,17 +56,12 @@ impl HashedPermutation { max_shuffle: self.length.get(), }); } - let mut i = input; + let mut i = Wrapping(input); let n = self.length.get(); - let seed = self.seed; - let mut w = n - 1; - w |= w >> 1; - w |= w >> 2; - w |= w >> 4; - w |= w >> 8; - w |= w >> 16; - - while i >= n { + let seed = Wrapping(self.seed); + let w = Wrapping(n.checked_next_power_of_two().map_or(u32::MAX, |x| x - 1)); + + while i.0 >= n { i ^= seed; i *= 0xe170893d; i ^= seed >> 16; @@ -75,7 +70,7 @@ impl HashedPermutation { i *= 0x0929eb3f; i ^= seed >> 23; i ^= (i & w) >> 1; - i *= 1 | seed >> 27; + i *= Wrapping(1) | seed >> 27; i *= 0x6935fa69; i ^= (i & w) >> 11; i *= 0x74dcb303; @@ -86,7 +81,7 @@ impl HashedPermutation { i &= w; i ^= i >> 5; } - Ok((i + seed) % n) + Ok((i + seed).0 % n) } } @@ -143,19 +138,15 @@ mod test { // Check that each entry doesn't exist // Check that every number is "hit" (as they'd have to be) for a perfect bijection // Check that the number is within range - let mut map = HashMap::new(); + let mut map = HashMap::with_capacity(length.get() as usize); for i in 0..perm.length.get() { let res = perm.shuffle(i); let res = res.unwrap(); - let map_result = map.get(&res); - assert!(map_result.is_none()); - map.insert(res, i); + assert!(map.insert(res, i).is_none()); } - // Need to dereference the types into regular integers - let mut keys_vec: Vec = map.keys().into_iter().map(|k| *k).collect(); + let (mut keys_vec, mut vals_vec): (Vec, Vec) = map.iter().unzip(); keys_vec.sort(); - let mut vals_vec: Vec = map.values().into_iter().map(|v| *v).collect(); vals_vec.sort(); let ground_truth: Vec = (0..length.get()).collect(); assert_eq!(ground_truth, keys_vec);