Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add 'core::array::from_fn' and 'core::array::try_from_fn' #75644

Merged
merged 7 commits into from
Oct 9, 2021
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 100 additions & 17 deletions library/core/src/array/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,69 @@ mod iter;
#[stable(feature = "array_value_iter", since = "1.51.0")]
pub use iter::IntoIter;

/// Creates an array `[T; N]` where each array element `T` is returned by the `cb` call.
///
/// # Arguments
///
/// * `cb`: Callback where the passed argument is the current array index.
///
/// # Example
///
/// ```rust
/// #![feature(array_from_fn)]
///
/// let array = core::array::from_fn(|i| i);
/// assert_eq!(array, [0, 1, 2, 3, 4]);
/// ```
#[inline]
#[unstable(feature = "array_from_fn", issue = "89379")]
pub fn from_fn<F, T, const N: usize>(mut cb: F) -> [T; N]
where
F: FnMut(usize) -> T,
{
let mut idx = 0;
[(); N].map(|_| {
let res = cb(idx);
idx += 1;
res
})
}

/// Creates an array `[T; N]` where each fallible array element `T` is returned by the `cb` call.
/// Unlike `core::array::from_fn`, where the element creation can't fail, this version will return an error
/// if any element creation was unsuccessful.
///
/// # Arguments
///
/// * `cb`: Callback where the passed argument is the current array index.
///
/// # Example
///
/// ```rust
/// #![feature(array_from_fn)]
///
/// #[derive(Debug, PartialEq)]
/// enum SomeError {
/// Foo,
/// }
///
/// let array = core::array::try_from_fn(|i| Ok::<_, SomeError>(i));
/// assert_eq!(array, Ok([0, 1, 2, 3, 4]));
///
/// let another_array = core::array::try_from_fn::<SomeError, _, (), 2>(|_| Err(SomeError::Foo));
/// assert_eq!(another_array, Err(SomeError::Foo));
/// ```
#[inline]
#[unstable(feature = "array_from_fn", issue = "89379")]
pub fn try_from_fn<E, F, T, const N: usize>(cb: F) -> Result<[T; N], E>
where
F: FnMut(usize) -> Result<T, E>,
{
// SAFETY: we know for certain that this iterator will yield exactly `N`
// items.
unsafe { collect_into_array_rslt_unchecked(&mut (0..N).map(cb)) }
}

