diff --git a/core/src/iter/adapters/filter.rs b/core/src/iter/adapters/filter.rs index ca23d1b13a88b..ba49070329c22 100644 --- a/core/src/iter/adapters/filter.rs +++ b/core/src/iter/adapters/filter.rs @@ -3,7 +3,7 @@ use crate::iter::{adapters::SourceIter, FusedIterator, InPlaceIterable, TrustedF use crate::num::NonZero; use crate::ops::Try; use core::array; -use core::mem::{ManuallyDrop, MaybeUninit}; +use core::mem::MaybeUninit; use core::ops::ControlFlow; /// An iterator that filters the elements of `iter` with `predicate`. @@ -27,6 +27,42 @@ impl Filter { } } +impl Filter +where + I: Iterator, + P: FnMut(&I::Item) -> bool, +{ + #[inline] + fn next_chunk_dropless( + &mut self, + ) -> Result<[I::Item; N], array::IntoIter> { + let mut array: [MaybeUninit; N] = [const { MaybeUninit::uninit() }; N]; + let mut initialized = 0; + + let result = self.iter.try_for_each(|element| { + let idx = initialized; + // branchless index update combined with unconditionally copying the value even when + // it is filtered reduces branching and dependencies in the loop. + initialized = idx + (self.predicate)(&element) as usize; + // SAFETY: Loop conditions ensure the index is in bounds. + unsafe { array.get_unchecked_mut(idx) }.write(element); + + if initialized < N { ControlFlow::Continue(()) } else { ControlFlow::Break(()) } + }); + + match result { + ControlFlow::Break(()) => { + // SAFETY: The loop above is only explicitly broken when the array has been fully initialized + Ok(unsafe { MaybeUninit::array_assume_init(array) }) + } + ControlFlow::Continue(()) => { + // SAFETY: The range is in bounds since the loop breaks when reaching N elements. + Err(unsafe { array::IntoIter::new_unchecked(array, 0..initialized) }) + } + } + } +} + #[stable(feature = "core_impl_debug", since = "1.9.0")] impl fmt::Debug for Filter { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { @@ -64,52 +100,16 @@ where fn next_chunk( &mut self, ) -> Result<[Self::Item; N], array::IntoIter> { - let mut array: [MaybeUninit; N] = [const { MaybeUninit::uninit() }; N]; - - struct Guard<'a, T> { - array: &'a mut [MaybeUninit], - initialized: usize, - } - - impl Drop for Guard<'_, T> { - #[inline] - fn drop(&mut self) { - if const { crate::mem::needs_drop::() } { - // SAFETY: self.initialized is always <= N, which also is the length of the array. - unsafe { - core::ptr::drop_in_place(MaybeUninit::slice_assume_init_mut( - self.array.get_unchecked_mut(..self.initialized), - )); - } - } + // avoid codegen for the dead branch + let fun = const { + if crate::mem::needs_drop::() { + array::iter_next_chunk:: + } else { + Self::next_chunk_dropless:: } - } - - let mut guard = Guard { array: &mut array, initialized: 0 }; - - let result = self.iter.try_for_each(|element| { - let idx = guard.initialized; - guard.initialized = idx + (self.predicate)(&element) as usize; - - // SAFETY: Loop conditions ensure the index is in bounds. - unsafe { guard.array.get_unchecked_mut(idx) }.write(element); - - if guard.initialized < N { ControlFlow::Continue(()) } else { ControlFlow::Break(()) } - }); + }; - let guard = ManuallyDrop::new(guard); - - match result { - ControlFlow::Break(()) => { - // SAFETY: The loop above is only explicitly broken when the array has been fully initialized - Ok(unsafe { MaybeUninit::array_assume_init(array) }) - } - ControlFlow::Continue(()) => { - let initialized = guard.initialized; - // SAFETY: The range is in bounds since the loop breaks when reaching N elements. - Err(unsafe { array::IntoIter::new_unchecked(array, 0..initialized) }) - } - } + fun(self) } #[inline] diff --git a/core/tests/iter/adapters/filter.rs b/core/tests/iter/adapters/filter.rs index a2050d89d8564..167851e33336e 100644 --- a/core/tests/iter/adapters/filter.rs +++ b/core/tests/iter/adapters/filter.rs @@ -1,4 +1,5 @@ use core::iter::*; +use std::rc::Rc; #[test] fn test_iterator_filter_count() { @@ -50,3 +51,15 @@ fn test_double_ended_filter() { assert_eq!(it.next().unwrap(), &2); assert_eq!(it.next_back(), None); } + +#[test] +fn test_next_chunk_does_not_leak() { + let drop_witness: [_; 5] = std::array::from_fn(|_| Rc::new(())); + + let v = (0..5).map(|i| drop_witness[i].clone()).collect::>(); + let _ = v.into_iter().filter(|_| false).next_chunk::<1>(); + + for ref w in drop_witness { + assert_eq!(Rc::strong_count(w), 1); + } +}