Skip to content

Commit

Permalink
Add a ThreadLocal type which allow you to hold a value per Rayon work…
Browse files Browse the repository at this point in the history
…er thread
  • Loading branch information
Zoxc committed May 4, 2018
1 parent ba13591 commit 7874a15
Show file tree
Hide file tree
Showing 3 changed files with 78 additions and 2 deletions.
2 changes: 2 additions & 0 deletions rayon-core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ mod scope;
mod sleep;
mod spawn;
mod test;
mod thread_local;
mod thread_pool;
mod unwind;
mod util;
Expand All @@ -65,6 +66,7 @@ pub use thread_pool::current_thread_has_pending_tasks;
pub use join::{join, join_context};
pub use scope::{scope, Scope};
pub use spawn::spawn;
pub use thread_local::ThreadLocal;

/// Returns the number of threads in the current registry. If this
/// code is executing within a Rayon thread-pool, then this will be
Expand Down
4 changes: 2 additions & 2 deletions rayon-core/src/registry.rs
Original file line number Diff line number Diff line change
Expand Up @@ -456,15 +456,15 @@ pub struct WorkerThread {
/// the "worker" half of our local deque
worker: Deque<JobRef>,

index: usize,
pub(crate) index: usize,

/// are these workers configured to steal breadth-first or not?
breadth_first: bool,

/// A weak random number generator.
rng: UnsafeCell<rand::XorShiftRng>,

registry: Arc<Registry>,
pub(crate) registry: Arc<Registry>,
}

// This is a bit sketchy, but basically: the WorkerThread is
Expand Down
74 changes: 74 additions & 0 deletions rayon-core/src/thread_local.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
use registry::{Registry, WorkerThread};
use std::fmt;
use std::ops::Deref;
use std::sync::Arc;

#[repr(align(64))]
#[derive(Debug)]
struct CacheAligned<T>(T);

/// Holds thread-locals values for each thread in a thread pool.
/// You can only access the thread local value through the Deref impl
/// on the thread pool it was constructed on. It will panic otherwise
pub struct ThreadLocal<T> {
locals: Vec<CacheAligned<T>>,
registry: Arc<Registry>,
}

unsafe impl<T> Send for ThreadLocal<T> {}
unsafe impl<T> Sync for ThreadLocal<T> {}

impl<T> ThreadLocal<T> {
/// Creates a new thread local where the `initial` closure computes the
/// value this thread local should take for each thread in the thread pool.
#[inline]
pub fn new<F: FnMut(usize) -> T>(mut initial: F) -> ThreadLocal<T> {
let registry = Registry::current();
ThreadLocal {
locals: (0..registry.num_threads())
.map(|i| CacheAligned(initial(i)))
.collect(),
registry,
}
}

/// Returns the thread-local value for each thread
#[inline]
pub fn into_inner(self) -> Vec<T> {
self.locals.into_iter().map(|c| c.0).collect()
}

fn current(&self) -> &T {
unsafe {
let worker_thread = WorkerThread::current();
if worker_thread.is_null()
|| &*(*worker_thread).registry as *const _ != &*self.registry as *const _
{
panic!("ThreadLocal can only be used on the thread pool it was created on")
}
&self.locals[(*worker_thread).index].0
}
}
}

impl<T> ThreadLocal<Vec<T>> {
/// Joins the elements of all the thread locals into one Vec
pub fn join(self) -> Vec<T> {
self.into_inner().into_iter().flat_map(|v| v).collect()
}
}

impl<T: fmt::Debug> fmt::Debug for ThreadLocal<T> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
fmt::Debug::fmt(&self.locals, f)
}
}

impl<T> Deref for ThreadLocal<T> {
type Target = T;

#[inline(always)]
fn deref(&self) -> &T {
self.current()
}
}

0 comments on commit 7874a15

Please sign in to comment.