Skip to content

Commit

Permalink
Move Tnode.task -> Task.t
Browse files Browse the repository at this point in the history
  • Loading branch information
lukstafi committed Oct 11, 2024
1 parent 4c4dd2c commit d54b5e0
Show file tree
Hide file tree
Showing 12 changed files with 74 additions and 53 deletions.
1 change: 1 addition & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
- Migrated to cudajit 0.5.
- Verifying that code is linked with the right contexts, by tracking `embedded_nodes` with assignments.
- Renaming: (virtual) `device` -> `stream`, `physical_device` -> `device`.
- New files: split out `backend_types.ml` from `backends.ml`; moved `Tnode.task` to `task.ml`; TODO: renamed `backend_utils.ml` to `c_syntax.ml`.
- TODO: Moved the multicore backend from a `device = stream` model to a single device model.
- TODO: Fixed #286: cross-stream-sharing incorporated into `Tnode.memory_mode`.
- TODO: Built per-tensor-node stream-to-stream synchronization into device-to-device copying functions, removed obsolete blocking synchronizations.
Expand Down
8 changes: 4 additions & 4 deletions arrayjit/lib/backend_types.ml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ let _get_local_debug_runtime = Utils._get_local_debug_runtime
module Types = struct
type 'context routine = {
context : 'context;
schedule : Tnode.task;
schedule : Task.t;
bindings : Indexing.lowered_bindings;
name : string;
}
Expand Down Expand Up @@ -232,7 +232,7 @@ module type Lowered_no_device_backend = sig
merge_buffer:(buffer_ptr * Tnode.t) option ref ->
context ->
procedure ->
context * Indexing.lowered_bindings * Tnode.task * string
context * Indexing.lowered_bindings * Task.t * string

val name : string
val initialize : unit -> unit
Expand Down Expand Up @@ -273,10 +273,10 @@ module type Lowered_backend = sig

val is_in_context : Low_level.traced_array -> bool
val ctx_arrays : context -> ctx_array Map.M(Tnode).t
val link : context -> code -> context * Indexing.lowered_bindings * Tnode.task
val link : context -> code -> context * Indexing.lowered_bindings * Task.t

val link_batch :
context -> code_batch -> context * Indexing.lowered_bindings * Tnode.task option array
context -> code_batch -> context * Indexing.lowered_bindings * Task.t option array

val unsafe_cleanup : unit -> unit

Expand Down
29 changes: 10 additions & 19 deletions arrayjit/lib/backends.ml
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,12 @@ module Multicore_backend (Backend : Backend_types.No_device_backend) : Backend_t
struct
module Domain = Domain [@warning "-3"]

type task_list = Tnode.task Utils.mutable_list [@@deriving sexp_of]
type task_list = Task.t Utils.mutable_list [@@deriving sexp_of]

module Mut = Stdlib.Mutex
module Queue = Saturn_lockfree.Single_prod_single_cons_queue

type task_queue = Tnode.task Queue.t
type task_queue = Task.t Queue.t

let sexp_of_task_queue q =
Sexp.(List [ Atom "task_queue_of_size"; Atom (Int.to_string @@ Queue.size q) ])
Expand Down Expand Up @@ -95,7 +95,7 @@ struct

let%track3_l_sexp schedule_task stream task =
assert (Domain.is_main_domain ());
[%log_result "schedule_task", Tnode.describe task, "stream", (stream.ordinal : int)];
[%log_result "schedule_task", Task.describe task, "stream", (stream.ordinal : int)];
let d = stream.state in
Option.iter d.stream_error ~f:(fun e ->
Exn.reraise e @@ name ^ " stream " ^ Int.to_string stream.ordinal);
Expand Down Expand Up @@ -137,7 +137,7 @@ struct
done;
state.is_ready <- false;
Mut.unlock state.mut
| Some task -> Tnode.run task
| Some task -> Task.run task
done
with e ->
state.stream_error <- Some e;
Expand All @@ -155,16 +155,6 @@ struct
allocated_buffer = None;
}

