Skip to content

Commit

Permalink
move to thread-local-storage 0.2 with get/set API
Browse files Browse the repository at this point in the history
  • Loading branch information
c-cube committed Aug 16, 2024
1 parent 3388098 commit 265d4f7
Show file tree
Hide file tree
Showing 10 changed files with 112 additions and 89 deletions.
1 change: 1 addition & 0 deletions dune-project
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
(depopts
(trace (>= 0.6))
thread-local-storage)
(conflicts (thread-local-storage (< 0.2)))
(tags
(thread pool domain futures fork-join)))

Expand Down
3 changes: 3 additions & 0 deletions moonpool.opam
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ depopts: [
"trace" {>= "0.6"}
"thread-local-storage"
]
conflicts: [
"thread-local-storage" {< "0.2"}
]
build: [
["dune" "subst"] {dev}
[
Expand Down
13 changes: 6 additions & 7 deletions src/core/fifo_pool.ml
Original file line number Diff line number Diff line change
Expand Up @@ -31,18 +31,17 @@ let schedule_ (self : state) (task : task_full) : unit =
type around_task = AT_pair : (t -> 'a) * (t -> 'a -> unit) -> around_task
type worker_state = { mutable cur_ls: Task_local_storage.t option }

let k_worker_state : worker_state option ref TLS.key =
TLS.new_key (fun () -> ref None)
let k_worker_state : worker_state TLS.t = TLS.create ()

let worker_thread_ (self : state) (runner : t) ~on_exn ~around_task : unit =
let w = { cur_ls = None } in
TLS.get k_worker_state := Some w;
TLS.get Runner.For_runner_implementors.k_cur_runner := Some runner;
TLS.set k_worker_state w;
TLS.set Runner.For_runner_implementors.k_cur_runner runner;

let (AT_pair (before_task, after_task)) = around_task in

let on_suspend () =
match !(TLS.get k_worker_state) with
match TLS.get_opt k_worker_state with
| Some { cur_ls = Some ls; _ } -> ls
| _ -> assert false
in
Expand All @@ -55,7 +54,7 @@ let worker_thread_ (self : state) (runner : t) ~on_exn ~around_task : unit =
| T_start { ls; _ } | T_resume { ls; _ } -> ls
in
w.cur_ls <- Some ls;
TLS.get k_cur_storage := Some ls;
TLS.set k_cur_storage ls;
let _ctx = before_task runner in

(* run the task now, catching errors, handling effects *)
Expand All @@ -74,7 +73,7 @@ let worker_thread_ (self : state) (runner : t) ~on_exn ~around_task : unit =
on_exn e bt);
after_task runner _ctx;
w.cur_ls <- None;
TLS.get k_cur_storage := None
TLS.set k_cur_storage _dummy_ls
in

let main_loop () =
Expand Down
2 changes: 1 addition & 1 deletion src/core/runner.ml
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ module For_runner_implementors = struct
let create ~size ~num_tasks ~shutdown ~run_async () : t =
{ size; num_tasks; shutdown; run_async }

let k_cur_runner : t option ref TLS.key = Types_.k_cur_runner
let k_cur_runner : t TLS.t = Types_.k_cur_runner
end

let dummy : t =
Expand Down
2 changes: 1 addition & 1 deletion src/core/runner.mli
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ module For_runner_implementors : sig
{b NOTE}: the runner should support DLA and {!Suspend_} on OCaml 5.x,
so that {!Fork_join} and other 5.x features work properly. *)

