diff --git a/rayon-core/src/lib.rs b/rayon-core/src/lib.rs index e6e4ad7c8..47d977d29 100644 --- a/rayon-core/src/lib.rs +++ b/rayon-core/src/lib.rs @@ -50,6 +50,7 @@ mod scope; mod sleep; mod spawn; mod test; +mod thread_local; mod thread_pool; mod unwind; mod util; @@ -65,6 +66,7 @@ pub use thread_pool::current_thread_has_pending_tasks; pub use join::{join, join_context}; pub use scope::{scope, Scope}; pub use spawn::spawn; +pub use thread_local::ThreadLocal; /// Returns the number of threads in the current registry. If this /// code is executing within a Rayon thread-pool, then this will be diff --git a/rayon-core/src/registry.rs b/rayon-core/src/registry.rs index dcdb4e1f4..cfefeadff 100644 --- a/rayon-core/src/registry.rs +++ b/rayon-core/src/registry.rs @@ -456,7 +456,7 @@ pub struct WorkerThread { /// the "worker" half of our local deque worker: Deque, - index: usize, + pub(crate) index: usize, /// are these workers configured to steal breadth-first or not? breadth_first: bool, @@ -464,7 +464,7 @@ pub struct WorkerThread { /// A weak random number generator. rng: UnsafeCell, - registry: Arc, + pub(crate) registry: Arc, } // This is a bit sketchy, but basically: the WorkerThread is diff --git a/rayon-core/src/thread_local.rs b/rayon-core/src/thread_local.rs new file mode 100644 index 000000000..9946cfc84 --- /dev/null +++ b/rayon-core/src/thread_local.rs @@ -0,0 +1,74 @@ +use registry::{Registry, WorkerThread}; +use std::fmt; +use std::ops::Deref; +use std::sync::Arc; + +#[repr(align(64))] +#[derive(Debug)] +struct CacheAligned(T); + +/// Holds thread-locals values for each thread in a thread pool. +/// You can only access the thread local value through the Deref impl +/// on the thread pool it was constructed on. It will panic otherwise +pub struct ThreadLocal { + locals: Vec>, + registry: Arc, +} + +unsafe impl Send for ThreadLocal {} +unsafe impl Sync for ThreadLocal {} + +impl ThreadLocal { + /// Creates a new thread local where the `initial` closure computes the + /// value this thread local should take for each thread in the thread pool. + #[inline] + pub fn new T>(mut initial: F) -> ThreadLocal { + let registry = Registry::current(); + ThreadLocal { + locals: (0..registry.num_threads()) + .map(|i| CacheAligned(initial(i))) + .collect(), + registry, + } + } + + /// Returns the thread-local value for each thread + #[inline] + pub fn into_inner(self) -> Vec { + self.locals.into_iter().map(|c| c.0).collect() + } + + fn current(&self) -> &T { + unsafe { + let worker_thread = WorkerThread::current(); + if worker_thread.is_null() + || &*(*worker_thread).registry as *const _ != &*self.registry as *const _ + { + panic!("ThreadLocal can only be used on the thread pool it was created on") + } + &self.locals[(*worker_thread).index].0 + } + } +} + +impl ThreadLocal> { + /// Joins the elements of all the thread locals into one Vec + pub fn join(self) -> Vec { + self.into_inner().into_iter().flat_map(|v| v).collect() + } +} + +impl fmt::Debug for ThreadLocal { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + fmt::Debug::fmt(&self.locals, f) + } +} + +impl Deref for ThreadLocal { + type Target = T; + + #[inline(always)] + fn deref(&self) -> &T { + self.current() + } +}