diff --git a/tokio-threadpool/Cargo.toml b/tokio-threadpool/Cargo.toml index f891f7e3386..324071e8290 100644 --- a/tokio-threadpool/Cargo.toml +++ b/tokio-threadpool/Cargo.toml @@ -20,11 +20,13 @@ categories = ["concurrency", "asynchronous"] [dependencies] tokio-executor = { version = "0.1.2", path = "../tokio-executor" } futures = "0.1.19" +crossbeam = "0.6.0" crossbeam-channel = "0.3.3" crossbeam-deque = "0.6.1" crossbeam-utils = "0.6.2" num_cpus = "1.2" rand = "0.6" +slab = "0.4.1" log = "0.4" [dev-dependencies] diff --git a/tokio-threadpool/src/lib.rs b/tokio-threadpool/src/lib.rs index f1a065ba16d..526f88ab12a 100644 --- a/tokio-threadpool/src/lib.rs +++ b/tokio-threadpool/src/lib.rs @@ -79,6 +79,7 @@ extern crate tokio_executor; +extern crate crossbeam; extern crate crossbeam_channel; extern crate crossbeam_deque as deque; extern crate crossbeam_utils; @@ -86,6 +87,7 @@ extern crate crossbeam_utils; extern crate futures; extern crate num_cpus; extern crate rand; +extern crate slab; #[macro_use] extern crate log; diff --git a/tokio-threadpool/src/pool/mod.rs b/tokio-threadpool/src/pool/mod.rs index bb9fd6a2598..3d7a9377394 100644 --- a/tokio-threadpool/src/pool/mod.rs +++ b/tokio-threadpool/src/pool/mod.rs @@ -45,12 +45,6 @@ pub(crate) struct Pool { // Stack tracking sleeping workers. sleep_stack: CachePadded, - // Number of workers that haven't reached the final state of shutdown - // - // This is only used to know when to single `shutdown_task` once the - // shutdown process has completed. - pub num_workers: AtomicUsize, - // Worker state // // A worker is a thread that is processing the work queue and polling @@ -122,7 +116,6 @@ impl Pool { let ret = Pool { state: CachePadded::new(AtomicUsize::new(State::new().into())), sleep_stack: CachePadded::new(worker::Stack::new()), - num_workers: AtomicUsize::new(0), workers, queue, trigger, @@ -313,7 +306,6 @@ impl Pool { } let trigger = match self.trigger.upgrade() { - // The pool is shutting down. None => { // The pool is shutting down. return; diff --git a/tokio-threadpool/src/shutdown.rs b/tokio-threadpool/src/shutdown.rs index 1cc19b54671..880fc822016 100644 --- a/tokio-threadpool/src/shutdown.rs +++ b/tokio-threadpool/src/shutdown.rs @@ -86,6 +86,11 @@ impl Drop for ShutdownTrigger { // Drain the global task queue. while self.queue.pop().is_some() {} + // Drop the remaining incomplete tasks and parkers assosicated with workers. + for worker in self.workers.iter() { + worker.shutdown(); + } + // Notify the task interested in shutdown. let mut inner = self.inner.lock().unwrap(); inner.completed = true; diff --git a/tokio-threadpool/src/task/mod.rs b/tokio-threadpool/src/task/mod.rs index ef873af571e..90592f203da 100644 --- a/tokio-threadpool/src/task/mod.rs +++ b/tokio-threadpool/src/task/mod.rs @@ -15,10 +15,10 @@ use futures::{self, Future, Async}; use futures::executor::{self, Spawn}; use std::{fmt, panic, ptr}; -use std::cell::{UnsafeCell}; +use std::cell::{Cell, UnsafeCell}; use std::sync::Arc; use std::sync::atomic::{AtomicUsize, AtomicPtr}; -use std::sync::atomic::Ordering::{AcqRel, Release, Relaxed}; +use std::sync::atomic::Ordering::{AcqRel, Acquire, Release, Relaxed}; /// Harness around a future. /// @@ -34,6 +34,21 @@ pub(crate) struct Task { /// Next pointer in the queue of tasks pending blocking capacity. next_blocking: AtomicPtr, + /// ID of the worker that polled this task first. + /// + /// This field can be a `Cell` because it's only accessed by the worker thread that is + /// executing the task. + /// + /// The worker ID is represented by a `u32` rather than `usize` in order to save some space + /// on 64-bit platforms. + pub reg_worker: Cell>, + + /// The key associated with this task in the `Slab` it was registered in. + /// + /// This field can be a `Cell` because it's only accessed by the worker thread that has + /// registered the task. + pub reg_index: Cell, + /// Store the future at the head of the struct /// /// The future is dropped immediately when it transitions to Complete @@ -61,6 +76,8 @@ impl Task { state: AtomicUsize::new(State::new().into()), blocking: AtomicUsize::new(BlockingState::new().into()), next_blocking: AtomicPtr::new(ptr::null_mut()), + reg_worker: Cell::new(None), + reg_index: Cell::new(0), future: UnsafeCell::new(Some(task_fut)), } } @@ -75,6 +92,8 @@ impl Task { state: AtomicUsize::new(State::stub().into()), blocking: AtomicUsize::new(BlockingState::new().into()), next_blocking: AtomicPtr::new(ptr::null_mut()), + reg_worker: Cell::new(None), + reg_index: Cell::new(0), future: UnsafeCell::new(Some(task_fut)), } } @@ -166,6 +185,41 @@ impl Task { } } + /// Aborts this task. + /// + /// This is called when the threadpool shuts down and the task has already beed polled but not + /// completed. + pub fn abort(&self) { + use self::State::*; + + let mut state = self.state.load(Acquire).into(); + + loop { + match state { + Idle | Scheduled => {} + Running | Notified | Complete | Aborted => { + // It is assumed that no worker threads are running so the task must be either + // in the idle or scheduled state. + panic!("unexpected state while aborting task: {:?}", state); + } + } + + let actual = self.state.compare_and_swap( + state.into(), + Aborted.into(), + AcqRel).into(); + + if actual == state { + // The future has been aborted. Drop it immediately to free resources and run drop + // handlers. + self.drop_future(); + break; + } + + state = actual; + } + } + /// Notify the task pub fn notify(me: Arc, pool: &Arc) { if me.schedule() { @@ -206,7 +260,7 @@ impl Task { _ => return false, } } - Complete | Notified | Scheduled => return false, + Complete | Aborted | Notified | Scheduled => return false, } } } diff --git a/tokio-threadpool/src/task/state.rs b/tokio-threadpool/src/task/state.rs index 9023eec5fbb..e01501c214e 100644 --- a/tokio-threadpool/src/task/state.rs +++ b/tokio-threadpool/src/task/state.rs @@ -15,6 +15,9 @@ pub(crate) enum State { /// Task is complete Complete = 4, + + /// Task was aborted because the thread pool has been shut down + Aborted = 5, } // ===== impl State ===== @@ -39,7 +42,7 @@ impl From for State { debug_assert!( src >= Idle as usize && - src <= Complete as usize, "actual={}", src); + src <= Aborted as usize, "actual={}", src); unsafe { ::std::mem::transmute(src) } } diff --git a/tokio-threadpool/src/worker/entry.rs b/tokio-threadpool/src/worker/entry.rs index d1013381ee6..07bee5bd298 100644 --- a/tokio-threadpool/src/worker/entry.rs +++ b/tokio-threadpool/src/worker/entry.rs @@ -5,11 +5,14 @@ use worker::state::{State, PUSHED_MASK}; use std::cell::UnsafeCell; use std::fmt; use std::sync::Arc; -use std::sync::atomic::{AtomicUsize, Ordering}; -use std::sync::atomic::Ordering::{AcqRel, Relaxed}; +use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; +use std::sync::atomic::Ordering::{Acquire, AcqRel, Relaxed, Release}; +use std::time::Duration; +use crossbeam::queue::SegQueue; use crossbeam_utils::CachePadded; use deque; +use slab::Slab; // TODO: None of the fields should be public // @@ -32,10 +35,20 @@ pub(crate) struct WorkerEntry { stealer: deque::Stealer>, // Thread parker - pub park: UnsafeCell, + park: UnsafeCell>, // Thread unparker - pub unpark: BoxUnpark, + unpark: UnsafeCell>, + + // Tasks that have been first polled by this worker, but not completed yet. + running_tasks: UnsafeCell>>, + + // Tasks that have been first polled by this worker, but completed by another worker. + remotely_completed_tasks: SegQueue>, + + // Set to `true` when `remotely_completed_tasks` has tasks that need to be removed from + // `running_tasks`. + needs_drain: AtomicBool, } impl WorkerEntry { @@ -47,8 +60,11 @@ impl WorkerEntry { next_sleeper: UnsafeCell::new(0), worker: w, stealer: s, - park: UnsafeCell::new(park), - unpark, + park: UnsafeCell::new(Some(park)), + unpark: UnsafeCell::new(Some(unpark)), + running_tasks: UnsafeCell::new(Slab::new()), + remotely_completed_tasks: SegQueue::new(), + needs_drain: AtomicBool::new(false), } } @@ -100,7 +116,7 @@ impl WorkerEntry { Sleeping => { // The worker is currently sleeping, the condition variable must // be signaled - self.wakeup(); + self.unpark(); true } Shutdown => false, @@ -163,7 +179,7 @@ impl WorkerEntry { } // Wakeup the worker - self.wakeup(); + self.unpark(); } /// Pop a task @@ -202,14 +218,94 @@ impl WorkerEntry { } } + /// Parks the worker thread. + pub fn park(&self) { + if let Some(park) = unsafe { (*self.park.get()).as_mut() } { + park.park().unwrap(); + } + } + + /// Parks the worker thread for at most `duration`. + pub fn park_timeout(&self, duration: Duration) { + if let Some(park) = unsafe { (*self.park.get()).as_mut() } { + park.park_timeout(duration).unwrap(); + } + } + + /// Unparks the worker thread. #[inline] - pub fn push_internal(&self, task: Arc) { - self.worker.push(task); + pub fn unpark(&self) { + if let Some(park) = unsafe { (*self.unpark.get()).as_ref() } { + park.unpark(); + } + } + + /// Registers a task in this worker. + /// + /// Called when the task is being polled for the first time. + #[inline] + pub fn register_task(&self, task: &Arc) { + let running_tasks = unsafe { &mut *self.running_tasks.get() }; + + let key = running_tasks.insert(task.clone()); + task.reg_index.set(key); } + /// Unregisters a task from this worker. + /// + /// Called when the task is completed and was previously registered in this worker. #[inline] - pub fn wakeup(&self) { - self.unpark.unpark(); + pub fn unregister_task(&self, task: Arc) { + let running_tasks = unsafe { &mut *self.running_tasks.get() }; + running_tasks.remove(task.reg_index.get()); + self.drain_remotely_completed_tasks(); + } + + /// Unregisters a task from this worker. + /// + /// Called when the task is completed by another worker and was previously registered in this + /// worker. + #[inline] + pub fn remotely_complete_task(&self, task: Arc) { + self.remotely_completed_tasks.push(task); + self.needs_drain.store(true, Release); + } + + /// Drops the remaining incomplete tasks and the parker associated with this worker. + /// + /// This function is called by the shutdown trigger. + pub fn shutdown(&self) { + self.drain_remotely_completed_tasks(); + + // Abort all incomplete tasks. + let running_tasks = unsafe { &mut *self.running_tasks.get() }; + for (_, task) in running_tasks.iter() { + task.abort(); + } + running_tasks.clear(); + + // Drop the parker. + unsafe { + *self.park.get() = None; + *self.unpark.get() = None; + } + } + + /// Drains the `remotely_completed_tasks` queue and removes tasks from `running_tasks`. + #[inline] + fn drain_remotely_completed_tasks(&self) { + if self.needs_drain.compare_and_swap(true, false, Acquire) { + let running_tasks = unsafe { &mut *self.running_tasks.get() }; + + while let Some(task) = self.remotely_completed_tasks.try_pop() { + running_tasks.remove(task.reg_index.get()); + } + } + } + + #[inline] + pub fn push_internal(&self, task: Arc) { + self.worker.push(task); } #[inline] diff --git a/tokio-threadpool/src/worker/mod.rs b/tokio-threadpool/src/worker/mod.rs index a82fe9f52eb..e270c10d935 100644 --- a/tokio-threadpool/src/worker/mod.rs +++ b/tokio-threadpool/src/worker/mod.rs @@ -451,6 +451,13 @@ impl Worker { fn run_task(&self, task: Arc, notify: &Arc) { use task::Run::*; + // If this is the first time this task is being polled, register it so that we can keep + // track of tasks that are in progress. + if task.reg_worker.get().is_none() { + task.reg_worker.set(Some(self.id.0 as u32)); + self.entry().register_task(&task); + } + let run = self.run_task2(&task, notify); // TODO: Try to claim back the worker state in case the backup thread @@ -497,6 +504,16 @@ impl Worker { } } + // Find which worker polled this task first. + let worker = task.reg_worker.get().unwrap() as usize; + + // Unregister the task from the worker it was registered in. + if !self.is_blocking.get() && worker == self.id.0 { + self.entry().unregister_task(task); + } else { + self.pool.workers[worker].remotely_complete_task(task); + } + // The worker's run loop will detect the shutdown state // next iteration. return; @@ -672,11 +689,7 @@ impl Worker { } } - unsafe { - (*self.entry().park.get()) - .park() - .unwrap(); - } + self.entry().park(); trace!(" -> wakeup; idx={}", self.id.0); } @@ -690,11 +703,7 @@ impl Worker { fn sleep_light(&self) { const STEAL_COUNT: usize = 32; - unsafe { - (*self.entry().park.get()) - .park_timeout(Duration::from_millis(0)) - .unwrap(); - } + self.entry().park_timeout(Duration::from_millis(0)); for _ in 0..STEAL_COUNT { if let Some(task) = self.pool.queue.pop() { diff --git a/tokio-threadpool/tests/threadpool.rs b/tokio-threadpool/tests/threadpool.rs index 0df1d1c2fe4..c9c3a76c1fc 100644 --- a/tokio-threadpool/tests/threadpool.rs +++ b/tokio-threadpool/tests/threadpool.rs @@ -3,7 +3,10 @@ extern crate tokio_executor; extern crate futures; extern crate env_logger; +use tokio_executor::park::{Park, Unpark}; use tokio_threadpool::*; +use tokio_threadpool::park::{DefaultPark, DefaultUnpark}; + use futures::{Poll, Sink, Stream, Async, Future}; use futures::future::lazy; @@ -420,3 +423,113 @@ fn multi_threadpool() { done_rx.recv().unwrap(); } + +#[test] +fn eagerly_drops_futures() { + use futures::future::{Future, lazy, empty}; + use futures::task; + use std::sync::mpsc; + + struct NotifyOnDrop(mpsc::Sender<()>); + + impl Drop for NotifyOnDrop { + fn drop(&mut self) { + self.0.send(()).unwrap(); + } + } + + struct MyPark { + inner: DefaultPark, + #[allow(dead_code)] + park_tx: mpsc::SyncSender<()>, + unpark_tx: mpsc::SyncSender<()>, + } + + impl Park for MyPark { + type Unpark = MyUnpark; + type Error = ::Error; + + fn unpark(&self) -> Self::Unpark { + MyUnpark { + inner: self.inner.unpark(), + unpark_tx: self.unpark_tx.clone(), + } + } + + fn park(&mut self) -> Result<(), Self::Error> { + self.inner.park() + } + + fn park_timeout(&mut self, duration: Duration) -> Result<(), Self::Error> { + self.inner.park_timeout(duration) + } + } + + struct MyUnpark { + inner: DefaultUnpark, + #[allow(dead_code)] + unpark_tx: mpsc::SyncSender<()>, + } + + impl Unpark for MyUnpark { + fn unpark(&self) { + self.inner.unpark() + } + } + + let (task_tx, task_rx) = mpsc::channel(); + let (drop_tx, drop_rx) = mpsc::channel(); + let (park_tx, park_rx) = mpsc::sync_channel(0); + let (unpark_tx, unpark_rx) = mpsc::sync_channel(0); + + // Get the signal that the handler dropped. + let notify_on_drop = NotifyOnDrop(drop_tx); + + let pool = tokio_threadpool::Builder::new() + .custom_park(move |_| { + MyPark { + inner: DefaultPark::new(), + park_tx: park_tx.clone(), + unpark_tx: unpark_tx.clone(), + } + }) + .build(); + + pool.spawn(lazy(move || { + // Get a handle to the current task. + let task = task::current(); + + // Send it to the main thread to hold on to. + task_tx.send(task).unwrap(); + + // This future will never resolve, it is only used to hold on to thee + // `notify_on_drop` handle. + empty::<(), ()>().then(move |_| { + // This code path should never be reached. + if true { panic!() } + + // Explicitly drop `notify_on_drop` here, this is mostly to ensure + // that the `notify_on_drop` handle gets moved into the task. It + // will actually get dropped when the runtime is dropped. + drop(notify_on_drop); + + Ok(()) + }) + })); + + // Wait until we get the task handle. + let task = task_rx.recv().unwrap(); + + // Drop the pool, this should result in futures being forcefully dropped. + drop(pool); + + // Make sure `MyPark` and `MyUnpark` were dropped during shutdown. + assert_eq!(park_rx.try_recv(), Err(mpsc::TryRecvError::Disconnected)); + assert_eq!(unpark_rx.try_recv(), Err(mpsc::TryRecvError::Disconnected)); + + // If the future is forcefully dropped, then we will get a signal here. + drop_rx.recv().unwrap(); + + // Ensure `task` lives until after the test completes. + drop(task); +}