Skip to content

Commit

Permalink
Only work-steal in the main loop for join and scope
Browse files Browse the repository at this point in the history
  • Loading branch information
Zoxc committed Aug 27, 2023
1 parent f192a48 commit 7fdf1fd
Show file tree
Hide file tree
Showing 8 changed files with 127 additions and 89 deletions.
1 change: 1 addition & 0 deletions rayon-core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ num_cpus = "1.2"
crossbeam-channel = "0.5.0"
crossbeam-deque = "0.8.1"
crossbeam-utils = "0.8.0"
smallvec = "1.11.0"

[dev-dependencies]
rand = "0.8"
Expand Down
2 changes: 1 addition & 1 deletion rayon-core/src/broadcast/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ where
registry.inject_broadcast(job_refs);

// Wait for all jobs to complete, then collect the results, maybe propagating a panic.
latch.wait(current_thread);
latch.wait(current_thread, None);
jobs.into_iter().map(|job| job.into_result()).collect()
}

Expand Down
26 changes: 14 additions & 12 deletions rayon-core/src/job.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,11 @@ pub(super) trait Job {
unsafe fn execute(this: *const ());
}

#[derive(PartialEq, Eq, Hash, Copy, Clone)]
pub(super) struct JobRefId {
pointer: usize,
}

/// Effectively a Job trait object. Each JobRef **must** be executed
/// exactly once, or else data may leak.
///
Expand Down Expand Up @@ -54,11 +59,11 @@ impl JobRef {
}
}

/// Returns an opaque handle that can be saved and compared,
/// without making `JobRef` itself `Copy + Eq`.
#[inline]
pub(super) fn id(&self) -> impl Eq {
(self.pointer, self.execute_fn)
pub(super) fn id(&self) -> JobRefId {
JobRefId {
pointer: self.pointer as usize,
}
}

#[inline]
Expand Down Expand Up @@ -102,10 +107,6 @@ where
JobRef::new(self)
}

pub(super) unsafe fn run_inline(self, stolen: bool) -> R {
self.func.into_inner().unwrap()(stolen)
}

pub(super) unsafe fn into_result(self) -> R {
self.result.into_inner().into_return_value()
}
Expand Down Expand Up @@ -136,15 +137,15 @@ where
/// (Probably `StackJob` should be refactored in a similar fashion.)
pub(super) struct HeapJob<BODY>
where
BODY: FnOnce() + Send,
BODY: FnOnce(JobRefId) + Send,
{
job: BODY,
tlv: Tlv,
}

impl<BODY> HeapJob<BODY>
where
BODY: FnOnce() + Send,
BODY: FnOnce(JobRefId) + Send,
{
pub(super) fn new(tlv: Tlv, job: BODY) -> Box<Self> {
Box::new(HeapJob { job, tlv })
Expand All @@ -168,12 +169,13 @@ where

impl<BODY> Job for HeapJob<BODY>
where
BODY: FnOnce() + Send,
BODY: FnOnce(JobRefId) + Send,
{
unsafe fn execute(this: *const ()) {
let pointer = this as usize;
let this = Box::from_raw(this as *mut Self);
tlv::set(this.tlv);
(this.job)();
(this.job)(JobRefId { pointer });
}
}

Expand Down
77 changes: 23 additions & 54 deletions rayon-core/src/join/mod.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
use crate::job::StackJob;
use crate::latch::SpinLatch;
use crate::registry::{self, WorkerThread};
use crate::tlv::{self, Tlv};
use crate::registry;
use crate::tlv;
use crate::unwind;
use std::any::Any;
use std::sync::atomic::{AtomicBool, Ordering};

use crate::FnContext;

Expand Down Expand Up @@ -135,68 +135,37 @@ where
// Create virtual wrapper for task b; this all has to be
// done here so that the stack frame can keep it all live
// long enough.
let job_b = StackJob::new(tlv, call_b(oper_b), SpinLatch::new(worker_thread));
let job_b_started = AtomicBool::new(false);
let job_b = StackJob::new(
tlv,
|migrated| {
job_b_started.store(true, Ordering::Relaxed);
call_b(oper_b)(migrated)
},
SpinLatch::new(worker_thread),
);
let job_b_ref = job_b.as_job_ref();
let job_b_id = job_b_ref.id();
worker_thread.push(job_b_ref);

// Execute task a; hopefully b gets stolen in the meantime.
let status_a = unwind::halt_unwinding(call_a(oper_a, injected));
let result_a = match status_a {
Ok(v) => v,
Err(err) => join_recover_from_panic(worker_thread, &job_b.latch, err, tlv),
};

// Now that task A has finished, try to pop job B from the
// local stack. It may already have been popped by job A; it
// may also have been stolen. There may also be some tasks
// pushed on top of it in the stack, and we will have to pop
// those off to get to it.
while !job_b.latch.probe() {
if let Some(job) = worker_thread.take_local_job() {
if job_b_id == job.id() {
// Found it! Let's run it.
//
// Note that this could panic, but it's ok if we unwind here.

// Restore the TLV since we might have run some jobs overwriting it when waiting for job b.
tlv::set(tlv);

let result_b = job_b.run_inline(injected);
return (result_a, result_b);
} else {
worker_thread.execute(job);
}
} else {
// Local deque is empty. Time to steal from other
// threads.
worker_thread.wait_until(&job_b.latch);
debug_assert!(job_b.latch.probe());
break;
}
}
// Wait for job B or execute it if it's in the local queue.
worker_thread.wait_for_jobs(
&job_b.latch,
|| job_b_started.load(Ordering::Relaxed),
|job| job.id() == job_b_id,
);

// Restore the TLV since we might have run some jobs overwriting it when waiting for job b.
tlv::set(tlv);

let result_a = match status_a {
Ok(v) => v,
Err(err) => unwind::resume_unwinding(err),
};

(result_a, job_b.into_result())
})
}

