Skip to content

Commit

Permalink
Auto merge of rust-lang#74024 - Folyd:master, r=m-ou-se
Browse files Browse the repository at this point in the history
Improve slice.binary_search_by()'s best-case performance to O(1)

This PR aimed to improve the [slice.binary_search_by()](https://doc.rust-lang.org/std/primitive.slice.html#method.binary_search_by)'s best-case performance to O(1).

# Noticed

I don't know why the docs of `binary_search_by` said `"If there are multiple matches, then any one of the matches could be returned."`, but the implementation isn't the same thing. Actually, it returns the **last one** if multiple matches found.

Then we got two options:

## If returns the last one is the correct or desired result

Then I can rectify the docs and revert my changes.

## If the docs are correct or desired result

Then my changes can be merged after fully reviewed.

However, if my PR gets merged, another issue raised: this could be a **breaking change** since if multiple matches found, the returning order no longer the last one instead of it could be any one.

For example:
```rust
let mut s = vec![0, 1, 1, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55];
let num = 1;
let idx = s.binary_search(&num);
s.insert(idx, 2);

// Old implementations
assert_eq!(s, [0, 1, 1, 1, 1, 2, 2, 3, 5, 8, 13, 21, 34, 42, 55]);

// New implementations
assert_eq!(s, [0, 1, 1, 1, 2, 1, 2, 3, 5, 8, 13, 21, 34, 42, 55]);
```

# Benchmarking

**Old implementations**
```sh
$ ./x.py bench --stage 1 library/libcore
test slice::binary_search_l1           ... bench:          59 ns/iter (+/- 4)
test slice::binary_search_l1_with_dups ... bench:          59 ns/iter (+/- 3)
test slice::binary_search_l2           ... bench:          76 ns/iter (+/- 5)
test slice::binary_search_l2_with_dups ... bench:          77 ns/iter (+/- 17)
test slice::binary_search_l3           ... bench:         183 ns/iter (+/- 23)
test slice::binary_search_l3_with_dups ... bench:         185 ns/iter (+/- 19)
```

**New implementations (1)**

Implemented by this PR.
```rust
if cmp == Equal {
    return Ok(mid);
} else if cmp == Less {
    base = mid
}
```
```sh
$ ./x.py bench --stage 1 library/libcore
test slice::binary_search_l1           ... bench:          58 ns/iter (+/- 2)
test slice::binary_search_l1_with_dups ... bench:          37 ns/iter (+/- 4)
test slice::binary_search_l2           ... bench:          76 ns/iter (+/- 3)
test slice::binary_search_l2_with_dups ... bench:          57 ns/iter (+/- 6)
test slice::binary_search_l3           ... bench:         200 ns/iter (+/- 30)
test slice::binary_search_l3_with_dups ... bench:         157 ns/iter (+/- 6)

$ ./x.py bench --stage 1 library/libcore
test slice::binary_search_l1           ... bench:          59 ns/iter (+/- 8)
test slice::binary_search_l1_with_dups ... bench:          37 ns/iter (+/- 2)
test slice::binary_search_l2           ... bench:          77 ns/iter (+/- 2)
test slice::binary_search_l2_with_dups ... bench:          57 ns/iter (+/- 2)
test slice::binary_search_l3           ... bench:         198 ns/iter (+/- 21)
test slice::binary_search_l3_with_dups ... bench:         158 ns/iter (+/- 11)

```

**New implementations (2)**

Suggested by `@nbdd0121` in [comment](rust-lang#74024 (comment)).
```rust
base = if cmp == Greater { base } else { mid };
if cmp == Equal { break }
```

```sh
$ ./x.py bench --stage 1 library/libcore
test slice::binary_search_l1           ... bench:          59 ns/iter (+/- 7)
test slice::binary_search_l1_with_dups ... bench:          37 ns/iter (+/- 5)
test slice::binary_search_l2           ... bench:          75 ns/iter (+/- 3)
test slice::binary_search_l2_with_dups ... bench:          56 ns/iter (+/- 3)
test slice::binary_search_l3           ... bench:         195 ns/iter (+/- 15)
test slice::binary_search_l3_with_dups ... bench:         151 ns/iter (+/- 7)

$ ./x.py bench --stage 1 library/libcore
test slice::binary_search_l1           ... bench:          57 ns/iter (+/- 2)
test slice::binary_search_l1_with_dups ... bench:          38 ns/iter (+/- 2)
test slice::binary_search_l2           ... bench:          77 ns/iter (+/- 11)
test slice::binary_search_l2_with_dups ... bench:          57 ns/iter (+/- 4)
test slice::binary_search_l3           ... bench:         194 ns/iter (+/- 15)
test slice::binary_search_l3_with_dups ... bench:         151 ns/iter (+/- 18)

```

I run some benchmarking testings against on two implementations. The new implementation has a lot of improvement in duplicates cases, while in `binary_search_l3` case, it's a little bit slower than the old one.
  • Loading branch information
bors committed Mar 5, 2021
2 parents 8fd946c + 3eb5bee commit caca212
Show file tree
Hide file tree
Showing 3 changed files with 80 additions and 29 deletions.
44 changes: 38 additions & 6 deletions library/core/benches/slice.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,21 @@ enum Cache {
L3,
}

impl Cache {
fn size(&self) -> usize {
match self {
Cache::L1 => 1000, // 8kb
Cache::L2 => 10_000, // 80kb
Cache::L3 => 1_000_000, // 8Mb
}
}
}

fn binary_search<F>(b: &mut Bencher, cache: Cache, mapper: F)
where
F: Fn(usize) -> usize,
{
let size = match cache {
Cache::L1 => 1000, // 8kb
Cache::L2 => 10_000, // 80kb
Cache::L3 => 1_000_000, // 8Mb
};
let size = cache.size();
let v = (0..size).map(&mapper).collect::<Vec<_>>();
let mut r = 0usize;
b.iter(move || {
Expand All @@ -24,7 +30,18 @@ where
// Lookup the whole range to get 50% hits and 50% misses.
let i = mapper(r % size);
black_box(v.binary_search(&i).is_ok());
})
});
}

