diff --git a/dune-project b/dune-project index 234d03ad..2fbaf0d8 100644 --- a/dune-project +++ b/dune-project @@ -30,6 +30,7 @@ (depopts (trace (>= 0.6)) thread-local-storage) + (conflicts (thread-local-storage (< 0.2))) (tags (thread pool domain futures fork-join))) diff --git a/moonpool.opam b/moonpool.opam index dcd9500f..338cddc1 100644 --- a/moonpool.opam +++ b/moonpool.opam @@ -22,6 +22,9 @@ depopts: [ "trace" {>= "0.6"} "thread-local-storage" ] +conflicts: [ + "thread-local-storage" {< "0.2"} +] build: [ ["dune" "subst"] {dev} [ diff --git a/src/core/fifo_pool.ml b/src/core/fifo_pool.ml index cd65e7ca..a16d5b08 100644 --- a/src/core/fifo_pool.ml +++ b/src/core/fifo_pool.ml @@ -31,18 +31,17 @@ let schedule_ (self : state) (task : task_full) : unit = type around_task = AT_pair : (t -> 'a) * (t -> 'a -> unit) -> around_task type worker_state = { mutable cur_ls: Task_local_storage.t option } -let k_worker_state : worker_state option ref TLS.key = - TLS.new_key (fun () -> ref None) +let k_worker_state : worker_state TLS.t = TLS.create () let worker_thread_ (self : state) (runner : t) ~on_exn ~around_task : unit = let w = { cur_ls = None } in - TLS.get k_worker_state := Some w; - TLS.get Runner.For_runner_implementors.k_cur_runner := Some runner; + TLS.set k_worker_state w; + TLS.set Runner.For_runner_implementors.k_cur_runner runner; let (AT_pair (before_task, after_task)) = around_task in let on_suspend () = - match !(TLS.get k_worker_state) with + match TLS.get_opt k_worker_state with | Some { cur_ls = Some ls; _ } -> ls | _ -> assert false in @@ -55,7 +54,7 @@ let worker_thread_ (self : state) (runner : t) ~on_exn ~around_task : unit = | T_start { ls; _ } | T_resume { ls; _ } -> ls in w.cur_ls <- Some ls; - TLS.get k_cur_storage := Some ls; + TLS.set k_cur_storage ls; let _ctx = before_task runner in (* run the task now, catching errors, handling effects *) @@ -74,7 +73,7 @@ let worker_thread_ (self : state) (runner : t) ~on_exn ~around_task : unit = on_exn e bt); after_task runner _ctx; w.cur_ls <- None; - TLS.get k_cur_storage := None + TLS.set k_cur_storage _dummy_ls in let main_loop () = diff --git a/src/core/runner.ml b/src/core/runner.ml index 0fda1be3..0bf7895c 100644 --- a/src/core/runner.ml +++ b/src/core/runner.ml @@ -40,7 +40,7 @@ module For_runner_implementors = struct let create ~size ~num_tasks ~shutdown ~run_async () : t = { size; num_tasks; shutdown; run_async } - let k_cur_runner : t option ref TLS.key = Types_.k_cur_runner + let k_cur_runner : t TLS.t = Types_.k_cur_runner end let dummy : t = diff --git a/src/core/runner.mli b/src/core/runner.mli index 9a568b8c..f0b0d099 100644 --- a/src/core/runner.mli +++ b/src/core/runner.mli @@ -73,7 +73,7 @@ module For_runner_implementors : sig {b NOTE}: the runner should support DLA and {!Suspend_} on OCaml 5.x, so that {!Fork_join} and other 5.x features work properly. *) - val k_cur_runner : t option ref Thread_local_storage_.key + val k_cur_runner : t Thread_local_storage_.t (** Key that should be used by each runner to store itself in TLS on every thread it controls, so that tasks running on these threads can access the runner. This is necessary for {!get_current_runner} diff --git a/src/core/task_local_storage.ml b/src/core/task_local_storage.ml index b62f4908..f4278d64 100644 --- a/src/core/task_local_storage.ml +++ b/src/core/task_local_storage.ml @@ -8,7 +8,7 @@ let key_count_ = A.make 0 type t = local_storage type ls_value += Dummy -let dummy : t = ref [||] +let dummy : t = _dummy_ls (** Resize array of TLS values *) let[@inline never] resize_ (cur : ls_value array ref) n = @@ -57,7 +57,9 @@ let new_key (type t) ~init () : t key = let[@inline] get_cur_ () : ls_value array ref = match get_current_storage () with - | Some r -> r + | Some r -> + assert (r != dummy); + r | None -> failwith "Task local storage must be accessed from within a runner." let[@inline] get (key : 'a key) : 'a = diff --git a/src/core/types_.ml b/src/core/types_.ml index f601d2be..08d2f09c 100644 --- a/src/core/types_.ml +++ b/src/core/types_.ml @@ -27,11 +27,9 @@ type runner = { num_tasks: unit -> int; } -let k_cur_runner : runner option ref TLS.key = TLS.new_key (fun () -> ref None) - -let k_cur_storage : local_storage option ref TLS.key = - TLS.new_key (fun () -> ref None) - -let[@inline] get_current_runner () : _ option = !(TLS.get k_cur_runner) -let[@inline] get_current_storage () : _ option = !(TLS.get k_cur_storage) +let k_cur_runner : runner TLS.t = TLS.create () +let k_cur_storage : local_storage TLS.t = TLS.create () +let _dummy_ls : local_storage = ref [||] +let[@inline] get_current_runner () : _ option = TLS.get_opt k_cur_runner +let[@inline] get_current_storage () : _ option = TLS.get_opt k_cur_storage let[@inline] create_local_storage () = ref [||] diff --git a/src/core/ws_pool.ml b/src/core/ws_pool.ml index cca913c8..de4b44cc 100644 --- a/src/core/ws_pool.ml +++ b/src/core/ws_pool.ml @@ -65,11 +65,10 @@ let num_tasks_ (self : state) : int = (** TLS, used by worker to store their specific state and be able to retrieve it from tasks when we schedule new sub-tasks. *) -let k_worker_state : worker_state option ref TLS.key = - TLS.new_key (fun () -> ref None) +let k_worker_state : worker_state TLS.t = TLS.create () let[@inline] find_current_worker_ () : worker_state option = - !(TLS.get k_worker_state) + TLS.get_opt k_worker_state (** Try to wake up a waiter, if there's any. *) let[@inline] try_wake_someone_ (self : state) : unit = @@ -121,7 +120,7 @@ let run_task_now_ (self : state) ~runner ~(w : worker_state) (task : task_full) in w.cur_ls <- Some ls; - TLS.get k_cur_storage := Some ls; + TLS.set k_cur_storage ls; let _ctx = before_task runner in let[@inline] on_suspend () : _ ref = @@ -166,7 +165,7 @@ let run_task_now_ (self : state) ~runner ~(w : worker_state) (task : task_full) after_task runner _ctx; w.cur_ls <- None; - TLS.get k_cur_storage := None + TLS.set k_cur_storage _dummy_ls let run_async_ (self : state) ~ls (f : task) : unit = let w = find_current_worker_ () in @@ -222,8 +221,8 @@ let worker_run_self_tasks_ (self : state) ~runner w : unit = (** Main loop for a worker thread. *) let worker_thread_ (self : state) ~(runner : t) (w : worker_state) : unit = - TLS.get Runner.For_runner_implementors.k_cur_runner := Some runner; - TLS.get k_worker_state := Some w; + TLS.set Runner.For_runner_implementors.k_cur_runner runner; + TLS.set k_worker_state w; let rec main () : unit = worker_run_self_tasks_ self ~runner w; @@ -358,7 +357,7 @@ let create ?(on_init_thread = default_thread_init_exit_) let thread = Thread.self () in let t_id = Thread.id thread in on_init_thread ~dom_id:dom_idx ~t_id (); - TLS.get k_cur_storage := None; + TLS.set k_cur_storage _dummy_ls; (* set thread name *) Option.iter diff --git a/src/private/thread_local_storage_.mli b/src/private/thread_local_storage_.mli index b7b50706..16aea3dd 100644 --- a/src/private/thread_local_storage_.mli +++ b/src/private/thread_local_storage_.mli @@ -1,21 +1,13 @@ (** Thread local storage *) -(* TODO: alias this to the library if present *) +type 'a t +(** A TLS slot for values of type ['a]. This allows the storage of a + single value of type ['a] per thread. *) -type 'a key -(** A TLS key for values of type ['a]. This allows the - storage of a single value of type ['a] per thread. *) +val create : unit -> 'a t -val new_key : (unit -> 'a) -> 'a key -(** Allocate a new, generative key. - When the key is used for the first time on a thread, - the function is called to produce it. +val get : 'a t -> 'a +(** @raise Failure if not present *) - This should only ever be called at toplevel to produce - constants, do not use it in a loop. *) - -val get : 'a key -> 'a -(** Get the value for the current thread. *) - -val set : 'a key -> 'a -> unit -(** Set the value for the current thread. *) +val get_opt : 'a t -> 'a option +val set : 'a t -> 'a -> unit diff --git a/src/private/thread_local_storage_.real.ml b/src/private/thread_local_storage_.real.ml index 70d7a558..09411a06 100644 --- a/src/private/thread_local_storage_.real.ml +++ b/src/private/thread_local_storage_.real.ml @@ -1,82 +1,111 @@ -(* see: https://discuss.ocaml.org/t/a-hack-to-implement-efficient-tls-thread-local-storage/13264 *) - -module A = Atomic_ +(* vendored from https://github.com/c-cube/thread-local-storage *) (* sanity check *) let () = assert (Obj.field (Obj.repr (Thread.self ())) 1 = Obj.repr ()) -type 'a key = { - index: int; (** Unique index for this key. *) - compute: unit -> 'a; - (** Initializer for values for this key. Called at most - once per thread. *) -} +type 'a t = int +(** Unique index for this TLS slot. *) + +let tls_length index = + let ceil_pow_2_minus_1 (n : int) : int = + let n = n lor (n lsr 1) in + let n = n lor (n lsr 2) in + let n = n lor (n lsr 4) in + let n = n lor (n lsr 8) in + let n = n lor (n lsr 16) in + if Sys.int_size > 32 then + n lor (n lsr 32) + else + n + in + let size = ceil_pow_2_minus_1 (index + 1) in + assert (size > index); + size (** Counter used to allocate new keys *) -let counter = A.make 0 +let counter = Atomic.make 0 -(** Value used to detect a TLS slot that was not initialized yet *) -let[@inline] sentinel_value_for_uninit_tls_ () : Obj.t = Obj.repr counter +(** Value used to detect a TLS slot that was not initialized yet. + Because [counter] is private and lives forever, no other + object the user can see will have the same address. *) +let sentinel_value_for_uninit_tls : Obj.t = Obj.repr counter -let new_key compute : _ key = - let index = A.fetch_and_add counter 1 in - { index; compute } +external max_wosize : unit -> int = "caml_sys_const_max_wosize" + +let max_word_size = max_wosize () + +let create () : _ t = + let index = Atomic.fetch_and_add counter 1 in + if tls_length index <= max_word_size then + index + else ( + (* Some platforms have a small max word size. *) + ignore (Atomic.fetch_and_add counter (-1)); + failwith "Thread_local_storage.create: out of TLS slots" + ) type thread_internal_state = { _id: int; (** Thread ID (here for padding reasons) *) mutable tls: Obj.t; (** Our data, stowed away in this unused field *) + _other: Obj.t; + (** Here to avoid lying to ocamlopt/flambda about the size of [Thread.t] *) } (** A partial representation of the internal type [Thread.t], allowing us to access the second field (unused after the thread has started) and stash TLS data in it. *) -let ceil_pow_2_minus_1 (n : int) : int = - let n = n lor (n lsr 1) in - let n = n lor (n lsr 2) in - let n = n lor (n lsr 4) in - let n = n lor (n lsr 8) in - let n = n lor (n lsr 16) in - if Sys.int_size > 32 then - n lor (n lsr 32) +let[@inline] get_raw index : Obj.t = + let thread : thread_internal_state = Obj.magic (Thread.self ()) in + let tls = thread.tls in + if Obj.is_block tls && index < Array.length (Obj.obj tls : Obj.t array) then + Array.unsafe_get (Obj.obj tls : Obj.t array) index + else + sentinel_value_for_uninit_tls + +let[@inline never] tls_error () = + failwith "Thread_local_storage.get: TLS entry not initialised" + +let[@inline] get slot = + let v = get_raw slot in + if v != sentinel_value_for_uninit_tls then + Obj.obj v + else + tls_error () + +let[@inline] get_opt slot = + let v = get_raw slot in + if v != sentinel_value_for_uninit_tls then + Some (Obj.obj v) else - n + None + +(** Allocating and setting *) (** Grow the array so that [index] is valid. *) -let[@inline never] grow_tls (old : Obj.t array) (index : int) : Obj.t array = - let new_length = ceil_pow_2_minus_1 (index + 1) in - let new_ = Array.make new_length (sentinel_value_for_uninit_tls_ ()) in +let grow (old : Obj.t array) (index : int) : Obj.t array = + let new_length = tls_length index in + let new_ = Array.make new_length sentinel_value_for_uninit_tls in Array.blit old 0 new_ 0 (Array.length old); new_ -let[@inline] get_tls_ (index : int) : Obj.t array = +let get_tls_with_capacity index : Obj.t array = let thread : thread_internal_state = Obj.magic (Thread.self ()) in let tls = thread.tls in if Obj.is_int tls then ( - let new_tls = grow_tls [||] index in - thread.tls <- Obj.magic new_tls; + let new_tls = grow [||] index in + thread.tls <- Obj.repr new_tls; new_tls ) else ( - let tls = (Obj.magic tls : Obj.t array) in + let tls = (Obj.obj tls : Obj.t array) in if index < Array.length tls then tls else ( - let new_tls = grow_tls tls index in - thread.tls <- Obj.magic new_tls; + let new_tls = grow tls index in + thread.tls <- Obj.repr new_tls; new_tls ) ) -let get key = - let tls = get_tls_ key.index in - let value = Array.unsafe_get tls key.index in - if value != sentinel_value_for_uninit_tls_ () then - Obj.magic value - else ( - let value = key.compute () in - Array.unsafe_set tls key.index (Obj.repr (Sys.opaque_identity value)); - value - ) - -let set key value = - let tls = get_tls_ key.index in - Array.unsafe_set tls key.index (Obj.repr (Sys.opaque_identity value)) +let[@inline] set slot value : unit = + let tls = get_tls_with_capacity slot in + Array.unsafe_set tls slot (Obj.repr (Sys.opaque_identity value))