Skip to content

Commit

Permalink
ws pool: random stealing; rework main state machine
Browse files Browse the repository at this point in the history
in the state machine, after waiting, we check the main queue, else we
directly go to stealing.
  • Loading branch information
c-cube committed Oct 27, 2023
1 parent aa7906e commit aba0d84
Showing 1 changed file with 51 additions and 56 deletions.
107 changes: 51 additions & 56 deletions src/ws_pool.ml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ let ( let@ ) = ( @@ )
type worker_state = {
mutable thread: Thread.t;
q: task WSQ.t; (** Work stealing queue *)
mutable work_steal_offset: int; (** Current offset for work stealing *)
rng: Random.State.t;
}
(** State for a given worker. Only this worker is
allowed to push into the queue, but other workers
Expand Down Expand Up @@ -111,39 +111,26 @@ let[@inline] wait_ (self : state) : unit =
self.n_waiting <- self.n_waiting - 1;
if self.n_waiting = 0 then self.n_waiting_nonzero <- false

(** Try to steal a task from the worker [w] *)
exception Got_task of task

(** Try to steal a task *)
let try_to_steal_work_once_ (self : state) (w : worker_state) : task option =
w.work_steal_offset <- (w.work_steal_offset + 1) mod Array.length self.workers;

(* if we're pointing to [w], skip to the next worker as
it's useless to steal from oneself *)
if Array.unsafe_get self.workers w.work_steal_offset == w then
w.work_steal_offset <-
(w.work_steal_offset + 1) mod Array.length self.workers;

let w' = Array.unsafe_get self.workers w.work_steal_offset in
WSQ.steal w'.q

(** Try to steal work from several other workers. *)
let try_to_steal_work_loop (self : state) ~runner w : bool =
if size_ self = 1 then
(* no stealing for single thread pool *)
false
else (
let has_stolen = ref false in
let n_retries_left = ref (size_ self - 1) in

while !n_retries_left > 0 do
match try_to_steal_work_once_ self w with
| Some task ->
try_wake_someone_ self;
run_task_now_ self ~runner task;
has_stolen := true;
n_retries_left := 0
| None -> decr n_retries_left
let init = Random.State.int w.rng (Array.length self.workers) in

try
for i = 0 to Array.length self.workers - 1 do
let w' =
Array.unsafe_get self.workers ((i + init) mod Array.length self.workers)
in

if w != w' then (
match WSQ.steal w'.q with
| Some t -> raise_notrace (Got_task t)
| None -> ()
)
done;
!has_stolen
)
None
with Got_task t -> Some t

(** Worker runs tasks from its queue until none remains *)
let worker_run_self_tasks_ (self : state) ~runner w : unit =
Expand All @@ -160,29 +147,41 @@ let worker_run_self_tasks_ (self : state) ~runner w : unit =
let worker_thread_ (self : state) ~(runner : t) (w : worker_state) : unit =
TLS.get k_worker_state := Some w;

let main_loop () : unit =
let continue = ref true in
while !continue && A.get self.active do
let rec main () : unit =
if A.get self.active then (
worker_run_self_tasks_ self ~runner w;

let did_steal = try_to_steal_work_loop self ~runner w in
if not did_steal then (
Mutex.lock self.mutex;
match Queue.pop self.main_q with
| task ->
Mutex.unlock self.mutex;
run_task_now_ self ~runner task
| exception Queue.Empty ->
if A.get self.active then wait_ self;
Mutex.unlock self.mutex
)
done;
assert (WSQ.size w.q = 0)
try_steal ()
)
and run_task task : unit =
run_task_now_ self ~runner task;
main ()
and try_steal () =
if A.get self.active then (
match try_to_steal_work_once_ self w with
| Some task -> run_task task
| None -> wait ()
)
and wait () =
Mutex.lock self.mutex;
match Queue.pop self.main_q with
| task ->
Mutex.unlock self.mutex;
run_task task
| exception Queue.Empty ->
(* wait here *)
if A.get self.active then wait_ self;

(* see if a task became available *)
let task = try Some (Queue.pop self.main_q) with Queue.Empty -> None in
Mutex.unlock self.mutex;

(match task with
| Some t -> run_task t
| None -> try_steal ())
in

(* handle domain-local await *)
Dla_.using ~prepare_for_await:Suspend_.prepare_for_await
~while_running:main_loop
Dla_.using ~prepare_for_await:Suspend_.prepare_for_await ~while_running:main

let default_thread_init_exit_ ~dom_id:_ ~t_id:_ () = ()

Expand Down Expand Up @@ -226,11 +225,7 @@ let create ?(on_init_thread = default_thread_init_exit_)
let workers : worker_state array =
let dummy = Thread.self () in
Array.init num_threads (fun i ->
{
thread = dummy;
q = WSQ.create ();
work_steal_offset = (i + 1) mod num_threads;
})
{ thread = dummy; q = WSQ.create (); rng = Random.State.make [| i |] })
in
let pool =
Expand Down

0 comments on commit aba0d84

Please sign in to comment.