fn binary_search_worst_case(b: &mut Bencher, cache: Cache) {
let size = cache.size();

let mut v = vec![0; size];
let i = 1;
v[size - 1] = i;
b.iter(move || {
black_box(v.binary_search(&i).is_ok());
});
}

#[bench]
Expand Down Expand Up @@ -57,6 +74,21 @@ fn binary_search_l3_with_dups(b: &mut Bencher) {
binary_search(b, Cache::L3, |i| i / 16 * 16);
}

#[bench]
fn binary_search_l1_worst_case(b: &mut Bencher) {
binary_search_worst_case(b, Cache::L1);
}

#[bench]
fn binary_search_l2_worst_case(b: &mut Bencher) {
binary_search_worst_case(b, Cache::L2);
}

#[bench]
fn binary_search_l3_worst_case(b: &mut Bencher) {
binary_search_worst_case(b, Cache::L3);
}

macro_rules! rotate {
($fn:ident, $n:expr, $mapper:expr) => {
#[bench]
Expand Down
44 changes: 25 additions & 19 deletions library/core/src/slice/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
#![stable(feature = "rust1", since = "1.0.0")]

use crate::cmp::Ordering::{self, Equal, Greater, Less};
use crate::cmp::Ordering::{self, Greater, Less};
use crate::marker::Copy;
use crate::mem;
use crate::num::NonZeroUsize;
Expand Down Expand Up @@ -2185,25 +2185,31 @@ impl<T> [T] {
where
F: FnMut(&'a T) -> Ordering,
{
let s = self;
let mut size = s.len();
if size == 0 {
return Err(0);
}
let mut base = 0usize;
while size > 1 {
let half = size / 2;
let mid = base + half;
// SAFETY: the call is made safe by the following inconstants:
// - `mid >= 0`: by definition
// - `mid < size`: `mid = size / 2 + size / 4 + size / 8 ...`
let cmp = f(unsafe { s.get_unchecked(mid) });
base = if cmp == Greater { base } else { mid };
size -= half;
let mut size = self.len();
let mut left = 0;
let mut right = size;
while left < right {
let mid = left + size / 2;

// SAFETY: the call is made safe by the following invariants:
// - `mid >= 0`
// - `mid < size`: `mid` is limited by `[left; right)` bound.
let cmp = f(unsafe { self.get_unchecked(mid) });

// The reason why we use if/else control flow rather than match
// is because match reorders comparison operations, which is perf sensitive.
// This is x86 asm for u8: https://rust.godbolt.org/z/8Y8Pra.
if cmp == Less {
left = mid + 1;
} else if cmp == Greater {
right = mid;
} else {
return Ok(mid);
}

size = right - left;
}
// SAFETY: base is always in [0, size) because base <= mid.
let cmp = f(unsafe { s.get_unchecked(base) });
if cmp == Equal { Ok(base) } else { Err(base + (cmp == Less) as usize) }
Err(left)
}

/// Binary searches this sorted slice with a key extraction function.
Expand Down
21 changes: 17 additions & 4 deletions library/core/tests/slice.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use core::cell::Cell;
use core::cmp::Ordering;
use core::result::Result::{Err, Ok};

#[test]
Expand Down Expand Up @@ -64,6 +65,17 @@ fn test_binary_search() {
assert_eq!(b.binary_search(&6), Err(4));
assert_eq!(b.binary_search(&7), Ok(4));
assert_eq!(b.binary_search(&8), Err(5));

let b = [(); usize::MAX];
assert_eq!(b.binary_search(&()), Ok(usize::MAX / 2));
}

#[test]
fn test_binary_search_by_overflow() {
let b = [(); usize::MAX];
assert_eq!(b.binary_search_by(|_| Ordering::Equal), Ok(usize::MAX / 2));
assert_eq!(b.binary_search_by(|_| Ordering::Greater), Err(0));
assert_eq!(b.binary_search_by(|_| Ordering::Less), Err(usize::MAX));
}

#[test]
Expand All @@ -73,13 +85,13 @@ fn test_binary_search_implementation_details() {
let b = [1, 1, 2, 2, 3, 3, 3];
assert_eq!(b.binary_search(&1), Ok(1));
assert_eq!(b.binary_search(&2), Ok(3));
assert_eq!(b.binary_search(&3), Ok(6));
assert_eq!(b.binary_search(&3), Ok(5));
let b = [1, 1, 1, 1, 1, 3, 3, 3, 3];
assert_eq!(b.binary_search(&1), Ok(4));
assert_eq!(b.binary_search(&3), Ok(8));
assert_eq!(b.binary_search(&3), Ok(7));
let b = [1, 1, 1, 1, 3, 3, 3, 3, 3];
assert_eq!(b.binary_search(&1), Ok(3));
assert_eq!(b.binary_search(&3), Ok(8));
assert_eq!(b.binary_search(&1), Ok(2));
assert_eq!(b.binary_search(&3), Ok(4));
}

#[test]
Expand Down Expand Up @@ -1982,6 +1994,7 @@ fn test_copy_within_panics_dest_too_long() {
// The length is only 13, so a slice of length 4 starting at index 10 is out of bounds.
bytes.copy_within(0..4, 10);
}

#[test]
#[should_panic(expected = "slice index starts at 2 but ends at 1")]
fn test_copy_within_panics_src_inverted() {
Expand Down

0 comments on commit caca212

Please sign in to comment.