diff --git a/tokio/src/task/coop/mod.rs b/tokio/src/task/coop/mod.rs index 013d5b1b327..f05f02050fd 100644 --- a/tokio/src/task/coop/mod.rs +++ b/tokio/src/task/coop/mod.rs @@ -305,7 +305,7 @@ cfg_coop! { Poll::Ready(restore) } else { - cx.waker().wake_by_ref(); + defer(cx); Poll::Pending } }).unwrap_or(Poll::Ready(RestoreOnPending(Cell::new(Budget::unconstrained())))) @@ -325,11 +325,19 @@ cfg_coop! { #[inline(always)] fn inc_budget_forced_yield_count() {} } + + fn defer(cx: &mut Context<'_>) { + context::defer(cx.waker()); + } } cfg_not_rt! { #[inline(always)] fn inc_budget_forced_yield_count() {} + + fn defer(cx: &mut Context<'_>) { + cx.waker().wake_by_ref(); + } } impl Budget { diff --git a/tokio/tests/rt_common.rs b/tokio/tests/rt_common.rs index c07e3e9ddb9..ab19177ab6b 100644 --- a/tokio/tests/rt_common.rs +++ b/tokio/tests/rt_common.rs @@ -745,7 +745,25 @@ rt_test! { #[cfg_attr(miri, ignore)] // No `socket` in miri. fn yield_defers_until_park() { for _ in 0..10 { - if yield_defers_until_park_inner() { + if yield_defers_until_park_inner(false) { + // test passed + return; + } + + // Wait a bit and run the test again. + std::thread::sleep(std::time::Duration::from_secs(2)); + } + + panic!("yield_defers_until_park is failing consistently"); + } + + /// Same as above, but with cooperative scheduling. + #[test] + #[cfg(not(target_os="wasi"))] + #[cfg_attr(miri, ignore)] // No `socket` in miri. + fn coop_yield_defers_until_park() { + for _ in 0..10 { + if yield_defers_until_park_inner(true) { // test passed return; } @@ -760,10 +778,12 @@ rt_test! { /// Implementation of `yield_defers_until_park` test. Returns `true` if the /// test passed. #[cfg(not(target_os="wasi"))] - fn yield_defers_until_park_inner() -> bool { + fn yield_defers_until_park_inner(use_coop: bool) -> bool { use std::sync::atomic::{AtomicBool, Ordering::SeqCst}; use std::sync::Barrier; + const BUDGET: usize = 128; + let rt = rt(); let flag = Arc::new(AtomicBool::new(false)); @@ -802,7 +822,15 @@ rt_test! { // Yield until connected let mut cnt = 0; while !flag_clone.load(SeqCst){ - tokio::task::yield_now().await; + if use_coop { + // Consume a good chunk of budget, which should + // force at least one yield. + for _ in 0..BUDGET { + tokio::task::consume_budget().await; + } + } else { + tokio::task::yield_now().await; + } cnt += 1; if cnt >= 10 {