diff --git a/maitake/src/task.rs b/maitake/src/task.rs index 555ba02a..35b00c2c 100644 --- a/maitake/src/task.rs +++ b/maitake/src/task.rs @@ -734,21 +734,7 @@ where #[cfg(any(feature = "tracing-01", feature = "tracing-02", test))] let _span = self.span().enter(); - self.inner.with_mut(|cell| { - let cell = unsafe { &mut *cell }; - let poll = match cell { - Cell::Pending(future) => unsafe { Pin::new_unchecked(future).poll(&mut cx) }, - _ => unreachable!("tried to poll a completed future!"), - }; - - match poll { - Poll::Ready(ready) => { - *cell = Cell::Ready(ready); - Poll::Ready(()) - } - Poll::Pending => Poll::Pending, - } - }) + self.inner.with_mut(|cell| unsafe { (*cell).poll(&mut cx) }) } /// Wakes the task's [`JoinHandle`], if it has one. @@ -1427,6 +1413,23 @@ impl fmt::Debug for Cell { } } +impl Cell { + fn poll(&mut self, cx: &mut Context<'_>) -> Poll<()> { + let poll = match self { + Cell::Pending(future) => unsafe { Pin::new_unchecked(future).poll(cx) }, + _ => unreachable!("tried to poll a completed future!"), + }; + + match poll { + Poll::Ready(ready) => { + *self = Cell::Ready(ready); + Poll::Ready(()) + } + Poll::Pending => Poll::Pending, + } + } +} + // === impl Vtable === impl fmt::Debug for Vtable { diff --git a/maitake/src/task/tests/alloc_tests.rs b/maitake/src/task/tests/alloc_tests.rs index caa2e8ef..0b1c0686 100644 --- a/maitake/src/task/tests/alloc_tests.rs +++ b/maitake/src/task/tests/alloc_tests.rs @@ -264,3 +264,53 @@ fn drop_join_handle() { assert!(COMPLETED.load(Ordering::Relaxed)) } + +// Test for potential UB in `Cell::poll` due to niche optimization. +// See https://github.com/rust-lang/miri/issues/3780 for details. +// +// This is based on the test for analogous types in Tokio added in +// https://github.com/tokio-rs/tokio/pull/6744 +#[test] +fn cell_miri() { + use super::Cell; + use alloc::{string::String, sync::Arc, task::Wake}; + use core::{ + future::Future, + pin::Pin, + task::{Context, Poll}, + }; + + struct DummyWaker; + + impl Wake for DummyWaker { + fn wake(self: Arc) {} + } + + struct ThingAdder<'a> { + thing: &'a mut String, + } + + impl Future for ThingAdder<'_> { + type Output = (); + + fn poll(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll { + unsafe { + *self.get_unchecked_mut().thing += ", world"; + } + Poll::Pending + } + } + + let mut thing = "hello".to_owned(); + + // The async block is necessary to trigger the miri failure. + #[allow(clippy::redundant_async_block)] + let fut = async move { ThingAdder { thing: &mut thing }.await }; + + let mut fut = Cell::Pending(fut); + + let waker = Arc::new(DummyWaker).into(); + let mut ctx = Context::from_waker(&waker); + assert_eq!(fut.poll(&mut ctx), Poll::Pending); + assert_eq!(fut.poll(&mut ctx), Poll::Pending); +}