Skip to content

Commit

Permalink
Drop task future before fulfill join handle (#4)
Browse files Browse the repository at this point in the history
This way after `task::spawn(future).await`, we know that all
values contained in the `future` are dropped.

Resolves #3.
  • Loading branch information
kezhuw authored May 21, 2024
1 parent 6d75a28 commit d273435
Showing 1 changed file with 64 additions and 24 deletions.
88 changes: 64 additions & 24 deletions spawns-core/src/task.rs
Original file line number Diff line number Diff line change
Expand Up @@ -278,28 +278,33 @@ impl<F: Future> Future for IdFuture<F> {
}

struct TaskFuture<F: Future> {
ready: bool,
waker: Option<Box<Waker>>,
cancellation: Arc<Cancellation>,
joint: Arc<Joint<F::Output>>,
future: F,
future: Option<F>,
}

impl<F: Future> TaskFuture<F> {
fn new(future: F) -> Self {
Self {
ready: false,
waker: None,
joint: Arc::new(Joint::new()),
cancellation: Arc::new(Default::default()),
future,
future: Some(future),
}
}

fn finish(&mut self, value: Result<F::Output, InnerJoinError>) -> Poll<()> {
self.future = None;
self.joint.wake(value);
Poll::Ready(())
}
}

impl<F: Future> Drop for TaskFuture<F> {
fn drop(&mut self) {
if !self.ready {
self.joint.wake(Err(InnerJoinError::Cancelled));
if self.future.is_some() {
let _ = self.finish(Err(InnerJoinError::Cancelled));
}
}
}
Expand All @@ -309,14 +314,12 @@ impl<F: Future> Future for TaskFuture<F> {

fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let task = unsafe { self.get_unchecked_mut() };
if task.ready {
if task.future.is_none() {
return Poll::Ready(());
} else if task.cancellation.is_cancelled() {
task.joint.wake(Err(InnerJoinError::Cancelled));
task.ready = true;
return Poll::Ready(());
return task.finish(Err(InnerJoinError::Cancelled));
}
let future = unsafe { Pin::new_unchecked(&mut task.future) };
let future = unsafe { Pin::new_unchecked(task.future.as_mut().unwrap_unchecked()) };
match panic::catch_unwind(AssertUnwindSafe(|| future.poll(cx))) {
Ok(Poll::Pending) => {
let waker = match task.waker.take() {
Expand All @@ -327,23 +330,13 @@ impl<F: Future> Future for TaskFuture<F> {
}
};
let Ok(waker) = task.cancellation.update(waker) else {
task.joint.wake(Err(InnerJoinError::Cancelled));
task.ready = true;
return Poll::Ready(());
return task.finish(Err(InnerJoinError::Cancelled));
};
task.waker = waker;
Poll::Pending
}
Ok(Poll::Ready(value)) => {
task.joint.wake(Ok(value));
task.ready = true;
Poll::Ready(())
}
Err(err) => {
task.joint.wake(Err(InnerJoinError::Panic(err)));
task.ready = true;
Poll::Ready(())
}
Ok(Poll::Ready(value)) => task.finish(Ok(value)),
Err(err) => task.finish(Err(InnerJoinError::Panic(err))),
}
}
}
Expand Down Expand Up @@ -999,4 +992,51 @@ mod tests {
block_on(Box::into_pin(task.future));
assert_eq!(cancelled.load(Ordering::Relaxed), false);
}

struct CustomFuture {
_shared: Arc<()>,
}

impl Future for CustomFuture {
type Output = ();

fn poll(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Self::Output> {
Poll::Ready(())
}
}

#[test]
fn future_dropped_before_ready() {
let shared = Arc::new(());
let (mut task, _handle) = Task::new(
Name::default(),
CustomFuture {
_shared: shared.clone(),
},
);
let pinned = unsafe { Pin::new_unchecked(task.future.as_mut()) };
let poll = pinned.poll(&mut Context::from_waker(futures::task::noop_waker_ref()));
assert!(poll.is_ready());
assert_eq!(Arc::strong_count(&shared), 1);
}

#[test]
fn future_dropped_before_joined() {
let shared = Arc::new(());
let (mut task, handle) = Task::new(
Name::default(),
CustomFuture {
_shared: shared.clone(),
},
);
std::thread::spawn(move || {
let pinned = unsafe { Pin::new_unchecked(task.future.as_mut()) };
let _poll = pinned.poll(&mut Context::from_waker(futures::task::noop_waker_ref()));

// Let join handle complete before task drop.
std::thread::sleep(Duration::from_millis(10));
});
block_on(handle).unwrap();
assert_eq!(Arc::strong_count(&shared), 1);
}
}

0 comments on commit d273435

Please sign in to comment.