let%track3_l_sexp make_work stream (Tnode.Task { description; _ } as task) =
[%log_result "make_work", description, "stream", (stream.ordinal : int)];
let work () = schedule_task stream task in
Tnode.Task
{
context_lifetime = task;
description = "schedules {" ^ description ^ "} on stream " ^ Int.to_string stream.ordinal;
work;
}

type context = { stream : stream; ctx : Backend.context; expected_merge_node : Tnode.t option }
[@@deriving sexp_of]

Expand All @@ -186,14 +176,15 @@ struct

let compile = Backend.compile
let compile_batch = Backend.compile_batch
let get_stream_name s = "stream " ^ Int.to_string s.ordinal

let link { ctx; stream; expected_merge_node = _ } code =
let task = Backend.link ~merge_buffer:stream.merge_buffer ctx code in
{
task with
context =
{ ctx = task.context; stream; expected_merge_node = Backend.expected_merge_node code };
schedule = make_work stream task.schedule;
schedule = Task.enschedule ~schedule_task ~get_stream_name stream task.schedule;
}

let link_batch { ctx; stream; expected_merge_node } code_batch =
Expand All @@ -205,7 +196,7 @@ struct
{
task with
context = { ctx = task.context; stream; expected_merge_node = merge_nodes.(i) };
schedule = make_work stream task.schedule;
schedule = Task.enschedule ~schedule_task ~get_stream_name stream task.schedule;
})) )