/// Converts a reference to `T` into a reference to an array of length 1 (without copying).
#[stable(feature = "array_from_ref", since = "1.53.0")]
pub fn from_ref<T>(s: &T) -> &[T; 1] {
Expand Down Expand Up @@ -448,13 +511,15 @@ impl<T, const N: usize> [T; N] {
///
/// It is up to the caller to guarantee that `iter` yields at least `N` items.
/// Violating this condition causes undefined behavior.
unsafe fn collect_into_array_unchecked<I, const N: usize>(iter: &mut I) -> [I::Item; N]
unsafe fn collect_into_array_rslt_unchecked<E, I, T, const N: usize>(
iter: &mut I,
) -> Result<[T; N], E>
where
// Note: `TrustedLen` here is somewhat of an experiment. This is just an
// internal function, so feel free to remove if this bound turns out to be a
// bad idea. In that case, remember to also remove the lower bound
// `debug_assert!` below!
I: Iterator + TrustedLen,
I: Iterator<Item = Result<T, E>> + TrustedLen,
{
debug_assert!(N <= iter.size_hint().1.unwrap_or(usize::MAX));
debug_assert!(N <= iter.size_hint().0);
Expand All @@ -463,6 +528,18 @@ where
unsafe { collect_into_array(iter).unwrap_unchecked() }
}

// Infallible version of `collect_into_array_rslt_unchecked`.
unsafe fn collect_into_array_unchecked<I, const N: usize>(iter: &mut I) -> [I::Item; N]
where
I: Iterator + TrustedLen,
{
let mut map = iter.map(|el| Ok::<_, Infallible>(el));

// SAFETY: Valid array elements are covered by the fact that all passed values
// to `collect_into_array` are `Ok`.
unsafe { collect_into_array_rslt_unchecked(&mut map).unwrap_unchecked() }
c410-f3r marked this conversation as resolved.
Show resolved Hide resolved
}
c410-f3r marked this conversation as resolved.
Show resolved Hide resolved

/// Pulls `N` items from `iter` and returns them as an array. If the iterator
/// yields fewer than `N` items, `None` is returned and all already yielded
/// items are dropped.
Expand All @@ -473,43 +550,49 @@ where
///
/// If `iter.next()` panicks, all items already yielded by the iterator are
/// dropped.
fn collect_into_array<I, const N: usize>(iter: &mut I) -> Option<[I::Item; N]>
fn collect_into_array<E, I, T, const N: usize>(iter: &mut I) -> Option<Result<[T; N], E>>
c410-f3r marked this conversation as resolved.
Show resolved Hide resolved
where
I: Iterator,
I: Iterator<Item = Result<T, E>>,
{
if N == 0 {
// SAFETY: An empty array is always inhabited and has no validity invariants.
return unsafe { Some(mem::zeroed()) };
return unsafe { Some(Ok(mem::zeroed())) };
}

struct Guard<T, const N: usize> {
ptr: *mut T,
struct Guard<'a, T, const N: usize> {
array_mut: &'a mut [MaybeUninit<T>; N],
initialized: usize,
}

impl<T, const N: usize> Drop for Guard<T, N> {
impl<T, const N: usize> Drop for Guard<'_, T, N> {
fn drop(&mut self) {
debug_assert!(self.initialized <= N);

let initialized_part = crate::ptr::slice_from_raw_parts_mut(self.ptr, self.initialized);

// SAFETY: this raw slice will contain only initialized objects.
// SAFETY: this slice will contain only initialized objects.
unsafe {
crate::ptr::drop_in_place(initialized_part);
crate::ptr::drop_in_place(MaybeUninit::slice_assume_init_mut(
&mut self.array_mut.get_unchecked_mut(..self.initialized),
));
}
}
}

let mut array = MaybeUninit::uninit_array::<N>();
let mut guard: Guard<_, N> =
Guard { ptr: MaybeUninit::slice_as_mut_ptr(&mut array), initialized: 0 };
let mut guard = Guard { array_mut: &mut array, initialized: 0 };

while let Some(item_rslt) = iter.next() {
let item = match item_rslt {
Err(err) => {
return Some(Err(err));
}
Ok(elem) => elem,
};

while let Some(item) = iter.next() {
// SAFETY: `guard.initialized` starts at 0, is increased by one in the
// loop and the loop is aborted once it reaches N (which is
// `array.len()`).
unsafe {
array.get_unchecked_mut(guard.initialized).write(item);
guard.array_mut.get_unchecked_mut(guard.initialized).write(item);
}
guard.initialized += 1;

Expand All @@ -520,7 +603,7 @@ where
// SAFETY: the condition above asserts that all elements are
// initialized.
let out = unsafe { MaybeUninit::array_assume_init(array) };
return Some(out);
return Some(Ok(out));
}
}

Expand Down
81 changes: 79 additions & 2 deletions library/core/tests/array.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use core::array;
use core::convert::TryFrom;
use core::sync::atomic::{AtomicUsize, Ordering};

#[test]
fn array_from_ref() {
Expand Down Expand Up @@ -303,8 +304,6 @@ fn array_map() {
#[test]
#[should_panic(expected = "test succeeded")]
fn array_map_drop_safety() {
use core::sync::atomic::AtomicUsize;
use core::sync::atomic::Ordering;
static DROPPED: AtomicUsize = AtomicUsize::new(0);
struct DropCounter;
impl Drop for DropCounter {
Expand Down Expand Up @@ -356,3 +355,81 @@ fn cell_allows_array_cycle() {
b3.a[0].set(Some(&b1));
b3.a[1].set(Some(&b2));
}

#[test]
fn array_from_fn() {
let array = core::array::from_fn(|idx| idx);
assert_eq!(array, [0, 1, 2, 3, 4]);
}

#[test]
fn array_try_from_fn() {
#[derive(Debug, PartialEq)]
enum SomeError {
Foo,
}

let array = core::array::try_from_fn(|i| Ok::<_, SomeError>(i));
assert_eq!(array, Ok([0, 1, 2, 3, 4]));

let another_array = core::array::try_from_fn::<SomeError, _, (), 2>(|_| Err(SomeError::Foo));
assert_eq!(another_array, Err(SomeError::Foo));
}

#[test]
fn array_try_from_fn_drops_inserted_elements_on_err() {
static DROP_COUNTER: AtomicUsize = AtomicUsize::new(0);

struct CountDrop;
impl Drop for CountDrop {
fn drop(&mut self) {
DROP_COUNTER.fetch_add(1, Ordering::SeqCst);
}
}

let _ = catch_unwind_silent(move || {
let _: Result<[CountDrop; 4], ()> = core::array::try_from_fn(|idx| {
if idx == 2 {
return Err(());
}
Ok(CountDrop)
});
});

assert_eq!(DROP_COUNTER.load(Ordering::SeqCst), 2);
}

#[test]
fn array_try_from_fn_drops_inserted_elements_on_panic() {
c410-f3r marked this conversation as resolved.
Show resolved Hide resolved
static DROP_COUNTER: AtomicUsize = AtomicUsize::new(0);

struct CountDrop;
impl Drop for CountDrop {
fn drop(&mut self) {
DROP_COUNTER.fetch_add(1, Ordering::SeqCst);
}
}

let _ = catch_unwind_silent(move || {
let _: Result<[CountDrop; 4], ()> = core::array::try_from_fn(|idx| {
if idx == 2 {
panic!("peek a boo");
}
Ok(CountDrop)
});
});

assert_eq!(DROP_COUNTER.load(Ordering::SeqCst), 2);
}

// https://stackoverflow.com/a/59211505
fn catch_unwind_silent<F, R>(f: F) -> std::thread::Result<R>
where
F: FnOnce() -> R + core::panic::UnwindSafe,
{
let prev_hook = std::panic::take_hook();
std::panic::set_hook(Box::new(|_| {}));
let result = std::panic::catch_unwind(f);
std::panic::set_hook(prev_hook);
result
}
1 change: 1 addition & 0 deletions library/core/tests/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#![feature(extern_types)]
#![feature(flt2dec)]
#![feature(fmt_internals)]
#![feature(array_from_fn)]
#![feature(hashmap_internals)]
#![feature(try_find)]
#![feature(is_sorted)]
Expand Down