Skip to content

Commit

Permalink
Eio.Workpool
Browse files Browse the repository at this point in the history
  • Loading branch information
SGrondin committed Sep 2, 2023
1 parent 19c42eb commit 925f912
Show file tree
Hide file tree
Showing 5 changed files with 519 additions and 0 deletions.
1 change: 1 addition & 0 deletions lib_eio/eio.ml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ module Condition = Condition
module Stream = Stream
module Lazy = Lazy
module Pool = Pool
module Workpool = Workpool
module Exn = Exn
module Resource = Resource
module Flow = Flow
Expand Down
3 changes: 3 additions & 0 deletions lib_eio/eio.mli
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,9 @@ module Pool = Pool
(** Cancelling fibers. *)
module Cancel = Eio__core.Cancel

(** A high-level domain workpool *)
module Workpool = Workpool

(** Commonly used standard features. This module is intended to be [open]ed. *)
module Std = Std

Expand Down
140 changes: 140 additions & 0 deletions lib_eio/workpool.ml
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
type job = Pack : (unit -> 'a) * ('a, exn) Result.t Promise.u -> job

type action =
| Process of job
| Quit of {
atomic: int Atomic.t;
target: int;
all_done: unit Promise.u;
}

(* Worker: 1 domain/thread
m jobs per worker, n domains per workpool *)

type t = {
sw: Switch.t;
(* The work queue *)
stream: action Stream.t;
(* Number of domains. Depending on settings, domains may run more than 1 job at a time. *)
domain_count: int;
(* True when [Workpool.terminate] has been called. *)
is_terminating: bool Atomic.t;
(* Resolved when the workpool begins terminating. *)
terminating: action Promise.t * action Promise.u;
(* Resolved when the workpool has terminated. *)
terminated: unit Promise.t * unit Promise.u;
}

let reject (Pack (_, w)) = Promise.resolve_error w (Failure "Workpool.terminate called")

(* This function is the core of workpool.ml.
Each worker recursively calls [loop ()] until the [terminating]
promise is resolved. Workers pull one job at a time from the Stream. *)
let start_worker ~limit ~terminating stream =
Switch.run @@ fun sw ->
let capacity = Semaphore.make limit in
let run_job job w =
Fiber.fork ~sw (fun () ->
Promise.resolve w
(try Ok (job ()) with
| exn -> Error exn);
Semaphore.release capacity )
in
(* The main worker loop. *)
let rec loop () =
let actions = Fiber.n_any [ (fun () -> Promise.await terminating); (fun () -> Semaphore.acquire capacity; Stream.take stream) ] in
match actions with
| [ Process (Pack (job, w)) ] ->
(* We start the job right away. This also gives a chance to other domains
to start waiting on the Stream before the current thread blocks on [Stream.take] again. *)
run_job job w;
(loop [@tailcall]) ()
| Quit { atomic; target; all_done } :: maybe_job ->
List.iter
(function
| Process job -> reject job
| _ -> assert false)
maybe_job;
(* Wait until the completion of all of this worker's jobs. *)
Switch.on_release sw (fun () ->
(* If we're the last worker terminating, resolve the promise. *)
if Atomic.fetch_and_add atomic 1 = target then Promise.resolve all_done ()
)
| _ -> assert false
in
loop ()

(* Start a new domain. The worker will need a switch, then we start the worker. *)
let start_domain ~sw ~domain_mgr ~limit ~terminating ~transient stream =
let go () =
Domain_manager.run domain_mgr (fun () -> start_worker ~limit ~terminating stream )
in
(* [transient] workpools run as daemons to not hold the user's switch from completing.
It's up to the user to hold the switch open (and thus, the workpool)
by blocking on the jobs issued to the workpool.
[Workpool.submit] and [Workpool.submit_exn] will block so this shouldn't be a problem.
Still, the user can call [Workpool.create] with [~transient:false] to
disable this behavior, in which case the user must call [Workpool.terminate]
to release the switch. *)
match transient with
| false -> Fiber.fork ~sw go
| true ->
Fiber.fork_daemon ~sw (fun () ->
go ();
`Stop_daemon )

let create ~sw ~domain_count ~domain_concurrency ?(transient = true) domain_mgr =
let stream = Stream.create 0 in
let instance =
{
sw;
stream;
domain_count;
is_terminating = Atomic.make false;
terminating = Promise.create ();
terminated = Promise.create ();
}
in
let terminating = fst instance.terminating in
for _ = 1 to domain_count do
start_domain ~sw ~domain_mgr ~limit:domain_concurrency ~terminating ~transient stream
done;
instance

let submit_fork ~sw { stream; _ } f =
let p, w = Promise.create () in
Fiber.fork_promise ~sw (fun () ->
Stream.add stream (Process (Pack (f, w)));
Promise.await_exn p )

let submit { stream; _ } f =
let p, w = Promise.create () in
Stream.add stream (Process (Pack (f, w)));
Promise.await p

let submit_exn instance f =
match submit instance f with
| Ok x -> x
| Error exn -> raise exn

let terminate ({ terminating = _, w1; terminated = p2, w2; _ } as instance) =
if Atomic.compare_and_set instance.is_terminating false true
then (
(* Instruct workers to shutdown *)
Promise.resolve w1 (Quit { atomic = Atomic.make 1; target = instance.domain_count; all_done = w2 });
(* Reject all present and future queued jobs *)
Fiber.fork_daemon ~sw:instance.sw (fun () ->
while true do
match Stream.take instance.stream with
| Process job -> reject job
| _ -> assert false
done;
`Stop_daemon );
(* Wait for all workers to have shutdown *)
Promise.await p2 )
else (* [Workpool.terminate] was called more than once. *)
Promise.await p2

let is_terminating { terminating = p, _; _ } = Promise.is_resolved p

let is_terminated { terminated = p, _; _ } = Promise.is_resolved p
37 changes: 37 additions & 0 deletions lib_eio/workpool.mli
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
type t

(** Creates a new workpool with [domain_count].
[domain_concurrency] is the maximum number of jobs that each domain can run at a time.
[transient] (default: true). When true, the workpool will not block the [~sw] Switch from completing.
When false, you must call [terminate] to release the [~sw] Switch. *)
val create :
sw:Switch.t ->
domain_count:int ->
domain_concurrency:int ->
?transient:bool ->
_ Domain_manager.t ->
t

(** Run a job on this workpool. It is placed at the end of the queue. *)
val submit : t -> (unit -> 'a) -> ('a, exn) result

(** Same as [submit] but raises if the job failed. *)
val submit_exn : t -> (unit -> 'a) -> 'a

(** Same as [submit] but returns immediately, without blocking. *)
val submit_fork : sw:Switch.t -> t -> (unit -> 'a) -> ('a, exn) result Promise.t

(** Waits for all running jobs to complete, then returns.
No new jobs are started, even if they were already enqueued.
To abort all running jobs instead of waiting for them, call [Switch.fail] on the Switch used to create this workpool *)
val terminate : t -> unit

(** Returns true if the [terminate] function has been called on this workpool.
Also returns true if the workpool has fully terminated. *)
val is_terminating : t -> bool

(** Returns true if the [terminate] function has been called on this workpool AND
the workpool has fully terminated (all running jobs have completed). *)
val is_terminated : t -> bool
Loading

0 comments on commit 925f912

Please sign in to comment.