Skip to content

Commit

Permalink
Rollup merge of rust-lang#65226 - ssomers:master, r=bluss
Browse files Browse the repository at this point in the history
BTreeSet symmetric_difference & union optimized

No scalability changes, but:
- Grew the cmp_opt function (shared by symmetric_difference & union) into a MergeIter, with less memory overhead than the pairs of Peekable iterators now, speeding up ~20% on my machine (not so clear on Travis though, I actually switched it off there because it wasn't consistent about identical code). Mainly meant to improve readability by sharing code, though it does end up using more lines of code. Extending and reusing the MergeIter in btree_map might be better, but I'm not sure that's possible or desirable. This MergeIter probably pretends to be more generic than it is, yet doesn't declare to be an iterator because there's no need to, it's only there to help construct genuine iterators SymmetricDifference & Union.
- Compact the code of rust-lang#64820 by moving if/else into match guards.

r? @bluss
  • Loading branch information
Centril authored Oct 19, 2019
2 parents 7e4ff91 + 5697432 commit a6b5c80
Show file tree
Hide file tree
Showing 2 changed files with 144 additions and 121 deletions.
239 changes: 119 additions & 120 deletions src/liballoc/collections/btree/set.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
// to TreeMap

use core::borrow::Borrow;
use core::cmp::Ordering::{self, Less, Greater, Equal};
use core::cmp::Ordering::{Less, Greater, Equal};
use core::cmp::{max, min};
use core::fmt::{self, Debug};
use core::iter::{Peekable, FromIterator, FusedIterator};
Expand Down Expand Up @@ -109,6 +109,77 @@ pub struct Range<'a, T: 'a> {
iter: btree_map::Range<'a, T, ()>,
}

/// Core of SymmetricDifference and Union.
/// More efficient than btree.map.MergeIter,
/// and crucially for SymmetricDifference, nexts() reports on both sides.
#[derive(Clone)]
struct MergeIterInner<I>
where I: Iterator,
I::Item: Copy,
{
a: I,
b: I,
peeked: Option<MergeIterPeeked<I>>,
}

#[derive(Copy, Clone, Debug)]
enum MergeIterPeeked<I: Iterator> {
A(I::Item),
B(I::Item),
}

impl<I> MergeIterInner<I>
where I: ExactSizeIterator + FusedIterator,
I::Item: Copy + Ord,
{
fn new(a: I, b: I) -> Self {
MergeIterInner { a, b, peeked: None }
}

fn nexts(&mut self) -> (Option<I::Item>, Option<I::Item>) {
let mut a_next = match self.peeked {
Some(MergeIterPeeked::A(next)) => Some(next),
_ => self.a.next(),
};
let mut b_next = match self.peeked {
Some(MergeIterPeeked::B(next)) => Some(next),
_ => self.b.next(),
};
let ord = match (a_next, b_next) {
(None, None) => Equal,
(_, None) => Less,
(None, _) => Greater,
(Some(a1), Some(b1)) => a1.cmp(&b1),
};
self.peeked = match ord {
Less => b_next.take().map(MergeIterPeeked::B),
Equal => None,
Greater => a_next.take().map(MergeIterPeeked::A),
};
(a_next, b_next)
}

fn lens(&self) -> (usize, usize) {
match self.peeked {
Some(MergeIterPeeked::A(_)) => (1 + self.a.len(), self.b.len()),
Some(MergeIterPeeked::B(_)) => (self.a.len(), 1 + self.b.len()),
_ => (self.a.len(), self.b.len()),
}
}
}

impl<I> Debug for MergeIterInner<I>
where I: Iterator + Debug,
I::Item: Copy + Debug,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_tuple("MergeIterInner")
.field(&self.a)
.field(&self.b)
.finish()
}
}

/// A lazy iterator producing elements in the difference of `BTreeSet`s.
///
/// This `struct` is created by the [`difference`] method on [`BTreeSet`].
Expand All @@ -120,6 +191,7 @@ pub struct Range<'a, T: 'a> {
pub struct Difference<'a, T: 'a> {
inner: DifferenceInner<'a, T>,
}
#[derive(Debug)]
enum DifferenceInner<'a, T: 'a> {
Stitch {
// iterate all of self and some of other, spotting matches along the way
Expand All @@ -137,21 +209,7 @@ enum DifferenceInner<'a, T: 'a> {
#[stable(feature = "collection_debug", since = "1.17.0")]
impl<T: fmt::Debug> fmt::Debug for Difference<'_, T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match &self.inner {
DifferenceInner::Stitch {
self_iter,
other_iter,
} => f
.debug_tuple("Difference")
.field(&self_iter)
.field(&other_iter)
.finish(),
DifferenceInner::Search {
self_iter,
other_set: _,
} => f.debug_tuple("Difference").field(&self_iter).finish(),
DifferenceInner::Iterate(iter) => f.debug_tuple("Difference").field(&iter).finish(),
}
f.debug_tuple("Difference").field(&self.inner).finish()
}
}

Expand All @@ -163,18 +221,12 @@ impl<T: fmt::Debug> fmt::Debug for Difference<'_, T> {
/// [`BTreeSet`]: struct.BTreeSet.html
/// [`symmetric_difference`]: struct.BTreeSet.html#method.symmetric_difference
#[stable(feature = "rust1", since = "1.0.0")]
pub struct SymmetricDifference<'a, T: 'a> {
a: Peekable<Iter<'a, T>>,
b: Peekable<Iter<'a, T>>,
}
pub struct SymmetricDifference<'a, T: 'a>(MergeIterInner<Iter<'a, T>>);

#[stable(feature = "collection_debug", since = "1.17.0")]
impl<T: fmt::Debug> fmt::Debug for SymmetricDifference<'_, T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_tuple("SymmetricDifference")
.field(&self.a)
.field(&self.b)
.finish()
f.debug_tuple("SymmetricDifference").field(&self.0).finish()
}
}

