Skip to content

Commit 1b8cac8

Browse files
authored
Rollup merge of rust-lang#81706 - SkiFire13:document-binaryheap-unsafe, r=Mark-Simulacrum
Document BinaryHeap unsafe functions `BinaryHeap` contains some private safe functions but that are actually unsafe to call. This PR marks them `unsafe` and documents all the `unsafe` function calls inside them. While doing this I might also have found a bug: some "SAFETY" comments in `sift_down_range` and `sift_down_to_bottom` are valid only if you assume that `child` doesn't overflow. However it may overflow if `end > isize::MAX` which can be true for ZSTs (but I think only for them). I guess the easiest fix would be to skip any sifting if `mem::size_of::<T> == 0`. Probably conflicts with rust-lang#81127 but solving the eventual merge conflict should be pretty easy.
2 parents c853094 + 3ec1a28 commit 1b8cac8

File tree

1 file changed

+117
-49
lines changed

1 file changed

+117
-49
lines changed

library/alloc/src/collections/binary_heap.rs

+117-49
Original file line numberDiff line numberDiff line change
@@ -275,7 +275,8 @@ impl<T: Ord + fmt::Debug> fmt::Debug for PeekMut<'_, T> {
275275
impl<T: Ord> Drop for PeekMut<'_, T> {
276276
fn drop(&mut self) {
277277
if self.sift {
278-
self.heap.sift_down(0);
278+
// SAFETY: PeekMut is only instantiated for non-empty heaps.
279+
unsafe { self.heap.sift_down(0) };
279280
}
280281
}
281282
}
@@ -431,7 +432,8 @@ impl<T: Ord> BinaryHeap<T> {
431432
self.data.pop().map(|mut item| {
432433
if !self.is_empty() {
433434
swap(&mut item, &mut self.data[0]);
434-
self.sift_down_to_bottom(0);
435+
// SAFETY: !self.is_empty() means that self.len() > 0
436+
unsafe { self.sift_down_to_bottom(0) };
435437
}
436438
item
437439
})
@@ -473,7 +475,9 @@ impl<T: Ord> BinaryHeap<T> {
473475
pub fn push(&mut self, item: T) {
474476
let old_len = self.len();
475477
self.data.push(item);
476-
self.sift_up(0, old_len);
478+
// SAFETY: Since we pushed a new item it means that
479+
// old_len = self.len() - 1 < self.len()
480+
unsafe { self.sift_up(0, old_len) };
477481
}
478482

479483
/// Consumes the `BinaryHeap` and returns a vector in sorted
@@ -506,7 +510,10 @@ impl<T: Ord> BinaryHeap<T> {
506510
let ptr = self.data.as_mut_ptr();
507511
ptr::swap(ptr, ptr.add(end));
508512
}
509-
self.sift_down_range(0, end);
513+
// SAFETY: `end` goes from `self.len() - 1` to 1 (both included) so:
514+
// 0 < 1 <= end <= self.len() - 1 < self.len()
515+
// Which means 0 < end and end < self.len().
516+
unsafe { self.sift_down_range(0, end) };
510517
}
511518
self.into_vec()
512519
}
@@ -519,78 +526,139 @@ impl<T: Ord> BinaryHeap<T> {
519526
// the hole is filled back at the end of its scope, even on panic.
520527
// Using a hole reduces the constant factor compared to using swaps,
521528
// which involves twice as many moves.
522-
fn sift_up(&mut self, start: usize, pos: usize) -> usize {
523-
unsafe {
524-
// Take out the value at `pos` and create a hole.
525-
let mut hole = Hole::new(&mut self.data, pos);
526-
527-
while hole.pos() > start {
528-
let parent = (hole.pos() - 1) / 2;
529-
if hole.element() <= hole.get(parent) {
530-
break;
531-
}
532-
hole.move_to(parent);
529+
530+
/// # Safety
531+
///
532+
/// The caller must guarantee that `pos < self.len()`.
533+
unsafe fn sift_up(&mut self, start: usize, pos: usize) -> usize {
534+
// Take out the value at `pos` and create a hole.
535+
// SAFETY: The caller guarantees that pos < self.len()
536+
let mut hole = unsafe { Hole::new(&mut self.data, pos) };
537+
538+
while hole.pos() > start {
539+
let parent = (hole.pos() - 1) / 2;
540+
541+
// SAFETY: hole.pos() > start >= 0, which means hole.pos() > 0
542+
// and so hole.pos() - 1 can't underflow.
543+
// This guarantees that parent < hole.pos() so
544+
// it's a valid index and also != hole.pos().
545+
if hole.element() <= unsafe { hole.get(parent) } {
546+
break;
533547
}
534-
hole.pos()
548+
549+
// SAFETY: Same as above
550+
unsafe { hole.move_to(parent) };
535551
}
552+
553+
hole.pos()
536554
}
537555

538556
/// Take an element at `pos` and move it down the heap,
539557
/// while its children are larger.
540-
fn sift_down_range(&mut self, pos: usize, end: usize) {
541-
unsafe {
542-
let mut hole = Hole::new(&mut self.data, pos);
543-
let mut child = 2 * pos + 1;
544-
while child < end - 1 {
545-
// compare with the greater of the two children
546-
child += (hole.get(child) <= hole.get(child + 1)) as usize;
547-
// if we are already in order, stop.
548-
if hole.element() >= hole.get(child) {
549-
return;
550-
}
551-
hole.move_to(child);
552-
child = 2 * hole.pos() + 1;
553-
}
554-
if child == end - 1 && hole.element() < hole.get(child) {
555-
hole.move_to(child);
558+
///
559+
/// # Safety
560+
///
561+
/// The caller must guarantee that `pos < end <= self.len()`.
562+
unsafe fn sift_down_range(&mut self, pos: usize, end: usize) {
563+
// SAFETY: The caller guarantees that pos < end <= self.len().
564+
let mut hole = unsafe { Hole::new(&mut self.data, pos) };
565+
let mut child = 2 * hole.pos() + 1;
566+
567+
// Loop invariant: child == 2 * hole.pos() + 1.
568+
while child < end - 1 {
569+
// compare with the greater of the two children
570+
// SAFETY: child < end - 1 < self.len() and
571+
// child + 1 < end <= self.len(), so they're valid indexes.
572+
// child == 2 * hole.pos() + 1 != hole.pos() and
573+
// child + 1 == 2 * hole.pos() + 2 != hole.pos().
574+
// FIXME: 2 * hole.pos() + 1 or 2 * hole.pos() + 2 could overflow
575+
// if T is a ZST
576+
child += unsafe { hole.get(child) <= hole.get(child + 1) } as usize;
577+
578+
// if we are already in order, stop.
579+
// SAFETY: child is now either the old child or the old child+1
580+
// We already proven that both are < self.len() and != hole.pos()
581+
if hole.element() >= unsafe { hole.get(child) } {
582+
return;
556583
}
584+
585+
// SAFETY: same as above.
586+
unsafe { hole.move_to(child) };
587+
child = 2 * hole.pos() + 1;
588+
}
589+
590+
// SAFETY: && short circuit, which means that in the
591+
// second condition it's already true that child == end - 1 < self.len().
592+
if child == end - 1 && hole.element() < unsafe { hole.get(child) } {
593+
// SAFETY: child is already proven to be a valid index and
594+
// child == 2 * hole.pos() + 1 != hole.pos().
595+
unsafe { hole.move_to(child) };
557596
}
558597
}
559598

560-
fn sift_down(&mut self, pos: usize) {
599+
/// # Safety
600+
///
601+
/// The caller must guarantee that `pos < self.len()`.
602+
unsafe fn sift_down(&mut self, pos: usize) {
561603
let len = self.len();
562-
self.sift_down_range(pos, len);
604+
// SAFETY: pos < len is guaranteed by the caller and
605+
// obviously len = self.len() <= self.len().
606+
unsafe { self.sift_down_range(pos, len) };
563607
}
564608

565609
/// Take an element at `pos` and move it all the way down the heap,
566610
/// then sift it up to its position.
567611
///
568612
/// Note: This is faster when the element is known to be large / should
569613
/// be closer to the bottom.
570-
fn sift_down_to_bottom(&mut self, mut pos: usize) {
614+
///
615+
/// # Safety
616+
///
617+
/// The caller must guarantee that `pos < self.len()`.
618+
unsafe fn sift_down_to_bottom(&mut self, mut pos: usize) {
571619
let end = self.len();
572620
let start = pos;
573-
unsafe {
574-
let mut hole = Hole::new(&mut self.data, pos);
575-
let mut child = 2 * pos + 1;
576-
while child < end - 1 {
577-
child += (hole.get(child) <= hole.get(child + 1)) as usize;
578-
hole.move_to(child);
579-
child = 2 * hole.pos() + 1;
580-
}
581-
if child == end - 1 {
582-
hole.move_to(child);
583-
}
584-
pos = hole.pos;
621+
622+
// SAFETY: The caller guarantees that pos < self.len().
623+
let mut hole = unsafe { Hole::new(&mut self.data, pos) };
624+
let mut child = 2 * hole.pos() + 1;
625+
626+
// Loop invariant: child == 2 * hole.pos() + 1.
627+
while child < end - 1 {
628+
// SAFETY: child < end - 1 < self.len() and
629+
// child + 1 < end <= self.len(), so they're valid indexes.
630+
// child == 2 * hole.pos() + 1 != hole.pos() and
631+
// child + 1 == 2 * hole.pos() + 2 != hole.pos().
632+
// FIXME: 2 * hole.pos() + 1 or 2 * hole.pos() + 2 could overflow
633+
// if T is a ZST
634+
child += unsafe { hole.get(child) <= hole.get(child + 1) } as usize;
635+
636+
// SAFETY: Same as above
637+
unsafe { hole.move_to(child) };
638+
child = 2 * hole.pos() + 1;
585639
}
586-
self.sift_up(start, pos);
640+
641+
if child == end - 1 {
642+
// SAFETY: child == end - 1 < self.len(), so it's a valid index
643+
// and child == 2 * hole.pos() + 1 != hole.pos().
644+
unsafe { hole.move_to(child) };
645+
}
646+
pos = hole.pos();
647+
drop(hole);
648+
649+
// SAFETY: pos is the position in the hole and was already proven
650+
// to be a valid index.
651+
unsafe { self.sift_up(start, pos) };
587652
}
588653

589654
fn rebuild(&mut self) {
590655
let mut n = self.len() / 2;
591656
while n > 0 {
592657
n -= 1;
593-
self.sift_down(n);
658+
// SAFETY: n starts from self.len() / 2 and goes down to 0.
659+
// The only case when !(n < self.len()) is if
660+
// self.len() == 0, but it's ruled out by the loop condition.
661+
unsafe { self.sift_down(n) };
594662
}
595663
}
596664

0 commit comments

Comments
 (0)