val k_cur_runner : t option ref Thread_local_storage_.key
val k_cur_runner : t Thread_local_storage_.t
(** Key that should be used by each runner to store itself in TLS
on every thread it controls, so that tasks running on these threads
can access the runner. This is necessary for {!get_current_runner}
Expand Down
6 changes: 4 additions & 2 deletions src/core/task_local_storage.ml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ let key_count_ = A.make 0
type t = local_storage
type ls_value += Dummy

let dummy : t = ref [||]
let dummy : t = _dummy_ls

(** Resize array of TLS values *)
let[@inline never] resize_ (cur : ls_value array ref) n =
Expand Down Expand Up @@ -57,7 +57,9 @@ let new_key (type t) ~init () : t key =

let[@inline] get_cur_ () : ls_value array ref =
match get_current_storage () with
| Some r -> r
| Some r ->
assert (r != dummy);
r
| None -> failwith "Task local storage must be accessed from within a runner."

let[@inline] get (key : 'a key) : 'a =
Expand Down
12 changes: 5 additions & 7 deletions src/core/types_.ml
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,9 @@ type runner = {
num_tasks: unit -> int;
}

let k_cur_runner : runner option ref TLS.key = TLS.new_key (fun () -> ref None)

let k_cur_storage : local_storage option ref TLS.key =
TLS.new_key (fun () -> ref None)

let[@inline] get_current_runner () : _ option = !(TLS.get k_cur_runner)
let[@inline] get_current_storage () : _ option = !(TLS.get k_cur_storage)
let k_cur_runner : runner TLS.t = TLS.create ()
let k_cur_storage : local_storage TLS.t = TLS.create ()
let _dummy_ls : local_storage = ref [||]
let[@inline] get_current_runner () : _ option = TLS.get_opt k_cur_runner
let[@inline] get_current_storage () : _ option = TLS.get_opt k_cur_storage
let[@inline] create_local_storage () = ref [||]
15 changes: 7 additions & 8 deletions src/core/ws_pool.ml
Original file line number Diff line number Diff line change
Expand Up @@ -65,11 +65,10 @@ let num_tasks_ (self : state) : int =
(** TLS, used by worker to store their specific state
and be able to retrieve it from tasks when we schedule new
sub-tasks. *)
let k_worker_state : worker_state option ref TLS.key =
TLS.new_key (fun () -> ref None)
let k_worker_state : worker_state TLS.t = TLS.create ()

let[@inline] find_current_worker_ () : worker_state option =
!(TLS.get k_worker_state)
TLS.get_opt k_worker_state

(** Try to wake up a waiter, if there's any. *)
let[@inline] try_wake_someone_ (self : state) : unit =
Expand Down Expand Up @@ -121,7 +120,7 @@ let run_task_now_ (self : state) ~runner ~(w : worker_state) (task : task_full)
in

w.cur_ls <- Some ls;
TLS.get k_cur_storage := Some ls;
TLS.set k_cur_storage ls;
let _ctx = before_task runner in

let[@inline] on_suspend () : _ ref =
Expand Down Expand Up @@ -166,7 +165,7 @@ let run_task_now_ (self : state) ~runner ~(w : worker_state) (task : task_full)

after_task runner _ctx;
w.cur_ls <- None;
TLS.get k_cur_storage := None
TLS.set k_cur_storage _dummy_ls

let run_async_ (self : state) ~ls (f : task) : unit =
let w = find_current_worker_ () in
Expand Down Expand Up @@ -222,8 +221,8 @@ let worker_run_self_tasks_ (self : state) ~runner w : unit =

(** Main loop for a worker thread. *)
let worker_thread_ (self : state) ~(runner : t) (w : worker_state) : unit =
TLS.get Runner.For_runner_implementors.k_cur_runner := Some runner;
TLS.get k_worker_state := Some w;
TLS.set Runner.For_runner_implementors.k_cur_runner runner;
TLS.set k_worker_state w;

let rec main () : unit =
worker_run_self_tasks_ self ~runner w;
Expand Down Expand Up @@ -358,7 +357,7 @@ let create ?(on_init_thread = default_thread_init_exit_)
let thread = Thread.self () in
let t_id = Thread.id thread in
on_init_thread ~dom_id:dom_idx ~t_id ();
TLS.get k_cur_storage := None;
TLS.set k_cur_storage _dummy_ls;
(* set thread name *)
Option.iter
Expand Down
24 changes: 8 additions & 16 deletions src/private/thread_local_storage_.mli
Original file line number Diff line number Diff line change
@@ -1,21 +1,13 @@
(** Thread local storage *)

(* TODO: alias this to the library if present *)
type 'a t
(** A TLS slot for values of type ['a]. This allows the storage of a
single value of type ['a] per thread. *)

type 'a key
(** A TLS key for values of type ['a]. This allows the
storage of a single value of type ['a] per thread. *)
val create : unit -> 'a t

val new_key : (unit -> 'a) -> 'a key
(** Allocate a new, generative key.
When the key is used for the first time on a thread,
the function is called to produce it.
val get : 'a t -> 'a
(** @raise Failure if not present *)

This should only ever be called at toplevel to produce
constants, do not use it in a loop. *)

val get : 'a key -> 'a
(** Get the value for the current thread. *)

val set : 'a key -> 'a -> unit
(** Set the value for the current thread. *)
val get_opt : 'a t -> 'a option
val set : 'a t -> 'a -> unit
123 changes: 76 additions & 47 deletions src/private/thread_local_storage_.real.ml
Original file line number Diff line number Diff line change
@@ -1,82 +1,111 @@
(* see: https://discuss.ocaml.org/t/a-hack-to-implement-efficient-tls-thread-local-storage/13264 *)

module A = Atomic_
(* vendored from https://github.com/c-cube/thread-local-storage *)

(* sanity check *)
let () = assert (Obj.field (Obj.repr (Thread.self ())) 1 = Obj.repr ())

type 'a key = {
index: int; (** Unique index for this key. *)
compute: unit -> 'a;
(** Initializer for values for this key. Called at most
once per thread. *)
}
type 'a t = int
(** Unique index for this TLS slot. *)

let tls_length index =
let ceil_pow_2_minus_1 (n : int) : int =
let n = n lor (n lsr 1) in
let n = n lor (n lsr 2) in
let n = n lor (n lsr 4) in
let n = n lor (n lsr 8) in
let n = n lor (n lsr 16) in
if Sys.int_size > 32 then
n lor (n lsr 32)
else
n
in
let size = ceil_pow_2_minus_1 (index + 1) in
assert (size > index);
size

(** Counter used to allocate new keys *)
let counter = A.make 0
let counter = Atomic.make 0

(** Value used to detect a TLS slot that was not initialized yet *)
let[@inline] sentinel_value_for_uninit_tls_ () : Obj.t = Obj.repr counter
(** Value used to detect a TLS slot that was not initialized yet.
Because [counter] is private and lives forever, no other
object the user can see will have the same address. *)
let sentinel_value_for_uninit_tls : Obj.t = Obj.repr counter

let new_key compute : _ key =
let index = A.fetch_and_add counter 1 in
{ index; compute }
external max_wosize : unit -> int = "caml_sys_const_max_wosize"

let max_word_size = max_wosize ()

let create () : _ t =
let index = Atomic.fetch_and_add counter 1 in
if tls_length index <= max_word_size then
index
else (
(* Some platforms have a small max word size. *)
ignore (Atomic.fetch_and_add counter (-1));
failwith "Thread_local_storage.create: out of TLS slots"
)

type thread_internal_state = {
_id: int; (** Thread ID (here for padding reasons) *)
mutable tls: Obj.t; (** Our data, stowed away in this unused field *)
_other: Obj.t;
(** Here to avoid lying to ocamlopt/flambda about the size of [Thread.t] *)
}
(** A partial representation of the internal type [Thread.t], allowing
us to access the second field (unused after the thread
has started) and stash TLS data in it. *)

let ceil_pow_2_minus_1 (n : int) : int =
let n = n lor (n lsr 1) in
let n = n lor (n lsr 2) in
let n = n lor (n lsr 4) in
let n = n lor (n lsr 8) in
let n = n lor (n lsr 16) in
if Sys.int_size > 32 then
n lor (n lsr 32)
let[@inline] get_raw index : Obj.t =
let thread : thread_internal_state = Obj.magic (Thread.self ()) in
let tls = thread.tls in
if Obj.is_block tls && index < Array.length (Obj.obj tls : Obj.t array) then
Array.unsafe_get (Obj.obj tls : Obj.t array) index
else
sentinel_value_for_uninit_tls

let[@inline never] tls_error () =
failwith "Thread_local_storage.get: TLS entry not initialised"

let[@inline] get slot =
let v = get_raw slot in
if v != sentinel_value_for_uninit_tls then
Obj.obj v
else
tls_error ()

let[@inline] get_opt slot =
let v = get_raw slot in
if v != sentinel_value_for_uninit_tls then
Some (Obj.obj v)
else
n
None

(** Allocating and setting *)

(** Grow the array so that [index] is valid. *)
let[@inline never] grow_tls (old : Obj.t array) (index : int) : Obj.t array =
let new_length = ceil_pow_2_minus_1 (index + 1) in
let new_ = Array.make new_length (sentinel_value_for_uninit_tls_ ()) in
let grow (old : Obj.t array) (index : int) : Obj.t array =
let new_length = tls_length index in
let new_ = Array.make new_length sentinel_value_for_uninit_tls in
Array.blit old 0 new_ 0 (Array.length old);
new_

let[@inline] get_tls_ (index : int) : Obj.t array =
let get_tls_with_capacity index : Obj.t array =
let thread : thread_internal_state = Obj.magic (Thread.self ()) in
let tls = thread.tls in
if Obj.is_int tls then (
let new_tls = grow_tls [||] index in
thread.tls <- Obj.magic new_tls;
let new_tls = grow [||] index in
thread.tls <- Obj.repr new_tls;
new_tls
) else (
let tls = (Obj.magic tls : Obj.t array) in
let tls = (Obj.obj tls : Obj.t array) in
if index < Array.length tls then
tls
else (
let new_tls = grow_tls tls index in
thread.tls <- Obj.magic new_tls;
let new_tls = grow tls index in
thread.tls <- Obj.repr new_tls;
new_tls
)
)

let get key =
let tls = get_tls_ key.index in
let value = Array.unsafe_get tls key.index in
if value != sentinel_value_for_uninit_tls_ () then
Obj.magic value
else (
let value = key.compute () in
Array.unsafe_set tls key.index (Obj.repr (Sys.opaque_identity value));
value
)

let set key value =
let tls = get_tls_ key.index in
Array.unsafe_set tls key.index (Obj.repr (Sys.opaque_identity value))
let[@inline] set slot value : unit =
let tls = get_tls_with_capacity slot in
Array.unsafe_set tls slot (Obj.repr (Sys.opaque_identity value))

0 comments on commit 265d4f7

Please sign in to comment.