/// If job A panics, we still cannot return until we are sure that job
/// B is complete. This is because it may contain references into the
/// enclosing stack frame(s).
#[cold] // cold path
unsafe fn join_recover_from_panic(
worker_thread: &WorkerThread,
job_b_latch: &SpinLatch<'_>,
err: Box<dyn Any + Send>,
tlv: Tlv,
) -> ! {
worker_thread.wait_until(job_b_latch);

// Restore the TLV since we might have run some jobs overwriting it when waiting for job b.
tlv::set(tlv);

unwind::resume_unwinding(err)
}
5 changes: 0 additions & 5 deletions rayon-core/src/latch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -177,11 +177,6 @@ impl<'r> SpinLatch<'r> {
..SpinLatch::new(thread)
}
}

#[inline]
pub(super) fn probe(&self) -> bool {
self.core_latch.probe()
}
}

impl<'r> AsCoreLatch for SpinLatch<'r> {
Expand Down
65 changes: 58 additions & 7 deletions rayon-core/src/registry.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ use crate::{
ReleaseThreadHandler, StartHandler, ThreadPoolBuildError, ThreadPoolBuilder, Yield,
};
use crossbeam_deque::{Injector, Steal, Stealer, Worker};
use smallvec::SmallVec;
use std::cell::Cell;
use std::collections::hash_map::DefaultHasher;
use std::fmt;
Expand Down Expand Up @@ -840,14 +841,58 @@ impl WorkerThread {
/// stealing tasks as necessary.
#[inline]
pub(super) unsafe fn wait_until<L: AsCoreLatch + ?Sized>(&self, latch: &L) {
self.wait_or_steal_until(latch, false)
}

/// Wait until the latch is set. Executes local jobs if `is_job` is true for them and
/// `all_jobs_started` still returns false.
#[inline]
pub(super) unsafe fn wait_for_jobs<L: AsCoreLatch + ?Sized>(
&self,
latch: &L,
mut all_jobs_started: impl FnMut() -> bool,
mut is_job: impl FnMut(&JobRef) -> bool,
) {
let mut jobs = SmallVec::<[JobRef; 8]>::new();

// Make sure all jobs have started.
while !all_jobs_started() {
if let Some(job) = self.worker.pop() {
if is_job(&job) {
// Found a job, let's run it.
self.execute(job);
} else {
jobs.push(job);
}
} else {
break;
}
}

// Restore the jobs that we weren't looking for.
for job in jobs.into_iter().rev() {
self.worker.push(job);
}

// Wait for the jobs to finish.
self.wait_until(latch);
debug_assert!(latch.as_core_latch().probe());
}

#[inline]
pub(super) unsafe fn wait_or_steal_until<L: AsCoreLatch + ?Sized>(
&self,
latch: &L,
steal: bool,
) {
let latch = latch.as_core_latch();
if !latch.probe() {
self.wait_until_cold(latch);
self.wait_until_cold(latch, steal);
}
}

#[cold]
unsafe fn wait_until_cold(&self, latch: &CoreLatch) {
unsafe fn wait_until_cold(&self, latch: &CoreLatch, steal: bool) {
// the code below should swallow all panics and hence never
// unwind; but if something does wrong, we want to abort,
// because otherwise other code in rayon may assume that the
Expand All @@ -857,10 +902,16 @@ impl WorkerThread {

let mut idle_state = self.registry.sleep.start_looking(self.index, latch);
while !latch.probe() {
if let Some(job) = self.find_work() {
self.registry.sleep.work_found(idle_state);
self.execute(job);
idle_state = self.registry.sleep.start_looking(self.index, latch);
if steal {
if let Some(job) = self.find_work() {
self.registry.sleep.work_found(idle_state);
self.execute(job);
idle_state = self.registry.sleep.start_looking(self.index, latch);
} else {
self.registry
.sleep
.no_work_found(&mut idle_state, latch, &self)
}
} else {
self.registry
.sleep
Expand Down Expand Up @@ -988,7 +1039,7 @@ unsafe fn main_loop(thread: ThreadBuilder) {
terminate_addr: my_terminate_latch.as_core_latch().addr(),
});
registry.acquire_thread();
worker_thread.wait_until(my_terminate_latch);
worker_thread.wait_or_steal_until(my_terminate_latch, true);

// Should not be any work left in our queue.
debug_assert!(worker_thread.take_local_job().is_none());
Expand Down
Loading

0 comments on commit 7fdf1fd

Please sign in to comment.