Expand All @@ -189,6 +241,7 @@ impl<T: fmt::Debug> fmt::Debug for SymmetricDifference<'_, T> {
pub struct Intersection<'a, T: 'a> {
inner: IntersectionInner<'a, T>,
}
#[derive(Debug)]
enum IntersectionInner<'a, T: 'a> {
Stitch {
// iterate similarly sized sets jointly, spotting matches along the way
Expand All @@ -206,23 +259,7 @@ enum IntersectionInner<'a, T: 'a> {
#[stable(feature = "collection_debug", since = "1.17.0")]
impl<T: fmt::Debug> fmt::Debug for Intersection<'_, T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match &self.inner {
IntersectionInner::Stitch {
a,
b,
} => f
.debug_tuple("Intersection")
.field(&a)
.field(&b)
.finish(),
IntersectionInner::Search {
small_iter,
large_set: _,
} => f.debug_tuple("Intersection").field(&small_iter).finish(),
IntersectionInner::Answer(answer) => {
f.debug_tuple("Intersection").field(&answer).finish()
}
}
f.debug_tuple("Intersection").field(&self.inner).finish()
}
}

Expand All @@ -234,18 +271,12 @@ impl<T: fmt::Debug> fmt::Debug for Intersection<'_, T> {
/// [`BTreeSet`]: struct.BTreeSet.html
/// [`union`]: struct.BTreeSet.html#method.union
#[stable(feature = "rust1", since = "1.0.0")]
pub struct Union<'a, T: 'a> {
a: Peekable<Iter<'a, T>>,
b: Peekable<Iter<'a, T>>,
}
pub struct Union<'a, T: 'a>(MergeIterInner<Iter<'a, T>>);

#[stable(feature = "collection_debug", since = "1.17.0")]
impl<T: fmt::Debug> fmt::Debug for Union<'_, T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_tuple("Union")
.field(&self.a)
.field(&self.b)
.finish()
f.debug_tuple("Union").field(&self.0).finish()
}
}

Expand Down Expand Up @@ -355,19 +386,16 @@ impl<T: Ord> BTreeSet<T> {
self_iter.next_back();
DifferenceInner::Iterate(self_iter)
}
_ => {
if self.len() <= other.len() / ITER_PERFORMANCE_TIPPING_SIZE_DIFF {
DifferenceInner::Search {
self_iter: self.iter(),
other_set: other,
}
} else {
DifferenceInner::Stitch {
self_iter: self.iter(),
other_iter: other.iter().peekable(),
}
_ if self.len() <= other.len() / ITER_PERFORMANCE_TIPPING_SIZE_DIFF => {
DifferenceInner::Search {
self_iter: self.iter(),
other_set: other,
}
}
_ => DifferenceInner::Stitch {
self_iter: self.iter(),
other_iter: other.iter().peekable(),
},
},
}
}
Expand Down Expand Up @@ -396,10 +424,7 @@ impl<T: Ord> BTreeSet<T> {
pub fn symmetric_difference<'a>(&'a self,
other: &'a BTreeSet<T>)
-> SymmetricDifference<'a, T> {
SymmetricDifference {
a: self.iter().peekable(),
b: other.iter().peekable(),
}
SymmetricDifference(MergeIterInner::new(self.iter(), other.iter()))
}

