-
Notifications
You must be signed in to change notification settings - Fork 507
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add a ThreadLocal type which allow you to hold a value per Rayon work…
…er thread
- Loading branch information
Showing
3 changed files
with
78 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
use registry::{Registry, WorkerThread}; | ||
use std::fmt; | ||
use std::ops::Deref; | ||
use std::sync::Arc; | ||
|
||
#[repr(align(64))] | ||
#[derive(Debug)] | ||
struct CacheAligned<T>(T); | ||
|
||
/// Holds thread-locals values for each thread in a thread pool. | ||
/// You can only access the thread local value through the Deref impl | ||
/// on the thread pool it was constructed on. It will panic otherwise | ||
pub struct ThreadLocal<T> { | ||
locals: Vec<CacheAligned<T>>, | ||
registry: Arc<Registry>, | ||
} | ||
|
||
unsafe impl<T> Send for ThreadLocal<T> {} | ||
unsafe impl<T> Sync for ThreadLocal<T> {} | ||
|
||
impl<T> ThreadLocal<T> { | ||
/// Creates a new thread local where the `initial` closure computes the | ||
/// value this thread local should take for each thread in the thread pool. | ||
#[inline] | ||
pub fn new<F: FnMut(usize) -> T>(mut initial: F) -> ThreadLocal<T> { | ||
let registry = Registry::current(); | ||
ThreadLocal { | ||
locals: (0..registry.num_threads()) | ||
.map(|i| CacheAligned(initial(i))) | ||
.collect(), | ||
registry, | ||
} | ||
} | ||
|
||
/// Returns the thread-local value for each thread | ||
#[inline] | ||
pub fn into_inner(self) -> Vec<T> { | ||
self.locals.into_iter().map(|c| c.0).collect() | ||
} | ||
|
||
fn current(&self) -> &T { | ||
unsafe { | ||
let worker_thread = WorkerThread::current(); | ||
if worker_thread.is_null() | ||
|| &*(*worker_thread).registry as *const _ != &*self.registry as *const _ | ||
{ | ||
panic!("ThreadLocal can only be used on the thread pool it was created on") | ||
} | ||
&self.locals[(*worker_thread).index].0 | ||
} | ||
} | ||
} | ||
|
||
impl<T> ThreadLocal<Vec<T>> { | ||
/// Joins the elements of all the thread locals into one Vec | ||
pub fn join(self) -> Vec<T> { | ||
self.into_inner().into_iter().flat_map(|v| v).collect() | ||
} | ||
} | ||
|
||
impl<T: fmt::Debug> fmt::Debug for ThreadLocal<T> { | ||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { | ||
fmt::Debug::fmt(&self.locals, f) | ||
} | ||
} | ||
|
||
impl<T> Deref for ThreadLocal<T> { | ||
type Target = T; | ||
|
||
#[inline(always)] | ||
fn deref(&self) -> &T { | ||
self.current() | ||
} | ||
} |