diff --git a/library/alloc/src/collections/binary_heap.rs b/library/alloc/src/collections/binary_heap.rs index 8a36b2af76522..33bd98d467cec 100644 --- a/library/alloc/src/collections/binary_heap.rs +++ b/library/alloc/src/collections/binary_heap.rs @@ -275,7 +275,8 @@ impl fmt::Debug for PeekMut<'_, T> { impl Drop for PeekMut<'_, T> { fn drop(&mut self) { if self.sift { - self.heap.sift_down(0); + // SAFETY: PeekMut is only instantiated for non-empty heaps. + unsafe { self.heap.sift_down(0) }; } } } @@ -431,7 +432,8 @@ impl BinaryHeap { self.data.pop().map(|mut item| { if !self.is_empty() { swap(&mut item, &mut self.data[0]); - self.sift_down_to_bottom(0); + // SAFETY: !self.is_empty() means that self.len() > 0 + unsafe { self.sift_down_to_bottom(0) }; } item }) @@ -473,7 +475,9 @@ impl BinaryHeap { pub fn push(&mut self, item: T) { let old_len = self.len(); self.data.push(item); - self.sift_up(0, old_len); + // SAFETY: Since we pushed a new item it means that + // old_len = self.len() - 1 < self.len() + unsafe { self.sift_up(0, old_len) }; } /// Consumes the `BinaryHeap` and returns a vector in sorted @@ -506,7 +510,10 @@ impl BinaryHeap { let ptr = self.data.as_mut_ptr(); ptr::swap(ptr, ptr.add(end)); } - self.sift_down_range(0, end); + // SAFETY: `end` goes from `self.len() - 1` to 1 (both included) so: + // 0 < 1 <= end <= self.len() - 1 < self.len() + // Which means 0 < end and end < self.len(). + unsafe { self.sift_down_range(0, end) }; } self.into_vec() } @@ -519,47 +526,84 @@ impl BinaryHeap { // the hole is filled back at the end of its scope, even on panic. // Using a hole reduces the constant factor compared to using swaps, // which involves twice as many moves. - fn sift_up(&mut self, start: usize, pos: usize) -> usize { - unsafe { - // Take out the value at `pos` and create a hole. - let mut hole = Hole::new(&mut self.data, pos); - - while hole.pos() > start { - let parent = (hole.pos() - 1) / 2; - if hole.element() <= hole.get(parent) { - break; - } - hole.move_to(parent); + + /// # Safety + /// + /// The caller must guarantee that `pos < self.len()`. + unsafe fn sift_up(&mut self, start: usize, pos: usize) -> usize { + // Take out the value at `pos` and create a hole. + // SAFETY: The caller guarantees that pos < self.len() + let mut hole = unsafe { Hole::new(&mut self.data, pos) }; + + while hole.pos() > start { + let parent = (hole.pos() - 1) / 2; + + // SAFETY: hole.pos() > start >= 0, which means hole.pos() > 0 + // and so hole.pos() - 1 can't underflow. + // This guarantees that parent < hole.pos() so + // it's a valid index and also != hole.pos(). + if hole.element() <= unsafe { hole.get(parent) } { + break; } - hole.pos() + + // SAFETY: Same as above + unsafe { hole.move_to(parent) }; } + + hole.pos() } /// Take an element at `pos` and move it down the heap, /// while its children are larger. - fn sift_down_range(&mut self, pos: usize, end: usize) { - unsafe { - let mut hole = Hole::new(&mut self.data, pos); - let mut child = 2 * pos + 1; - while child < end - 1 { - // compare with the greater of the two children - child += (hole.get(child) <= hole.get(child + 1)) as usize; - // if we are already in order, stop. - if hole.element() >= hole.get(child) { - return; - } - hole.move_to(child); - child = 2 * hole.pos() + 1; - } - if child == end - 1 && hole.element() < hole.get(child) { - hole.move_to(child); + /// + /// # Safety + /// + /// The caller must guarantee that `pos < end <= self.len()`. + unsafe fn sift_down_range(&mut self, pos: usize, end: usize) { + // SAFETY: The caller guarantees that pos < end <= self.len(). + let mut hole = unsafe { Hole::new(&mut self.data, pos) }; + let mut child = 2 * hole.pos() + 1; + + // Loop invariant: child == 2 * hole.pos() + 1. + while child < end - 1 { + // compare with the greater of the two children + // SAFETY: child < end - 1 < self.len() and + // child + 1 < end <= self.len(), so they're valid indexes. + // child == 2 * hole.pos() + 1 != hole.pos() and + // child + 1 == 2 * hole.pos() + 2 != hole.pos(). + // FIXME: 2 * hole.pos() + 1 or 2 * hole.pos() + 2 could overflow + // if T is a ZST + child += unsafe { hole.get(child) <= hole.get(child + 1) } as usize; + + // if we are already in order, stop. + // SAFETY: child is now either the old child or the old child+1 + // We already proven that both are < self.len() and != hole.pos() + if hole.element() >= unsafe { hole.get(child) } { + return; } + + // SAFETY: same as above. + unsafe { hole.move_to(child) }; + child = 2 * hole.pos() + 1; + } + + // SAFETY: && short circuit, which means that in the + // second condition it's already true that child == end - 1 < self.len(). + if child == end - 1 && hole.element() < unsafe { hole.get(child) } { + // SAFETY: child is already proven to be a valid index and + // child == 2 * hole.pos() + 1 != hole.pos(). + unsafe { hole.move_to(child) }; } } - fn sift_down(&mut self, pos: usize) { + /// # Safety + /// + /// The caller must guarantee that `pos < self.len()`. + unsafe fn sift_down(&mut self, pos: usize) { let len = self.len(); - self.sift_down_range(pos, len); + // SAFETY: pos < len is guaranteed by the caller and + // obviously len = self.len() <= self.len(). + unsafe { self.sift_down_range(pos, len) }; } /// Take an element at `pos` and move it all the way down the heap, @@ -567,30 +611,54 @@ impl BinaryHeap { /// /// Note: This is faster when the element is known to be large / should /// be closer to the bottom. - fn sift_down_to_bottom(&mut self, mut pos: usize) { + /// + /// # Safety + /// + /// The caller must guarantee that `pos < self.len()`. + unsafe fn sift_down_to_bottom(&mut self, mut pos: usize) { let end = self.len(); let start = pos; - unsafe { - let mut hole = Hole::new(&mut self.data, pos); - let mut child = 2 * pos + 1; - while child < end - 1 { - child += (hole.get(child) <= hole.get(child + 1)) as usize; - hole.move_to(child); - child = 2 * hole.pos() + 1; - } - if child == end - 1 { - hole.move_to(child); - } - pos = hole.pos; + + // SAFETY: The caller guarantees that pos < self.len(). + let mut hole = unsafe { Hole::new(&mut self.data, pos) }; + let mut child = 2 * hole.pos() + 1; + + // Loop invariant: child == 2 * hole.pos() + 1. + while child < end - 1 { + // SAFETY: child < end - 1 < self.len() and + // child + 1 < end <= self.len(), so they're valid indexes. + // child == 2 * hole.pos() + 1 != hole.pos() and + // child + 1 == 2 * hole.pos() + 2 != hole.pos(). + // FIXME: 2 * hole.pos() + 1 or 2 * hole.pos() + 2 could overflow + // if T is a ZST + child += unsafe { hole.get(child) <= hole.get(child + 1) } as usize; + + // SAFETY: Same as above + unsafe { hole.move_to(child) }; + child = 2 * hole.pos() + 1; } - self.sift_up(start, pos); + + if child == end - 1 { + // SAFETY: child == end - 1 < self.len(), so it's a valid index + // and child == 2 * hole.pos() + 1 != hole.pos(). + unsafe { hole.move_to(child) }; + } + pos = hole.pos(); + drop(hole); + + // SAFETY: pos is the position in the hole and was already proven + // to be a valid index. + unsafe { self.sift_up(start, pos) }; } fn rebuild(&mut self) { let mut n = self.len() / 2; while n > 0 { n -= 1; - self.sift_down(n); + // SAFETY: n starts from self.len() / 2 and goes down to 0. + // The only case when !(n < self.len()) is if + // self.len() == 0, but it's ruled out by the loop condition. + unsafe { self.sift_down(n) }; } }