From 56974329d1eab8dd990d53b6cd6978bbf6a615b7 Mon Sep 17 00:00:00 2001 From: Stein Somers Date: Wed, 9 Oct 2019 01:07:57 +0200 Subject: [PATCH] BTreeSet symmetric_difference & union optimized, cleaned --- src/liballoc/collections/btree/set.rs | 239 +++++++++++++------------- src/liballoc/tests/btree/set.rs | 26 ++- 2 files changed, 144 insertions(+), 121 deletions(-) diff --git a/src/liballoc/collections/btree/set.rs b/src/liballoc/collections/btree/set.rs index 8250fc38ccd1c..f0796354e00c3 100644 --- a/src/liballoc/collections/btree/set.rs +++ b/src/liballoc/collections/btree/set.rs @@ -2,7 +2,7 @@ // to TreeMap use core::borrow::Borrow; -use core::cmp::Ordering::{self, Less, Greater, Equal}; +use core::cmp::Ordering::{Less, Greater, Equal}; use core::cmp::{max, min}; use core::fmt::{self, Debug}; use core::iter::{Peekable, FromIterator, FusedIterator}; @@ -109,6 +109,77 @@ pub struct Range<'a, T: 'a> { iter: btree_map::Range<'a, T, ()>, } +/// Core of SymmetricDifference and Union. +/// More efficient than btree.map.MergeIter, +/// and crucially for SymmetricDifference, nexts() reports on both sides. +#[derive(Clone)] +struct MergeIterInner + where I: Iterator, + I::Item: Copy, +{ + a: I, + b: I, + peeked: Option>, +} + +#[derive(Copy, Clone, Debug)] +enum MergeIterPeeked { + A(I::Item), + B(I::Item), +} + +impl MergeIterInner + where I: ExactSizeIterator + FusedIterator, + I::Item: Copy + Ord, +{ + fn new(a: I, b: I) -> Self { + MergeIterInner { a, b, peeked: None } + } + + fn nexts(&mut self) -> (Option, Option) { + let mut a_next = match self.peeked { + Some(MergeIterPeeked::A(next)) => Some(next), + _ => self.a.next(), + }; + let mut b_next = match self.peeked { + Some(MergeIterPeeked::B(next)) => Some(next), + _ => self.b.next(), + }; + let ord = match (a_next, b_next) { + (None, None) => Equal, + (_, None) => Less, + (None, _) => Greater, + (Some(a1), Some(b1)) => a1.cmp(&b1), + }; + self.peeked = match ord { + Less => b_next.take().map(MergeIterPeeked::B), + Equal => None, + Greater => a_next.take().map(MergeIterPeeked::A), + }; + (a_next, b_next) + } + + fn lens(&self) -> (usize, usize) { + match self.peeked { + Some(MergeIterPeeked::A(_)) => (1 + self.a.len(), self.b.len()), + Some(MergeIterPeeked::B(_)) => (self.a.len(), 1 + self.b.len()), + _ => (self.a.len(), self.b.len()), + } + } +} + +impl Debug for MergeIterInner + where I: Iterator + Debug, + I::Item: Copy + Debug, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_tuple("MergeIterInner") + .field(&self.a) + .field(&self.b) + .finish() + } +} + /// A lazy iterator producing elements in the difference of `BTreeSet`s. /// /// This `struct` is created by the [`difference`] method on [`BTreeSet`]. @@ -120,6 +191,7 @@ pub struct Range<'a, T: 'a> { pub struct Difference<'a, T: 'a> { inner: DifferenceInner<'a, T>, } +#[derive(Debug)] enum DifferenceInner<'a, T: 'a> { Stitch { // iterate all of self and some of other, spotting matches along the way @@ -137,21 +209,7 @@ enum DifferenceInner<'a, T: 'a> { #[stable(feature = "collection_debug", since = "1.17.0")] impl fmt::Debug for Difference<'_, T> { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match &self.inner { - DifferenceInner::Stitch { - self_iter, - other_iter, - } => f - .debug_tuple("Difference") - .field(&self_iter) - .field(&other_iter) - .finish(), - DifferenceInner::Search { - self_iter, - other_set: _, - } => f.debug_tuple("Difference").field(&self_iter).finish(), - DifferenceInner::Iterate(iter) => f.debug_tuple("Difference").field(&iter).finish(), - } + f.debug_tuple("Difference").field(&self.inner).finish() } } @@ -163,18 +221,12 @@ impl fmt::Debug for Difference<'_, T> { /// [`BTreeSet`]: struct.BTreeSet.html /// [`symmetric_difference`]: struct.BTreeSet.html#method.symmetric_difference #[stable(feature = "rust1", since = "1.0.0")] -pub struct SymmetricDifference<'a, T: 'a> { - a: Peekable>, - b: Peekable>, -} +pub struct SymmetricDifference<'a, T: 'a>(MergeIterInner>); #[stable(feature = "collection_debug", since = "1.17.0")] impl fmt::Debug for SymmetricDifference<'_, T> { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_tuple("SymmetricDifference") - .field(&self.a) - .field(&self.b) - .finish() + f.debug_tuple("SymmetricDifference").field(&self.0).finish() } } @@ -189,6 +241,7 @@ impl fmt::Debug for SymmetricDifference<'_, T> { pub struct Intersection<'a, T: 'a> { inner: IntersectionInner<'a, T>, } +#[derive(Debug)] enum IntersectionInner<'a, T: 'a> { Stitch { // iterate similarly sized sets jointly, spotting matches along the way @@ -206,23 +259,7 @@ enum IntersectionInner<'a, T: 'a> { #[stable(feature = "collection_debug", since = "1.17.0")] impl fmt::Debug for Intersection<'_, T> { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match &self.inner { - IntersectionInner::Stitch { - a, - b, - } => f - .debug_tuple("Intersection") - .field(&a) - .field(&b) - .finish(), - IntersectionInner::Search { - small_iter, - large_set: _, - } => f.debug_tuple("Intersection").field(&small_iter).finish(), - IntersectionInner::Answer(answer) => { - f.debug_tuple("Intersection").field(&answer).finish() - } - } + f.debug_tuple("Intersection").field(&self.inner).finish() } } @@ -234,18 +271,12 @@ impl fmt::Debug for Intersection<'_, T> { /// [`BTreeSet`]: struct.BTreeSet.html /// [`union`]: struct.BTreeSet.html#method.union #[stable(feature = "rust1", since = "1.0.0")] -pub struct Union<'a, T: 'a> { - a: Peekable>, - b: Peekable>, -} +pub struct Union<'a, T: 'a>(MergeIterInner>); #[stable(feature = "collection_debug", since = "1.17.0")] impl fmt::Debug for Union<'_, T> { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_tuple("Union") - .field(&self.a) - .field(&self.b) - .finish() + f.debug_tuple("Union").field(&self.0).finish() } } @@ -355,19 +386,16 @@ impl BTreeSet { self_iter.next_back(); DifferenceInner::Iterate(self_iter) } - _ => { - if self.len() <= other.len() / ITER_PERFORMANCE_TIPPING_SIZE_DIFF { - DifferenceInner::Search { - self_iter: self.iter(), - other_set: other, - } - } else { - DifferenceInner::Stitch { - self_iter: self.iter(), - other_iter: other.iter().peekable(), - } + _ if self.len() <= other.len() / ITER_PERFORMANCE_TIPPING_SIZE_DIFF => { + DifferenceInner::Search { + self_iter: self.iter(), + other_set: other, } } + _ => DifferenceInner::Stitch { + self_iter: self.iter(), + other_iter: other.iter().peekable(), + }, }, } } @@ -396,10 +424,7 @@ impl BTreeSet { pub fn symmetric_difference<'a>(&'a self, other: &'a BTreeSet) -> SymmetricDifference<'a, T> { - SymmetricDifference { - a: self.iter().peekable(), - b: other.iter().peekable(), - } + SymmetricDifference(MergeIterInner::new(self.iter(), other.iter())) } /// Visits the values representing the intersection, @@ -447,24 +472,22 @@ impl BTreeSet { (Greater, _) | (_, Less) => IntersectionInner::Answer(None), (Equal, _) => IntersectionInner::Answer(Some(self_min)), (_, Equal) => IntersectionInner::Answer(Some(self_max)), - _ => { - if self.len() <= other.len() / ITER_PERFORMANCE_TIPPING_SIZE_DIFF { - IntersectionInner::Search { - small_iter: self.iter(), - large_set: other, - } - } else if other.len() <= self.len() / ITER_PERFORMANCE_TIPPING_SIZE_DIFF { - IntersectionInner::Search { - small_iter: other.iter(), - large_set: self, - } - } else { - IntersectionInner::Stitch { - a: self.iter(), - b: other.iter(), - } + _ if self.len() <= other.len() / ITER_PERFORMANCE_TIPPING_SIZE_DIFF => { + IntersectionInner::Search { + small_iter: self.iter(), + large_set: other, + } + } + _ if other.len() <= self.len() / ITER_PERFORMANCE_TIPPING_SIZE_DIFF => { + IntersectionInner::Search { + small_iter: other.iter(), + large_set: self, } } + _ => IntersectionInner::Stitch { + a: self.iter(), + b: other.iter(), + }, }, } } @@ -489,10 +512,7 @@ impl BTreeSet { /// ``` #[stable(feature = "rust1", since = "1.0.0")] pub fn union<'a>(&'a self, other: &'a BTreeSet) -> Union<'a, T> { - Union { - a: self.iter().peekable(), - b: other.iter().peekable(), - } + Union(MergeIterInner::new(self.iter(), other.iter())) } /// Clears the set, removing all values. @@ -1166,15 +1186,6 @@ impl<'a, T> DoubleEndedIterator for Range<'a, T> { #[stable(feature = "fused", since = "1.26.0")] impl FusedIterator for Range<'_, T> {} -/// Compares `x` and `y`, but return `short` if x is None and `long` if y is None -fn cmp_opt(x: Option<&T>, y: Option<&T>, short: Ordering, long: Ordering) -> Ordering { - match (x, y) { - (None, _) => short, - (_, None) => long, - (Some(x1), Some(y1)) => x1.cmp(y1), - } -} - #[stable(feature = "rust1", since = "1.0.0")] impl Clone for Difference<'_, T> { fn clone(&self) -> Self { @@ -1261,10 +1272,7 @@ impl FusedIterator for Difference<'_, T> {} #[stable(feature = "rust1", since = "1.0.0")] impl Clone for SymmetricDifference<'_, T> { fn clone(&self) -> Self { - SymmetricDifference { - a: self.a.clone(), - b: self.b.clone(), - } + SymmetricDifference(self.0.clone()) } } #[stable(feature = "rust1", since = "1.0.0")] @@ -1273,19 +1281,19 @@ impl<'a, T: Ord> Iterator for SymmetricDifference<'a, T> { fn next(&mut self) -> Option<&'a T> { loop { - match cmp_opt(self.a.peek(), self.b.peek(), Greater, Less) { - Less => return self.a.next(), - Equal => { - self.a.next(); - self.b.next(); - } - Greater => return self.b.next(), + let (a_next, b_next) = self.0.nexts(); + if a_next.and(b_next).is_none() { + return a_next.or(b_next); } } } fn size_hint(&self) -> (usize, Option) { - (0, Some(self.a.len() + self.b.len())) + let (a_len, b_len) = self.0.lens(); + // No checked_add, because even if a and b refer to the same set, + // and T is an empty type, the storage overhead of sets limits + // the number of elements to less than half the range of usize. + (0, Some(a_len + b_len)) } } @@ -1311,7 +1319,7 @@ impl Clone for Intersection<'_, T> { small_iter: small_iter.clone(), large_set, }, - IntersectionInner::Answer(answer) => IntersectionInner::Answer(answer.clone()), + IntersectionInner::Answer(answer) => IntersectionInner::Answer(*answer), }, } } @@ -1365,10 +1373,7 @@ impl FusedIterator for Intersection<'_, T> {} #[stable(feature = "rust1", since = "1.0.0")] impl Clone for Union<'_, T> { fn clone(&self) -> Self { - Union { - a: self.a.clone(), - b: self.b.clone(), - } + Union(self.0.clone()) } } #[stable(feature = "rust1", since = "1.0.0")] @@ -1376,19 +1381,13 @@ impl<'a, T: Ord> Iterator for Union<'a, T> { type Item = &'a T; fn next(&mut self) -> Option<&'a T> { - match cmp_opt(self.a.peek(), self.b.peek(), Greater, Less) { - Less => self.a.next(), - Equal => { - self.b.next(); - self.a.next() - } - Greater => self.b.next(), - } + let (a_next, b_next) = self.0.nexts(); + a_next.or(b_next) } fn size_hint(&self) -> (usize, Option) { - let a_len = self.a.len(); - let b_len = self.b.len(); + let (a_len, b_len) = self.0.lens(); + // No checked_add - see SymmetricDifference::size_hint. (max(a_len, b_len), Some(a_len + b_len)) } } diff --git a/src/liballoc/tests/btree/set.rs b/src/liballoc/tests/btree/set.rs index 5c611fd21d21b..e4883abc8b56c 100644 --- a/src/liballoc/tests/btree/set.rs +++ b/src/liballoc/tests/btree/set.rs @@ -221,6 +221,18 @@ fn test_symmetric_difference() { &[-2, 1, 5, 11, 14, 22]); } +#[test] +fn test_symmetric_difference_size_hint() { + let x: BTreeSet = [2, 4].iter().copied().collect(); + let y: BTreeSet = [1, 2, 3].iter().copied().collect(); + let mut iter = x.symmetric_difference(&y); + assert_eq!(iter.size_hint(), (0, Some(5))); + assert_eq!(iter.next(), Some(&1)); + assert_eq!(iter.size_hint(), (0, Some(4))); + assert_eq!(iter.next(), Some(&3)); + assert_eq!(iter.size_hint(), (0, Some(1))); +} + #[test] fn test_union() { fn check_union(a: &[i32], b: &[i32], expected: &[i32]) { @@ -235,6 +247,18 @@ fn test_union() { &[-2, 1, 3, 5, 9, 11, 13, 16, 19, 24]); } +#[test] +fn test_union_size_hint() { + let x: BTreeSet = [2, 4].iter().copied().collect(); + let y: BTreeSet = [1, 2, 3].iter().copied().collect(); + let mut iter = x.union(&y); + assert_eq!(iter.size_hint(), (3, Some(5))); + assert_eq!(iter.next(), Some(&1)); + assert_eq!(iter.size_hint(), (2, Some(4))); + assert_eq!(iter.next(), Some(&2)); + assert_eq!(iter.size_hint(), (1, Some(2))); +} + #[test] // Only tests the simple function definition with respect to intersection fn test_is_disjoint() { @@ -244,7 +268,7 @@ fn test_is_disjoint() { } #[test] -// Also tests the trivial function definition of is_superset +// Also implicitly tests the trivial function definition of is_superset fn test_is_subset() { fn is_subset(a: &[i32], b: &[i32]) -> bool { let set_a = a.iter().collect::>();