Skip to content

Commit

Permalink
Improved heuristic fallback to update-on-host
Browse files Browse the repository at this point in the history
  • Loading branch information
lukstafi committed Jun 3, 2023
1 parent bb893e1 commit c2a2073
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 40 deletions.
19 changes: 5 additions & 14 deletions bin/moons_benchmark.ml
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,6 @@ let classify_moons ~random_seed ~on_device executor ~opti_level ~inlining_cutoff
result = [%sexp_of: float * float] (!min_loss, !loss);
}
in
(*
let points = SDSL.value_2d_points ~xdim:0 ~ydim:1 moons_flat in
let classes = SDSL.value_1d_points ~xdim:0 moons_classes in
let points1, points2 = Array.partitioni_tf points ~f:Float.(fun i _ -> classes.(i) > 0.) in
Expand Down Expand Up @@ -337,19 +336,19 @@ let classify_moons ~random_seed ~on_device executor ~opti_level ~inlining_cutoff
(* Stdio.printf "\nProcess memory delta: %d\n%!"
(train_mem.process_physical_memory - init_mem.process_physical_memory); *)
Exec_as_gccjit.optimization_level := 3;
Stdio.printf "\n%!";*)
Stdio.printf "\n%!";
result

let benchmark_executor = SDSL.Gccjit

let () =
let () =
Node.fixed_state_for_init := Some 14;
ignore
@@ classify_moons ~random_seed:0 ~on_device:true benchmark_executor ~opti_level:3 ~inlining_cutoff:3
@@ classify_moons ~random_seed:3 ~on_device:true benchmark_executor ~opti_level:3 ~inlining_cutoff:3
~num_parallel_tasks:20 ~per_refresh:100 CDSL.single ()

let benchmarks =
List.concat_map [ 0; 3; 5 ] ~f:(fun inlining_cutoff ->
List.concat_map [ (* 0; 3; 5 *) 3 ] ~f:(fun inlining_cutoff ->
List.concat_map [ 1; 2; 4; 8; 10; 20 ] ~f:(fun num_parallel_tasks ->
List.concat_map [ 1; 10; 100 ] ~f:(fun per_refresh ->
[
Expand All @@ -360,15 +359,7 @@ let benchmarks =
let time_of = function PrintBox_utils.Benchmark { time_in_sec; _ } -> time_in_sec

let nth_best nth bench =
let results =
[
bench ~random_seed:0 ();
bench ~random_seed:1 ();
bench ~random_seed:2 ();
bench ~random_seed:3 ();
bench ~random_seed:4 ();
]
in
let results = List.init 5 ~f:(fun random_seed -> bench ~random_seed ()) in
let sorted = List.sort results ~compare:(fun r1 r2 -> Float.compare (time_of r1) (time_of r2)) in
List.nth_exn sorted (nth - 1)

Expand Down
15 changes: 15 additions & 0 deletions lib/code.ml
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,21 @@ let to_low_level (code : t) : unit low_level =
let rec loop code =
match code with
| Accum_binop { zero_out; accum; op; lhs; rhs1; rhs2; projections } ->
let lhs_n = NodeUI.get lhs.id in
(match (accum, op) with
| Add, _ -> lhs_n.value_distributes_over_sum <- true
| Arg2, Mul ->
let rhs1_n = NodeUI.get rhs1.id in
let rhs2_n = NodeUI.get rhs2.id in
lhs_n.value_distributes_over_sum <-
(rhs1_n.value_distributes_over_sum && not rhs2_n.value_distributes_over_sum)
|| (rhs2_n.value_distributes_over_sum && not rhs1_n.value_distributes_over_sum)
| Arg2, Add ->
let rhs1_n = NodeUI.get rhs1.id in
let rhs2_n = NodeUI.get rhs2.id in
lhs_n.value_distributes_over_sum <-
rhs1_n.value_distributes_over_sum || rhs2_n.value_distributes_over_sum
| _ -> lhs_n.value_distributes_over_sum <- false);
let projections = projections () in
let lhs_idx = Shape.(derive_index projections.product_iterators projections.project_lhs) in
let rhs1_idx = Shape.(derive_index projections.product_iterators projections.project_rhs1) in
Expand Down
33 changes: 7 additions & 26 deletions lib/exec_as_gccjit.ml
Original file line number Diff line number Diff line change
Expand Up @@ -376,40 +376,21 @@ let jit_code ~name ~env ~task_id ({ ctx; func; _ } as state) initial_block (body
Block.assign !current_block device_lhs rhs
| Set (ptr, idcs, value) -> (
let host_idcs = lookup ~on_host:true env idcs in
let n = NodeUI.get ptr.id in
let distributive = NodeUI.equal_data_kind ptr.field Grad || n.value_distributes_over_sum in
match get_tensor state ~dependencies:value ~jit_code:loop_proc ~host_idcs ptr with
| tensor ->
let value = loop_float ~name ~env ~num_typ:tensor.num_typ ~is_double:tensor.is_double value in
let idcs = lookup ~on_host:false env idcs in
let device_offset = jit_array_offset ctx ~idcs ~dims:tensor.device_dims in
let device_lhs = LValue.access_array (get_ptr tensor) device_offset in
Block.assign !current_block device_lhs value
| exception Unknown_synchronization -> (
| exception Unknown_synchronization when distributive ->
(* Cache the tensor with an Update_on_host synchronization, then recompile as an update. *)
let cache_sync () =
ignore
@@ get_tensor state ~dependencies:value ~force_sync:Update_on_host ~jit_code:loop_proc ~host_idcs
ptr
in
match value with
| Binop (Add, c1, c2) ->
cache_sync ();
loop ~name
@@ Lines
[|
Set (ptr, idcs, Constant 0.0);
Set (ptr, idcs, Binop (Add, Get (ptr, idcs), c1));
Set (ptr, idcs, Binop (Add, Get (ptr, idcs), c2));
|]
| Binop (Mul, c1, c2) ->
cache_sync ();
loop ~name
@@ Lines
[|
Set (ptr, idcs, Constant 1.0);
Set (ptr, idcs, Binop (Mul, Get (ptr, idcs), c1));
Set (ptr, idcs, Binop (Mul, Get (ptr, idcs), c2));
|]
| _ -> raise Unknown_synchronization))
ignore
@@ get_tensor state ~dependencies:value ~force_sync:Update_on_host ~jit_code:loop_proc ~host_idcs
ptr;
loop ~name @@ Set (ptr, idcs, Binop (Add, Get (ptr, idcs), value)))
| Set_local (id, value) ->
let lhs, num_typ, is_double = Map.find_exn !locals id in
let value = loop_float ~name ~env ~num_typ ~is_double value in
Expand Down
3 changes: 3 additions & 0 deletions lib/nodeUI.ml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ type t = {
mutable value_never_device_only : bool;
mutable grad_never_virtual : bool;
mutable grad_never_device_only : bool;
mutable value_distributes_over_sum : bool;
(** [value_distributes_over_sum] is a heuristic marker for deciding synchronization strategies. *)
literal : bool;
(** To avoid confusion, try to maintain the following for a literal:
- empty [children],
Expand Down Expand Up @@ -144,6 +146,7 @@ let create ~(value_prec : prec) ?(grad_prec : prec option) ?(literal = false) ~n
value_never_device_only = false;
grad_never_virtual = false;
grad_never_device_only = false;
value_distributes_over_sum = false;
literal;
backend_info = "";
}
Expand Down

0 comments on commit c2a2073

Please sign in to comment.