let from_host (context : context) (tn : Tnode.t) =
Expand All @@ -226,7 +217,7 @@ struct
Ndarray.render_array ~indices h_arr]]]
in
schedule_task context.stream
(Tnode.Task
(Task.Task
{
context_lifetime = context;
description =
Expand Down Expand Up @@ -260,7 +251,7 @@ struct
Ndarray.render_array ~indices h_arr]]]
in
schedule_task context.stream
(Tnode.Task
(Task.Task
{
context_lifetime = context;
description =
Expand Down Expand Up @@ -315,7 +306,7 @@ struct
"device_to_device " ^ Tnode.debug_name tn ^ " dst " ^ Int.to_string dev.ordinal ^ " src "
^ Int.to_string src.stream.ordinal
in
schedule_task dev (Tnode.Task { context_lifetime = (src, dst); description; work })
schedule_task dev (Task.Task { context_lifetime = (src, dst); description; work })
in
match (Backend.get_buffer tn dst.ctx, Backend.get_buffer tn src.ctx) with
| Some dst, Some _ ->
Expand Down
3 changes: 1 addition & 2 deletions arrayjit/lib/cc_backend.ml
Original file line number Diff line number Diff line change
Expand Up @@ -277,15 +277,14 @@ let%diagn_sexp link_compiled ~merge_buffer (prior_context : context) (code : pro
in
let%diagn_l_sexp work () : unit =
[%log_result name];
Backend_utils.check_merge_buffer ~merge_buffer ~code_node:code.lowered.merge_node;
Indexing.apply run_variadic ();
if Utils.debug_log_from_routines () then (
Utils.log_trace_tree (Stdio.In_channel.read_lines log_file_name);
Stdlib.Sys.remove log_file_name)
in
( context,
Indexing.lowered_bindings code.bindings run_variadic,
Tn.Task
Task.Task
{
(* In particular, keep code alive so it doesn't get unloaded. *)
context_lifetime = (context, code);
Expand Down
2 changes: 1 addition & 1 deletion arrayjit/lib/cuda_backend.cudajit.ml
Original file line number Diff line number Diff line change
Expand Up @@ -569,7 +569,7 @@ let link_proc ~prior_context ~name ~(params : (string * param_source) list) ~ctx
[%log "kernel launched"]
in
( context,
Tn.Task
Task.Task
{
context_lifetime = context;
description = "launches " ^ name ^ " on " ^ context.label;
Expand Down
4 changes: 2 additions & 2 deletions arrayjit/lib/cuda_backend.missing.ml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ let ctx_arrays Unimplemented_ctx = Map.empty (module Tnode)
let link (Unimplemented_ctx : context) (code : code) =
let lowered_bindings = List.map ~f:(fun s -> (s, ref 0)) @@ Indexing.bound_symbols code in
let task =
Tnode.Task
Task.Task
{
context_lifetime = ();
description = "CUDA missing: install cudajit";
Expand All @@ -42,7 +42,7 @@ let link_batch (Unimplemented_ctx : context) (code_batch : code_batch) =
let task =
Array.map code_batch ~f:(fun _ ->
Some
(Tnode.Task
(Task.Task
{
context_lifetime = ();
description = "CUDA missing: install cudajit";
Expand Down
1 change: 1 addition & 0 deletions arrayjit/lib/dune
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
tnode
low_level
assignments
task
backend_types
backend_utils
cc_backend
Expand Down
2 changes: 1 addition & 1 deletion arrayjit/lib/gcc_backend.gccjit.ml
Original file line number Diff line number Diff line change
Expand Up @@ -866,7 +866,7 @@ let%diagn_sexp link_compiled ~merge_buffer (prior_context : context) (code : pro
in
( context,
Indexing.lowered_bindings code.bindings run_variadic,
Tn.Task
Task.Task
{
context_lifetime = context;
description = "executes " ^ code.name ^ " on " ^ context.label;
Expand Down
43 changes: 43 additions & 0 deletions arrayjit/lib/task.ml
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
open Base
module Lazy = Utils.Lazy
module Debug_runtime = Utils.Debug_runtime

let _get_local_debug_runtime = Utils._get_local_debug_runtime

[%%global_debug_log_level 9]
[%%global_debug_log_level_from_env_var "OCANNL_LOG_LEVEL"]

type t =
| Task : {
context_lifetime : ('a[@sexp.opaque]);
description : string;
work : unit -> unit;
}
-> t
[@@deriving sexp_of]

let describe (Task task) = task.description

let%diagn_l_sexp run (Task task) =
[%log_result "run", task.description];
task.work ()

let prepend ~work (Task task) =
Task
{
task with
work =
(fun () ->
work ();
task.work ());
}

let%track3_l_sexp enschedule ~schedule_task ~get_stream_name stream (Task { description; _ } as task) =
[%log_result "enschedule", description, "on", get_stream_name stream];
let work () = schedule_task stream task in
Task
{
context_lifetime = task;
description = "schedules {" ^ description ^ "} on " ^ get_stream_name stream;
work;
}
15 changes: 0 additions & 15 deletions arrayjit/lib/tnode.ml
Original file line number Diff line number Diff line change
Expand Up @@ -8,21 +8,6 @@ let _get_local_debug_runtime = Utils._get_local_debug_runtime
[%%global_debug_log_level 9]
[%%global_debug_log_level_from_env_var "OCANNL_LOG_LEVEL"]

type task =
| Task : {
context_lifetime : ('a[@sexp.opaque]);
description : string;
work : unit -> unit;
}
-> task
[@@deriving sexp_of]

let describe (Task task) = task.description

let%diagn_l_sexp run (Task task) =
[%log_result "run", task.description];
task.work ()

type memory_type =
| Constant (** The tensor node does not change after initialization. *)
| Nonconstant (** One of: [Changed_on_devices], [Volatile]. *)
Expand Down
4 changes: 2 additions & 2 deletions arrayjit/lib/writing_a_backend.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ Currently, OCANNL integrates new backends via code in [Backends](backends.ml), s
type lowered_bindings = (static_symbol, int ref) List.Assoc.t (* in indexing.ml *)
type task =
| Task : { context_lifetime : 'a; description : string; work : unit -> unit; } -> task (* in tnode.ml *)
| Task.t : { context_lifetime : 'a; description : string; work : unit -> unit; } -> task (* in tnode.ml *)
type 'context routine = {
context : 'context;
Expand Down Expand Up @@ -190,7 +190,7 @@ module type Lowered_no_device_backend = sig
...
val link_compiled :
context -> procedure -> context * Indexing.lowered_bindings * Tnode.task * string
context -> procedure -> context * Indexing.lowered_bindings * Task.t * string
...
end
Expand Down
15 changes: 8 additions & 7 deletions lib/train.ml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ module Nd = Arrayjit.Ndarray
module NTDSL = Operation.NTDSL
module Asgns = Arrayjit.Assignments
module Idx = Arrayjit.Indexing
module Task = Arrayjit.Task
module Utils = Arrayjit.Utils
module Rand = Arrayjit.Rand.Lib
module BT = Arrayjit.Backend_types.Types
Expand Down Expand Up @@ -43,7 +44,7 @@ module IDX = struct
let find_exn = Idx.find_exn
end

let run jitted = Tn.run jitted.BT.schedule
let run jitted = Task.run jitted.BT.schedule

let is_param t =
match t with
Expand Down Expand Up @@ -284,10 +285,10 @@ let%track3_sexp sync_run ?looping (type context)
(module Backend : Backend_type with type context = context) (routine : Backend.routine) t =
all_host_to_device (module Backend) routine.context t;
(match looping with
| None -> Tn.run routine.schedule
| None -> Task.run routine.schedule
| Some then_ ->
let f () =
Tn.run routine.schedule;
Task.run routine.schedule;
then_ ()
in
sequential_loop ~f routine.bindings);
Expand Down Expand Up @@ -366,20 +367,20 @@ let%track3_sexp parallel_update (type context)
assert (
Backend.device_to_device (Option.value_exn ~here:[%here] p.diff).grad ~into_merge_buffer
~dst:grad_merge.context ~src:ctxs.(from));
(Tn.run grad_merge.schedule : unit))
(Task.run grad_merge.schedule : unit))
in
let merge_loss ~src =
assert (
Backend.device_to_device updaten.loss.value ~into_merge_buffer ~dst:loss_merge.context ~src);
Tn.run loss_merge.schedule
Task.run loss_merge.schedule
in
(* FIXME: missing backcopy. *)
let needed_on_host = ref @@ Set.empty (module Tn) in
let%track3_sexp sync (devices_to_sync : int) : unit =
Arrayjit.Utils.parallel_merge merge_grads devices_to_sync;
(* We need to wait, because copying happens on other devices. *)
Array.iteri ctxs ~f:(fun i src -> if i <> 0 then Backend.(await @@ get_ctx_stream src));
Tn.run sgd_update.schedule;
Task.run sgd_update.schedule;
Array.iteri ctxs ~f:(fun i src -> if i <> 0 then merge_loss ~src);
Set.iter !needed_on_host ~f:(fun p -> assert (Backend.to_host sgd_update.context p));
Backend.(await @@ get_ctx_stream sgd_update.context);
Expand All @@ -394,7 +395,7 @@ let%track3_sexp parallel_update (type context)
post_sync ~num_synced_devices:devices_to_sync
in
let lowered_bindings = [%debug_notrace Array.map grad_updates ~f:(fun upd -> upd.bindings)] in
let fs = [%debug_notrace Array.map grad_updates ~f:(fun upd () -> Tn.run upd.schedule)] in
let fs = [%debug_notrace Array.map grad_updates ~f:(fun upd () -> Task.run upd.schedule)] in
fun () -> round_robin fs lowered_bindings sgd_update.bindings ~sync

let get_all_suggested_streams ?(max_num_streams : int option) (type device stream)
Expand Down

0 comments on commit d54b5e0

Please sign in to comment.