/// Visits the values representing the intersection,
Expand Down Expand Up @@ -447,24 +472,22 @@ impl<T: Ord> BTreeSet<T> {
(Greater, _) | (_, Less) => IntersectionInner::Answer(None),
(Equal, _) => IntersectionInner::Answer(Some(self_min)),
(_, Equal) => IntersectionInner::Answer(Some(self_max)),
_ => {
if self.len() <= other.len() / ITER_PERFORMANCE_TIPPING_SIZE_DIFF {
IntersectionInner::Search {
small_iter: self.iter(),
large_set: other,
}
} else if other.len() <= self.len() / ITER_PERFORMANCE_TIPPING_SIZE_DIFF {
IntersectionInner::Search {
small_iter: other.iter(),
large_set: self,
}
} else {
IntersectionInner::Stitch {
a: self.iter(),
b: other.iter(),
}
_ if self.len() <= other.len() / ITER_PERFORMANCE_TIPPING_SIZE_DIFF => {
IntersectionInner::Search {
small_iter: self.iter(),
large_set: other,
}
}
_ if other.len() <= self.len() / ITER_PERFORMANCE_TIPPING_SIZE_DIFF => {
IntersectionInner::Search {
small_iter: other.iter(),
large_set: self,
}
}
_ => IntersectionInner::Stitch {
a: self.iter(),
b: other.iter(),
},
},
}
}
Expand All @@ -489,10 +512,7 @@ impl<T: Ord> BTreeSet<T> {
/// ```
#[stable(feature = "rust1", since = "1.0.0")]
pub fn union<'a>(&'a self, other: &'a BTreeSet<T>) -> Union<'a, T> {
Union {
a: self.iter().peekable(),
b: other.iter().peekable(),
}
Union(MergeIterInner::new(self.iter(), other.iter()))
}

/// Clears the set, removing all values.
Expand Down Expand Up @@ -1166,15 +1186,6 @@ impl<'a, T> DoubleEndedIterator for Range<'a, T> {
#[stable(feature = "fused", since = "1.26.0")]
impl<T> FusedIterator for Range<'_, T> {}

/// Compares `x` and `y`, but return `short` if x is None and `long` if y is None
fn cmp_opt<T: Ord>(x: Option<&T>, y: Option<&T>, short: Ordering, long: Ordering) -> Ordering {
match (x, y) {
(None, _) => short,
(_, None) => long,
(Some(x1), Some(y1)) => x1.cmp(y1),
}
}

#[stable(feature = "rust1", since = "1.0.0")]
impl<T> Clone for Difference<'_, T> {
fn clone(&self) -> Self {
Expand Down Expand Up @@ -1261,10 +1272,7 @@ impl<T: Ord> FusedIterator for Difference<'_, T> {}
#[stable(feature = "rust1", since = "1.0.0")]
impl<T> Clone for SymmetricDifference<'_, T> {
fn clone(&self) -> Self {
SymmetricDifference {
a: self.a.clone(),
b: self.b.clone(),
}
SymmetricDifference(self.0.clone())
}
}
#[stable(feature = "rust1", since = "1.0.0")]
Expand All @@ -1273,19 +1281,19 @@ impl<'a, T: Ord> Iterator for SymmetricDifference<'a, T> {

fn next(&mut self) -> Option<&'a T> {
loop {
match cmp_opt(self.a.peek(), self.b.peek(), Greater, Less) {
Less => return self.a.next(),
Equal => {
self.a.next();
self.b.next();
}
Greater => return self.b.next(),
let (a_next, b_next) = self.0.nexts();
if a_next.and(b_next).is_none() {
return a_next.or(b_next);
}
}
}

fn size_hint(&self) -> (usize, Option<usize>) {
(0, Some(self.a.len() + self.b.len()))
let (a_len, b_len) = self.0.lens();
// No checked_add, because even if a and b refer to the same set,
// and T is an empty type, the storage overhead of sets limits
// the number of elements to less than half the range of usize.
(0, Some(a_len + b_len))
}
}

Expand All @@ -1311,7 +1319,7 @@ impl<T> Clone for Intersection<'_, T> {
small_iter: small_iter.clone(),
large_set,
},
IntersectionInner::Answer(answer) => IntersectionInner::Answer(answer.clone()),
IntersectionInner::Answer(answer) => IntersectionInner::Answer(*answer),
},
}
}
Expand Down Expand Up @@ -1365,30 +1373,21 @@ impl<T: Ord> FusedIterator for Intersection<'_, T> {}
#[stable(feature = "rust1", since = "1.0.0")]
impl<T> Clone for Union<'_, T> {
fn clone(&self) -> Self {
Union {
a: self.a.clone(),
b: self.b.clone(),
}
Union(self.0.clone())
}
}
#[stable(feature = "rust1", since = "1.0.0")]
impl<'a, T: Ord> Iterator for Union<'a, T> {
type Item = &'a T;

fn next(&mut self) -> Option<&'a T> {
match cmp_opt(self.a.peek(), self.b.peek(), Greater, Less) {
Less => self.a.next(),
Equal => {
self.b.next();
self.a.next()
}
Greater => self.b.next(),
}
let (a_next, b_next) = self.0.nexts();
a_next.or(b_next)
}

fn size_hint(&self) -> (usize, Option<usize>) {
let a_len = self.a.len();
let b_len = self.b.len();
let (a_len, b_len) = self.0.lens();
// No checked_add - see SymmetricDifference::size_hint.
(max(a_len, b_len), Some(a_len + b_len))
}
}
Expand Down
Loading

0 comments on commit a6b5c80

Please sign in to comment.