@@ -275,7 +275,8 @@ impl<T: Ord + fmt::Debug> fmt::Debug for PeekMut<'_, T> {
275
275
impl < T : Ord > Drop for PeekMut < ' _ , T > {
276
276
fn drop ( & mut self ) {
277
277
if self . sift {
278
- self . heap . sift_down ( 0 ) ;
278
+ // SAFETY: PeekMut is only instantiated for non-empty heaps.
279
+ unsafe { self . heap . sift_down ( 0 ) } ;
279
280
}
280
281
}
281
282
}
@@ -431,7 +432,8 @@ impl<T: Ord> BinaryHeap<T> {
431
432
self . data . pop ( ) . map ( |mut item| {
432
433
if !self . is_empty ( ) {
433
434
swap ( & mut item, & mut self . data [ 0 ] ) ;
434
- self . sift_down_to_bottom ( 0 ) ;
435
+ // SAFETY: !self.is_empty() means that self.len() > 0
436
+ unsafe { self . sift_down_to_bottom ( 0 ) } ;
435
437
}
436
438
item
437
439
} )
@@ -473,7 +475,9 @@ impl<T: Ord> BinaryHeap<T> {
473
475
pub fn push ( & mut self , item : T ) {
474
476
let old_len = self . len ( ) ;
475
477
self . data . push ( item) ;
476
- self . sift_up ( 0 , old_len) ;
478
+ // SAFETY: Since we pushed a new item it means that
479
+ // old_len = self.len() - 1 < self.len()
480
+ unsafe { self . sift_up ( 0 , old_len) } ;
477
481
}
478
482
479
483
/// Consumes the `BinaryHeap` and returns a vector in sorted
@@ -506,7 +510,10 @@ impl<T: Ord> BinaryHeap<T> {
506
510
let ptr = self . data . as_mut_ptr ( ) ;
507
511
ptr:: swap ( ptr, ptr. add ( end) ) ;
508
512
}
509
- self . sift_down_range ( 0 , end) ;
513
+ // SAFETY: `end` goes from `self.len() - 1` to 1 (both included) so:
514
+ // 0 < 1 <= end <= self.len() - 1 < self.len()
515
+ // Which means 0 < end and end < self.len().
516
+ unsafe { self . sift_down_range ( 0 , end) } ;
510
517
}
511
518
self . into_vec ( )
512
519
}
@@ -519,78 +526,139 @@ impl<T: Ord> BinaryHeap<T> {
519
526
// the hole is filled back at the end of its scope, even on panic.
520
527
// Using a hole reduces the constant factor compared to using swaps,
521
528
// which involves twice as many moves.
522
- fn sift_up ( & mut self , start : usize , pos : usize ) -> usize {
523
- unsafe {
524
- // Take out the value at `pos` and create a hole.
525
- let mut hole = Hole :: new ( & mut self . data , pos) ;
526
-
527
- while hole. pos ( ) > start {
528
- let parent = ( hole. pos ( ) - 1 ) / 2 ;
529
- if hole. element ( ) <= hole. get ( parent) {
530
- break ;
531
- }
532
- hole. move_to ( parent) ;
529
+
530
+ /// # Safety
531
+ ///
532
+ /// The caller must guarantee that `pos < self.len()`.
533
+ unsafe fn sift_up ( & mut self , start : usize , pos : usize ) -> usize {
534
+ // Take out the value at `pos` and create a hole.
535
+ // SAFETY: The caller guarantees that pos < self.len()
536
+ let mut hole = unsafe { Hole :: new ( & mut self . data , pos) } ;
537
+
538
+ while hole. pos ( ) > start {
539
+ let parent = ( hole. pos ( ) - 1 ) / 2 ;
540
+
541
+ // SAFETY: hole.pos() > start >= 0, which means hole.pos() > 0
542
+ // and so hole.pos() - 1 can't underflow.
543
+ // This guarantees that parent < hole.pos() so
544
+ // it's a valid index and also != hole.pos().
545
+ if hole. element ( ) <= unsafe { hole. get ( parent) } {
546
+ break ;
533
547
}
534
- hole. pos ( )
548
+
549
+ // SAFETY: Same as above
550
+ unsafe { hole. move_to ( parent) } ;
535
551
}
552
+
553
+ hole. pos ( )
536
554
}
537
555
538
556
/// Take an element at `pos` and move it down the heap,
539
557
/// while its children are larger.
540
- fn sift_down_range ( & mut self , pos : usize , end : usize ) {
541
- unsafe {
542
- let mut hole = Hole :: new ( & mut self . data , pos) ;
543
- let mut child = 2 * pos + 1 ;
544
- while child < end - 1 {
545
- // compare with the greater of the two children
546
- child += ( hole. get ( child) <= hole. get ( child + 1 ) ) as usize ;
547
- // if we are already in order, stop.
548
- if hole. element ( ) >= hole. get ( child) {
549
- return ;
550
- }
551
- hole. move_to ( child) ;
552
- child = 2 * hole. pos ( ) + 1 ;
553
- }
554
- if child == end - 1 && hole. element ( ) < hole. get ( child) {
555
- hole. move_to ( child) ;
558
+ ///
559
+ /// # Safety
560
+ ///
561
+ /// The caller must guarantee that `pos < end <= self.len()`.
562
+ unsafe fn sift_down_range ( & mut self , pos : usize , end : usize ) {
563
+ // SAFETY: The caller guarantees that pos < end <= self.len().
564
+ let mut hole = unsafe { Hole :: new ( & mut self . data , pos) } ;
565
+ let mut child = 2 * hole. pos ( ) + 1 ;
566
+
567
+ // Loop invariant: child == 2 * hole.pos() + 1.
568
+ while child < end - 1 {
569
+ // compare with the greater of the two children
570
+ // SAFETY: child < end - 1 < self.len() and
571
+ // child + 1 < end <= self.len(), so they're valid indexes.
572
+ // child == 2 * hole.pos() + 1 != hole.pos() and
573
+ // child + 1 == 2 * hole.pos() + 2 != hole.pos().
574
+ // FIXME: 2 * hole.pos() + 1 or 2 * hole.pos() + 2 could overflow
575
+ // if T is a ZST
576
+ child += unsafe { hole. get ( child) <= hole. get ( child + 1 ) } as usize ;
577
+
578
+ // if we are already in order, stop.
579
+ // SAFETY: child is now either the old child or the old child+1
580
+ // We already proven that both are < self.len() and != hole.pos()
581
+ if hole. element ( ) >= unsafe { hole. get ( child) } {
582
+ return ;
556
583
}
584
+
585
+ // SAFETY: same as above.
586
+ unsafe { hole. move_to ( child) } ;
587
+ child = 2 * hole. pos ( ) + 1 ;
588
+ }
589
+
590
+ // SAFETY: && short circuit, which means that in the
591
+ // second condition it's already true that child == end - 1 < self.len().
592
+ if child == end - 1 && hole. element ( ) < unsafe { hole. get ( child) } {
593
+ // SAFETY: child is already proven to be a valid index and
594
+ // child == 2 * hole.pos() + 1 != hole.pos().
595
+ unsafe { hole. move_to ( child) } ;
557
596
}
558
597
}
559
598
560
- fn sift_down ( & mut self , pos : usize ) {
599
+ /// # Safety
600
+ ///
601
+ /// The caller must guarantee that `pos < self.len()`.
602
+ unsafe fn sift_down ( & mut self , pos : usize ) {
561
603
let len = self . len ( ) ;
562
- self . sift_down_range ( pos, len) ;
604
+ // SAFETY: pos < len is guaranteed by the caller and
605
+ // obviously len = self.len() <= self.len().
606
+ unsafe { self . sift_down_range ( pos, len) } ;
563
607
}
564
608
565
609
/// Take an element at `pos` and move it all the way down the heap,
566
610
/// then sift it up to its position.
567
611
///
568
612
/// Note: This is faster when the element is known to be large / should
569
613
/// be closer to the bottom.
570
- fn sift_down_to_bottom ( & mut self , mut pos : usize ) {
614
+ ///
615
+ /// # Safety
616
+ ///
617
+ /// The caller must guarantee that `pos < self.len()`.
618
+ unsafe fn sift_down_to_bottom ( & mut self , mut pos : usize ) {
571
619
let end = self . len ( ) ;
572
620
let start = pos;
573
- unsafe {
574
- let mut hole = Hole :: new ( & mut self . data , pos) ;
575
- let mut child = 2 * pos + 1 ;
576
- while child < end - 1 {
577
- child += ( hole. get ( child) <= hole. get ( child + 1 ) ) as usize ;
578
- hole. move_to ( child) ;
579
- child = 2 * hole. pos ( ) + 1 ;
580
- }
581
- if child == end - 1 {
582
- hole. move_to ( child) ;
583
- }
584
- pos = hole. pos ;
621
+
622
+ // SAFETY: The caller guarantees that pos < self.len().
623
+ let mut hole = unsafe { Hole :: new ( & mut self . data , pos) } ;
624
+ let mut child = 2 * hole. pos ( ) + 1 ;
625
+
626
+ // Loop invariant: child == 2 * hole.pos() + 1.
627
+ while child < end - 1 {
628
+ // SAFETY: child < end - 1 < self.len() and
629
+ // child + 1 < end <= self.len(), so they're valid indexes.
630
+ // child == 2 * hole.pos() + 1 != hole.pos() and
631
+ // child + 1 == 2 * hole.pos() + 2 != hole.pos().
632
+ // FIXME: 2 * hole.pos() + 1 or 2 * hole.pos() + 2 could overflow
633
+ // if T is a ZST
634
+ child += unsafe { hole. get ( child) <= hole. get ( child + 1 ) } as usize ;
635
+
636
+ // SAFETY: Same as above
637
+ unsafe { hole. move_to ( child) } ;
638
+ child = 2 * hole. pos ( ) + 1 ;
585
639
}
586
- self . sift_up ( start, pos) ;
640
+
641
+ if child == end - 1 {
642
+ // SAFETY: child == end - 1 < self.len(), so it's a valid index
643
+ // and child == 2 * hole.pos() + 1 != hole.pos().
644
+ unsafe { hole. move_to ( child) } ;
645
+ }
646
+ pos = hole. pos ( ) ;
647
+ drop ( hole) ;
648
+
649
+ // SAFETY: pos is the position in the hole and was already proven
650
+ // to be a valid index.
651
+ unsafe { self . sift_up ( start, pos) } ;
587
652
}
588
653
589
654
fn rebuild ( & mut self ) {
590
655
let mut n = self . len ( ) / 2 ;
591
656
while n > 0 {
592
657
n -= 1 ;
593
- self . sift_down ( n) ;
658
+ // SAFETY: n starts from self.len() / 2 and goes down to 0.
659
+ // The only case when !(n < self.len()) is if
660
+ // self.len() == 0, but it's ruled out by the loop condition.
661
+ unsafe { self . sift_down ( n) } ;
594
662
}
595
663
}
596
664
0 commit comments