diff --git a/rayon-core/Cargo.toml b/rayon-core/Cargo.toml index 260ca7156..4e353c4ef 100644 --- a/rayon-core/Cargo.toml +++ b/rayon-core/Cargo.toml @@ -13,11 +13,16 @@ readme = "README.md" keywords = ["parallel", "thread", "concurrency", "join", "performance"] categories = ["concurrency"] +[features] +default = ["tlv"] +tlv = [] + [dependencies] rand = ">= 0.3, < 0.5" num_cpus = "1.2" libc = "0.2.16" lazy_static = "1" +scoped-tls = "0.1.1" # This is deliberately not the latest version, because we want # to support older rustc than crossbeam-deque 0.3+ does. diff --git a/rayon-core/src/job.rs b/rayon-core/src/job.rs index 096e41294..5589f1caa 100644 --- a/rayon-core/src/job.rs +++ b/rayon-core/src/job.rs @@ -3,6 +3,8 @@ use std::any::Any; use std::cell::UnsafeCell; use std::mem; use unwind; +#[cfg(feature = "tlv")] +use tlv; pub enum JobResult { None, @@ -73,6 +75,8 @@ pub struct StackJob pub latch: L, func: UnsafeCell>, result: UnsafeCell>, + #[cfg(feature = "tlv")] + tlv: usize, } impl StackJob @@ -85,6 +89,8 @@ impl StackJob latch: latch, func: UnsafeCell::new(Some(func)), result: UnsafeCell::new(JobResult::None), + #[cfg(feature = "tlv")] + tlv: tlv::get(), } } @@ -108,6 +114,8 @@ impl Job for StackJob { unsafe fn execute(this: *const Self) { let this = &*this; + #[cfg(feature = "tlv")] + tlv::set(this.tlv); let abort = unwind::AbortIfPanic; let func = (*this.func.get()).take().unwrap(); (*this.result.get()) = match unwind::halt_unwinding(|| func(true)) { @@ -129,13 +137,19 @@ pub struct HeapJob where BODY: FnOnce() + Send { job: UnsafeCell>, + #[cfg(feature = "tlv")] + tlv: usize, } impl HeapJob where BODY: FnOnce() + Send { pub fn new(func: BODY) -> Self { - HeapJob { job: UnsafeCell::new(Some(func)) } + HeapJob { + job: UnsafeCell::new(Some(func)), + #[cfg(feature = "tlv")] + tlv: tlv::get(), + } } /// Creates a `JobRef` from this job -- note that this hides all @@ -152,6 +166,8 @@ impl Job for HeapJob { unsafe fn execute(this: *const Self) { let this: Box = mem::transmute(this); + #[cfg(feature = "tlv")] + tlv::set(this.tlv); let job = (*this.job.get()).take().unwrap(); job(); } diff --git a/rayon-core/src/join/mod.rs b/rayon-core/src/join/mod.rs index 8233d3bad..21c0b6ce4 100644 --- a/rayon-core/src/join/mod.rs +++ b/rayon-core/src/join/mod.rs @@ -1,10 +1,9 @@ use latch::{LatchProbe, SpinLatch}; use log::Event::*; use job::StackJob; -use registry::{self, WorkerThread}; -use std::any::Any; use unwind; - +use registry; +use PoisonedJob; use FnContext; #[cfg(test)] @@ -128,7 +127,17 @@ pub fn join_context(oper_a: A, oper_b: B) -> (RA, RB) let status_a = unwind::halt_unwinding(move || oper_a(FnContext::new(injected))); let result_a = match status_a { Ok(v) => v, - Err(err) => join_recover_from_panic(worker_thread, &job_b.latch, err), + Err(err) => { + // If job A panics, we still cannot return until we are sure that job + // B is complete. This is because it may contain references into the + // enclosing stack frame(s). + worker_thread.wait_until(&job_b.latch); + if err.is::() { + // Job A was poisoned, so unwind the panic of Job B if it exists + job_b.into_result(); + } + unwind::resume_unwinding(err) + }, }; // Now that task A has finished, try to pop job B from the @@ -162,16 +171,3 @@ pub fn join_context(oper_a: A, oper_b: B) -> (RA, RB) return (result_a, job_b.into_result()); }) } - -/// If job A panics, we still cannot return until we are sure that job -/// B is complete. This is because it may contain references into the -/// enclosing stack frame(s). -#[cold] // cold path -unsafe fn join_recover_from_panic(worker_thread: &WorkerThread, - job_b_latch: &SpinLatch, - err: Box) - -> ! -{ - worker_thread.wait_until(job_b_latch); - unwind::resume_unwinding(err) -} diff --git a/rayon-core/src/lib.rs b/rayon-core/src/lib.rs index 1e85e69f0..9fa1b32d7 100644 --- a/rayon-core/src/lib.rs +++ b/rayon-core/src/lib.rs @@ -35,6 +35,8 @@ use std::fmt; extern crate crossbeam_deque; #[macro_use] extern crate lazy_static; +#[macro_use] +extern crate scoped_tls; extern crate libc; extern crate num_cpus; extern crate rand; @@ -50,10 +52,14 @@ mod scope; mod sleep; mod spawn; mod test; +mod thread_local; mod thread_pool; mod unwind; mod util; +#[cfg(feature = "tlv")] +pub mod tlv; + #[cfg(rayon_unstable)] pub mod internal; pub use thread_pool::ThreadPool; @@ -62,6 +68,7 @@ pub use thread_pool::current_thread_has_pending_tasks; pub use join::{join, join_context}; pub use scope::{scope, Scope}; pub use spawn::spawn; +pub use thread_local::ThreadLocal; /// Returns the number of threads in the current registry. If this /// code is executing within a Rayon thread-pool, then this will be @@ -85,6 +92,11 @@ pub fn current_num_threads() -> usize { ::registry::Registry::current_num_threads() } +/// A value which can be thrown which will give priority to other panics +/// for `join` and `scope` +#[derive(Debug)] +pub struct PoisonedJob; + /// Error when initializing a thread pool. #[derive(Debug)] pub struct ThreadPoolBuildError { @@ -138,6 +150,9 @@ pub struct ThreadPoolBuilder { /// Closure invoked on worker thread exit. exit_handler: Option>, + /// Closure invoked on worker thread start. + main_handler: Option>, + /// If false, worker threads will execute spawned jobs in a /// "depth-first" fashion. If true, they will do a "breadth-first" /// fashion. Depth-first is the default. @@ -167,6 +182,12 @@ type StartHandler = Fn(usize) + Send + Sync; /// Note that this same closure may be invoked multiple times in parallel. type ExitHandler = Fn(usize) + Send + Sync; +/// The type for a closure that gets invoked with a +/// function which runs rayon tasks. +/// The closure is passed the index of the thread on which it is invoked. +/// Note that this same closure may be invoked multiple times in parallel. +type MainHandler = Fn(usize, &mut FnMut()) + Send + Sync; + impl ThreadPoolBuilder { /// Creates and returns a valid rayon thread pool builder, but does not initialize it. pub fn new() -> ThreadPoolBuilder { @@ -366,6 +387,23 @@ impl ThreadPoolBuilder { self.exit_handler = Some(Box::new(exit_handler)); self } + + /// Takes the current thread main callback, leaving `None`. + fn take_main_handler(&mut self) -> Option> { + self.main_handler.take() + } + + /// Set a callback to be invoked on thread main. + /// + /// The closure is passed the index of the thread on which it is invoked. + /// Note that this same closure may be invoked multiple times in parallel. + /// If this closure panics, the panic will be passed to the panic handler. + pub fn main_handler(mut self, main_handler: H) -> ThreadPoolBuilder + where H: Fn(usize, &mut FnMut()) + Send + Sync + 'static + { + self.main_handler = Some(Box::new(main_handler)); + self + } } #[allow(deprecated)] @@ -471,7 +509,7 @@ impl fmt::Debug for ThreadPoolBuilder { let ThreadPoolBuilder { ref num_threads, ref get_thread_name, ref panic_handler, ref stack_size, ref start_handler, ref exit_handler, - ref breadth_first } = *self; + ref main_handler, ref breadth_first } = *self; // Just print `Some()` or `None` to the debug // output. @@ -485,6 +523,7 @@ impl fmt::Debug for ThreadPoolBuilder { let panic_handler = panic_handler.as_ref().map(|_| ClosurePlaceholder); let start_handler = start_handler.as_ref().map(|_| ClosurePlaceholder); let exit_handler = exit_handler.as_ref().map(|_| ClosurePlaceholder); + let main_handler = main_handler.as_ref().map(|_| ClosurePlaceholder); f.debug_struct("ThreadPoolBuilder") .field("num_threads", num_threads) @@ -493,6 +532,7 @@ impl fmt::Debug for ThreadPoolBuilder { .field("stack_size", &stack_size) .field("start_handler", &start_handler) .field("exit_handler", &exit_handler) + .field("main_handler", &main_handler) .field("breadth_first", &breadth_first) .finish() } diff --git a/rayon-core/src/registry.rs b/rayon-core/src/registry.rs index 18748b344..cfefeadff 100644 --- a/rayon-core/src/registry.rs +++ b/rayon-core/src/registry.rs @@ -1,4 +1,5 @@ -use ::{ExitHandler, PanicHandler, StartHandler, ThreadPoolBuilder, ThreadPoolBuildError, ErrorKind}; +use ::{ExitHandler, PanicHandler, StartHandler, MainHandler, + ThreadPoolBuilder, ThreadPoolBuildError, ErrorKind}; use crossbeam_deque::{Deque, Steal, Stealer}; use job::{JobRef, StackJob}; #[cfg(rayon_unstable)] @@ -27,6 +28,7 @@ pub struct Registry { panic_handler: Option>, start_handler: Option>, exit_handler: Option>, + main_handler: Option>, // When this latch reaches 0, it means that all work on this // registry must be complete. This is ensured in the following ways: @@ -116,6 +118,7 @@ impl Registry { terminate_latch: CountLatch::new(), panic_handler: builder.take_panic_handler(), start_handler: builder.take_start_handler(), + main_handler: builder.take_main_handler(), exit_handler: builder.take_exit_handler(), }); @@ -212,8 +215,7 @@ impl Registry { /// Waits for the worker threads to stop. This is used for testing /// -- so we can check that termination actually works. - #[cfg(test)] - pub fn wait_until_stopped(&self) { + pub(crate) fn wait_until_stopped(&self) { for info in &self.thread_infos { info.stopped.wait(); } @@ -454,7 +456,7 @@ pub struct WorkerThread { /// the "worker" half of our local deque worker: Deque, - index: usize, + pub(crate) index: usize, /// are these workers configured to steal breadth-first or not? breadth_first: bool, @@ -462,7 +464,7 @@ pub struct WorkerThread { /// A weak random number generator. rng: UnsafeCell, - registry: Arc, + pub(crate) registry: Arc, } // This is a bit sketchy, but basically: the WorkerThread is @@ -671,7 +673,21 @@ unsafe fn main_loop(worker: Deque, } } - worker_thread.wait_until(®istry.terminate_latch); + let mut work = || { + worker_thread.wait_until(®istry.terminate_latch); + }; + + if let Some(ref handler) = registry.main_handler { + match unwind::halt_unwinding(|| handler(index, &mut work)) { + Ok(()) => { + } + Err(err) => { + registry.handle_panic(err); + } + } + } else { + work(); + } // Should not be any work left in our queue. debug_assert!(worker_thread.take_local_job().is_none()); diff --git a/rayon-core/src/scope/mod.rs b/rayon-core/src/scope/mod.rs index 5f349b97e..a70ed1a5d 100644 --- a/rayon-core/src/scope/mod.rs +++ b/rayon-core/src/scope/mod.rs @@ -16,6 +16,7 @@ use std::sync::Arc; use std::sync::atomic::{AtomicPtr, Ordering}; use registry::{in_worker, WorkerThread, Registry}; use unwind; +use PoisonedJob; #[cfg(test)] mod test; @@ -37,6 +38,10 @@ pub struct Scope<'scope> { /// propagated to the one who created the scope panic: AtomicPtr>, + /// if some job panicked with PoisonedJob, the error is stored here; it will be + /// propagated to the one who created the scope unless there is a proper panic + poisoned_panic: AtomicPtr>, + /// latch to set when the counter drops to zero (and hence this scope is complete) job_completed_latch: CountLatch, @@ -265,6 +270,7 @@ pub fn scope<'scope, OP, R>(op: OP) -> R owner_thread_index: owner_thread.index(), registry: owner_thread.registry().clone(), panic: AtomicPtr::new(ptr::null_mut()), + poisoned_panic: AtomicPtr::new(ptr::null_mut()), job_completed_latch: CountLatch::new(), marker: PhantomData, }; @@ -371,14 +377,18 @@ impl<'scope> Scope<'scope> { // capture the first error we see, free the rest let nil = ptr::null_mut(); let mut err = Box::new(err); // box up the fat ptr - if self.panic.compare_exchange(nil, &mut *err, Ordering::Release, Ordering::Relaxed).is_ok() { + let field = if err.is::() { + &self.poisoned_panic + } else { + &self.panic + }; + if field.compare_exchange(nil, &mut *err, Ordering::Release, Ordering::Relaxed).is_ok() { log!(JobPanickedErrorStored { owner_thread: self.owner_thread_index }); mem::forget(err); // ownership now transferred into self.panic } else { log!(JobPanickedErrorNotStored { owner_thread: self.owner_thread_index }); } - self.job_completed_latch.set(); } @@ -394,7 +404,10 @@ impl<'scope> Scope<'scope> { // propagate panic, if any occurred; at this point, all // outstanding jobs have completed, so we can use a relaxed // ordering: - let panic = self.panic.swap(ptr::null_mut(), Ordering::Relaxed); + let mut panic = self.panic.swap(ptr::null_mut(), Ordering::Relaxed); + if panic.is_null() { + panic = self.poisoned_panic.swap(ptr::null_mut(), Ordering::Relaxed); + } if !panic.is_null() { log!(ScopeCompletePanicked { owner_thread: owner_thread.index() }); let value: Box> = mem::transmute(panic); diff --git a/rayon-core/src/thread_local.rs b/rayon-core/src/thread_local.rs new file mode 100644 index 000000000..9946cfc84 --- /dev/null +++ b/rayon-core/src/thread_local.rs @@ -0,0 +1,74 @@ +use registry::{Registry, WorkerThread}; +use std::fmt; +use std::ops::Deref; +use std::sync::Arc; + +#[repr(align(64))] +#[derive(Debug)] +struct CacheAligned(T); + +/// Holds thread-locals values for each thread in a thread pool. +/// You can only access the thread local value through the Deref impl +/// on the thread pool it was constructed on. It will panic otherwise +pub struct ThreadLocal { + locals: Vec>, + registry: Arc, +} + +unsafe impl Send for ThreadLocal {} +unsafe impl Sync for ThreadLocal {} + +impl ThreadLocal { + /// Creates a new thread local where the `initial` closure computes the + /// value this thread local should take for each thread in the thread pool. + #[inline] + pub fn new T>(mut initial: F) -> ThreadLocal { + let registry = Registry::current(); + ThreadLocal { + locals: (0..registry.num_threads()) + .map(|i| CacheAligned(initial(i))) + .collect(), + registry, + } + } + + /// Returns the thread-local value for each thread + #[inline] + pub fn into_inner(self) -> Vec { + self.locals.into_iter().map(|c| c.0).collect() + } + + fn current(&self) -> &T { + unsafe { + let worker_thread = WorkerThread::current(); + if worker_thread.is_null() + || &*(*worker_thread).registry as *const _ != &*self.registry as *const _ + { + panic!("ThreadLocal can only be used on the thread pool it was created on") + } + &self.locals[(*worker_thread).index].0 + } + } +} + +impl ThreadLocal> { + /// Joins the elements of all the thread locals into one Vec + pub fn join(self) -> Vec { + self.into_inner().into_iter().flat_map(|v| v).collect() + } +} + +impl fmt::Debug for ThreadLocal { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + fmt::Debug::fmt(&self.locals, f) + } +} + +impl Deref for ThreadLocal { + type Target = T; + + #[inline(always)] + fn deref(&self) -> &T { + self.current() + } +} diff --git a/rayon-core/src/thread_pool/mod.rs b/rayon-core/src/thread_pool/mod.rs index 1f0088e75..43b21993f 100644 --- a/rayon-core/src/thread_pool/mod.rs +++ b/rayon-core/src/thread_pool/mod.rs @@ -85,6 +85,40 @@ impl ThreadPool { &DEFAULT_THREAD_POOL } + /// Creates a scoped thread pool + pub fn scoped_pool(builder: ThreadPoolBuilder, + main_handler: H, + with_pool: F) -> Result + where F: FnOnce(&ThreadPool) -> R, + H: Fn(&mut FnMut()) + Send + Sync + { + struct Handler(*const ()); + unsafe impl Send for Handler {} + unsafe impl Sync for Handler {} + + let handler = Handler(&main_handler as *const _ as *const ()); + + let builder = builder.main_handler(move |_, worker| { + let handler = unsafe { &*(handler.0 as *const H) }; + handler(worker); + }); + + let pool = builder.build()?; + + struct JoinRegistry(Arc); + + impl Drop for JoinRegistry { + fn drop(&mut self) { + self.0.terminate(); + self.0.wait_until_stopped(); + } + } + + let _join_registry = JoinRegistry(pool.registry.clone()); + + Ok(with_pool(&pool)) + } + /// Executes `op` within the threadpool. Any attempts to use /// `join`, `scope`, or parallel iterators will then operate /// within that threadpool. diff --git a/rayon-core/src/tlv.rs b/rayon-core/src/tlv.rs new file mode 100644 index 000000000..f035d12d9 --- /dev/null +++ b/rayon-core/src/tlv.rs @@ -0,0 +1,30 @@ +//! Allows access to the Rayon's thread local value +//! which is preserved when moving jobs across threads + +use std::cell::Cell; + +thread_local!(pub(crate) static TLV: Cell = Cell::new(0)); + +/// Sets the current thread-local value to `value` inside the closure. +/// The old value is restored when the closure ends +pub fn with R, R>(value: usize, f: F) -> R { + struct Reset(usize); + impl Drop for Reset { + fn drop(&mut self) { + TLV.with(|tlv| tlv.set(self.0)); + } + } + let _reset = Reset(get()); + TLV.with(|tlv| tlv.set(value)); + f() +} + +/// Sets the current thread-local value +pub fn set(value: usize) { + TLV.with(|tlv| tlv.set(value)); +} + +/// Returns the current thread-local value +pub fn get() -> usize { + TLV.with(|tlv| tlv.get()) +}