From d2734352bfe921aec79ee6ccec6e13ab7bcc5503 Mon Sep 17 00:00:00 2001 From: Kezhu Wang Date: Tue, 21 May 2024 11:05:39 +0800 Subject: [PATCH] Drop task future before fulfill join handle (#4) This way after `task::spawn(future).await`, we know that all values contained in the `future` are dropped. Resolves #3. --- spawns-core/src/task.rs | 88 ++++++++++++++++++++++++++++++----------- 1 file changed, 64 insertions(+), 24 deletions(-) diff --git a/spawns-core/src/task.rs b/spawns-core/src/task.rs index 2485728..dc7f16f 100644 --- a/spawns-core/src/task.rs +++ b/spawns-core/src/task.rs @@ -278,28 +278,33 @@ impl Future for IdFuture { } struct TaskFuture { - ready: bool, waker: Option>, cancellation: Arc, joint: Arc>, - future: F, + future: Option, } impl TaskFuture { fn new(future: F) -> Self { Self { - ready: false, waker: None, joint: Arc::new(Joint::new()), cancellation: Arc::new(Default::default()), - future, + future: Some(future), } } + + fn finish(&mut self, value: Result) -> Poll<()> { + self.future = None; + self.joint.wake(value); + Poll::Ready(()) + } } + impl Drop for TaskFuture { fn drop(&mut self) { - if !self.ready { - self.joint.wake(Err(InnerJoinError::Cancelled)); + if self.future.is_some() { + let _ = self.finish(Err(InnerJoinError::Cancelled)); } } } @@ -309,14 +314,12 @@ impl Future for TaskFuture { fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let task = unsafe { self.get_unchecked_mut() }; - if task.ready { + if task.future.is_none() { return Poll::Ready(()); } else if task.cancellation.is_cancelled() { - task.joint.wake(Err(InnerJoinError::Cancelled)); - task.ready = true; - return Poll::Ready(()); + return task.finish(Err(InnerJoinError::Cancelled)); } - let future = unsafe { Pin::new_unchecked(&mut task.future) }; + let future = unsafe { Pin::new_unchecked(task.future.as_mut().unwrap_unchecked()) }; match panic::catch_unwind(AssertUnwindSafe(|| future.poll(cx))) { Ok(Poll::Pending) => { let waker = match task.waker.take() { @@ -327,23 +330,13 @@ impl Future for TaskFuture { } }; let Ok(waker) = task.cancellation.update(waker) else { - task.joint.wake(Err(InnerJoinError::Cancelled)); - task.ready = true; - return Poll::Ready(()); + return task.finish(Err(InnerJoinError::Cancelled)); }; task.waker = waker; Poll::Pending } - Ok(Poll::Ready(value)) => { - task.joint.wake(Ok(value)); - task.ready = true; - Poll::Ready(()) - } - Err(err) => { - task.joint.wake(Err(InnerJoinError::Panic(err))); - task.ready = true; - Poll::Ready(()) - } + Ok(Poll::Ready(value)) => task.finish(Ok(value)), + Err(err) => task.finish(Err(InnerJoinError::Panic(err))), } } } @@ -999,4 +992,51 @@ mod tests { block_on(Box::into_pin(task.future)); assert_eq!(cancelled.load(Ordering::Relaxed), false); } + + struct CustomFuture { + _shared: Arc<()>, + } + + impl Future for CustomFuture { + type Output = (); + + fn poll(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll { + Poll::Ready(()) + } + } + + #[test] + fn future_dropped_before_ready() { + let shared = Arc::new(()); + let (mut task, _handle) = Task::new( + Name::default(), + CustomFuture { + _shared: shared.clone(), + }, + ); + let pinned = unsafe { Pin::new_unchecked(task.future.as_mut()) }; + let poll = pinned.poll(&mut Context::from_waker(futures::task::noop_waker_ref())); + assert!(poll.is_ready()); + assert_eq!(Arc::strong_count(&shared), 1); + } + + #[test] + fn future_dropped_before_joined() { + let shared = Arc::new(()); + let (mut task, handle) = Task::new( + Name::default(), + CustomFuture { + _shared: shared.clone(), + }, + ); + std::thread::spawn(move || { + let pinned = unsafe { Pin::new_unchecked(task.future.as_mut()) }; + let _poll = pinned.poll(&mut Context::from_waker(futures::task::noop_waker_ref())); + + // Let join handle complete before task drop. + std::thread::sleep(Duration::from_millis(10)); + }); + block_on(handle).unwrap(); + assert_eq!(Arc::strong_count(&shared), 1); + } }