Skip to content

Commit

Permalink
More fine-grained refactoring of backend APIs
Browse files Browse the repository at this point in the history
  • Loading branch information
lukstafi committed Oct 20, 2024
1 parent 3cb9936 commit 1416586
Show file tree
Hide file tree
Showing 7 changed files with 107 additions and 101 deletions.
2 changes: 1 addition & 1 deletion CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
- Got rid of `unsafe_cleanup`.
- Got rid of `subordinal`.
- Removed dependency on `core`, broke up dependency on `ppx_jane`.
- TODO: Built per-tensor-node stream-to-stream synchronization into device-to-device copying functions, removed obsolete blocking synchronizations.
- TODO: Built per-tensor-node stream-to-stream synchronization into copying functions, removed obsolete blocking synchronizations.

### Fixed

Expand Down
176 changes: 80 additions & 96 deletions arrayjit/lib/backend_types.ml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ module Types = struct
}
[@@deriving sexp_of]

(** For now, we only configure a backend with regard to how many streams it should suggest using
(where applicable). *)
type config = Only_devices_parallel | For_parallel_copying | Most_parallel_streams
[@@deriving equal, sexp, variants]

Expand All @@ -39,12 +41,10 @@ module Types = struct
[@@deriving sexp_of]
end

module type Backend_common = sig
type code [@@deriving sexp_of]
type code_batch [@@deriving sexp_of]
(** Parts shared by both assignments-level and lowered-level backend interfaces. *)
module type Backend_any_common = sig
type buffer_ptr [@@deriving sexp_of]
type context [@@deriving sexp_of]
type routine = context Types.routine [@@deriving sexp_of]
type stream

type init_info
Expand All @@ -67,9 +67,15 @@ module type Backend_common = sig
(** Finalizes (just) the context. *)

val alloc_buffer : ?old_buffer:buffer_ptr * int -> size_in_bytes:int -> stream -> buffer_ptr
end

val get_used_memory : unit -> int
(** Returns (an upper bound of) the memory used for arrays, in bytes. *)
(** Parts shared by assignments-level backend interfaces. *)
module type Backend_common = sig
include Backend_any_common

type routine = context Types.routine [@@deriving sexp_of]
type code [@@deriving sexp_of]
type code_batch [@@deriving sexp_of]

val compile : ?shared:bool -> ?name:string -> Indexing.unit_bindings -> Assignments.comp -> code
(** If [~shared:true] (default [false]), the backend should prefer to do more compile work in a
Expand All @@ -89,6 +95,7 @@ module type Backend_common = sig
[occupancy] returns true are included. *)
end

(** An intermediate interface for stream-agnostic (typically CPU) backend implementations. *)
module type No_device_backend = sig
include Backend_common with type init_info := string and type stream := unit

Expand All @@ -104,23 +111,21 @@ module type No_device_backend = sig
downstream of all the returned routines (in particular, the routines' contexts are not
independent). *)

val get_used_memory : unit -> int
(** Returns (an upper bound of) the memory used for arrays, in bytes. *)

val to_buffer : Tnode.t -> dst:buffer_ptr -> src:context -> unit
val host_to_buffer : Ndarray.t -> dst:buffer_ptr -> unit
val buffer_to_host : Ndarray.t -> src:buffer_ptr -> unit
val get_buffer : Tnode.t -> context -> buffer_ptr option
end

module type Backend = sig
(** Parts shared by both assignments-level and lowered-level backend interfaces providing streams
and devices. *)
module type Backend_device_common = sig
type stream [@@deriving sexp_of]

include Backend_common with type init_info := stream and type stream := stream

val link : context -> code -> routine
(** Returns the routine for the code's procedure, in a new context derived from the given context. *)

val link_batch : context -> code_batch -> context * routine option array
(** Returns the routines for the procedures included in the code batch. The returned context is
downstream of all the returned routines. *)
include Backend_any_common with type init_info := stream and type stream := stream

type event
(** An event tracks if a stream finished computing past a particular point in its schedue. These
Expand All @@ -147,6 +152,51 @@ module type Backend = sig
called internally when necessary. But there is one exception, see {!device_to_device} when
[into_merge_buffer=Streaming]. *)

type device

val get_used_memory : device -> int
(** Returns (an upper bound of) the memory used for arrays, in bytes. *)

val await : stream -> unit
(** Blocks till the stream becomes idle, i.e. synchronizes the stream. *)

val all_work : stream -> event
(** Returns the event indicating if any currently running or scheduled computations on the stream
have completed. *)

val is_idle : stream -> bool
(** Whether the stream is currently waiting for work. *)

val get_device : ordinal:int -> device
val num_devices : unit -> int

val suggested_num_streams : device -> int
(** The optimal number of streams for the given device to follow the {!Types.config} strategy
passed to {!No_device_backend.initialize}. *)

val new_stream : device -> stream
val get_ctx_stream : context -> stream
val get_stream_device : stream -> device
val to_ordinal : device -> int
val get_name : stream -> string
end

module type Backend = sig
include Backend_device_common

include
Backend_common
with type context := context
and type init_info := stream
and type stream := stream

val link : context -> code -> routine
(** Returns the routine for the code's procedure, in a new context derived from the given context. *)

val link_batch : context -> code_batch -> context * routine option array
(** Returns the routines for the procedures included in the code batch. The returned context is
downstream of all the returned routines. *)

val from_host : context -> Tnode.t -> bool
(** If the tensor node is both hosted and in-context, schedules a copy from host to context and
returns true, otherwise returns false. NOTE: it's the caller's responsibility to synchronize
Expand Down Expand Up @@ -175,47 +225,16 @@ module type Backend = sig
NOTE: If [into_merge_buffer=Streaming], after scheduling the work on [dst] using the merge
buffer but before scheduling work on [src] that modifies [tn], execute
[will_wait_for src (all_work (get_ctx_stream dst))]. *)

type device

val get_used_memory : device -> int
(** Returns (an upper bound of) the memory used for arrays, in bytes. *)

val await : stream -> unit
(** Blocks till the stream becomes idle, i.e. synchronizes the stream. *)

val all_work : stream -> event
(** Returns the event indicating if any currently running or scheduled computations on the stream
have completed. *)

val is_idle : stream -> bool
(** Whether the stream is currently waiting for work. *)

val get_device : ordinal:int -> device
val num_devices : unit -> int

val suggested_num_streams : device -> int
(** The optimal number of streams for the given device to follow the {!Types.config} strategy
passed to {!No_device_backend.initialize}. *)

val new_stream : device -> stream
val get_ctx_stream : context -> stream
val get_stream_device : stream -> device
val to_ordinal : device -> int
val get_name : stream -> string
end

(** Parts shared by lowered-level backends excluding what's already in {!Backend_any_common}. *)
module type Lowered_backend_common = sig
type context [@@deriving sexp_of]
type ctx_array [@@deriving sexp_of]
type ctx_arrays [@@deriving sexp_of]
type buffer_ptr [@@deriving sexp_of]
type config
type init_info
type stream
type buffer_ptr

val buffer_ptr : ctx_array -> buffer_ptr
val alloc_buffer : ?old_buffer:buffer_ptr * int -> size_in_bytes:int -> stream -> buffer_ptr
val ctx_arrays : context -> ctx_arrays
val get_array : ctx_arrays -> Tnode.t -> ctx_array option

Expand All @@ -224,20 +243,18 @@ module type Lowered_backend_common = sig
Should return false for nodes that are virtual, local, or which the backend prefers to access
directly from the host. *)

val initialize : config -> unit
val is_initialized : unit -> bool
val init : init_info -> context
val finalize : context -> unit
val name : string
end

(** Lowered-level stream agnostic backend interface: implementation-facing API for CPU backends. *)
module type Lowered_no_device_backend = sig
include Lowered_backend_common

include
Lowered_backend_common
with type stream := unit
and type config := unit
Backend_any_common
with type context := context
and type stream := unit
and type init_info := string
and type buffer_ptr := buffer_ptr

type procedure [@@deriving sexp_of]

Expand Down Expand Up @@ -266,27 +283,15 @@ module type Lowered_no_device_backend = sig
val buffer_to_host : Ndarray.t -> src:buffer_ptr -> unit
end

(** Lowered-level backend interface: implementation-facing API for device-based (typically GPU)
backends. *)
module type Lowered_backend = sig
type stream [@@deriving sexp_of]

include
Lowered_backend_common
with type config := Types.config
and type stream := stream
and type init_info := stream

include Lowered_backend_common
include Backend_device_common with type context := context and type buffer_ptr := buffer_ptr

type code [@@deriving sexp_of]
type code_batch [@@deriving sexp_of]
type event

val sync : event -> unit
val is_done : event -> bool
val work_for : context -> Tnode.t -> event option
val will_wait_for : context -> event -> unit

open Types

val sexp_of_context : context -> Sexplib.Sexp.t
val compile : name:string -> Indexing.unit_bindings -> Low_level.optimized -> code

val compile_batch :
Expand All @@ -301,34 +306,13 @@ module type Lowered_backend = sig
context -> code_batch -> context * Indexing.lowered_bindings * Task.t option array

val from_host : context -> Tnode.t -> bool
(** If the array is both hosted and in-context, copies from host to context. *)

val to_host : context -> Tnode.t -> bool
(** If the array is both hosted and in-context, copies from context to host. *)

val device_to_device :
Tnode.t -> into_merge_buffer:merge_buffer_use -> dst:context -> src:context -> bool
(** See {!Backend.device_to_device}. *)

type device

val get_used_memory : device -> int
(** Returns (an upper bound of) the memory used for arrays, in bytes. *)

val await : stream -> unit
val is_idle : stream -> bool
val all_work : stream -> event
Tnode.t -> into_merge_buffer:Types.merge_buffer_use -> dst:context -> src:context -> bool

val scheduled_merge_node : stream -> Tnode.t option
(** [scheduled_merge_node stream] is the tensor node that would be in the [stream]'s merge buffer
right after [await stream]. *)

val num_devices : unit -> int
val suggested_num_streams : device -> int
val get_device : ordinal:int -> device
val get_stream_device : stream -> device
val new_stream : device -> stream
val get_ctx_stream : context -> stream
val get_name : stream -> string
val to_ordinal : device -> int
end
4 changes: 3 additions & 1 deletion arrayjit/lib/backends.ml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ let check_merge_buffer ~scheduled_node ~code_node =
("Merge buffer mismatch, on stream: " ^ name scheduled_node ^ ", expected by code: "
^ name code_node)

(* module *)

module Multicore_backend (Backend : Backend_types.No_device_backend) : Backend_types.Backend =
struct
module Domain = Domain [@warning "-3"]
Expand Down Expand Up @@ -581,7 +583,7 @@ module Lowered_no_device_backend (Backend : Backend_types.Lowered_no_device_back

let initialize config =
global_config := config;
initialize ()
initialize config

type nonrec routine = context routine [@@deriving sexp_of]

Expand Down
2 changes: 1 addition & 1 deletion arrayjit/lib/cc_backend.ml
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ let buffer_to_host dst ~src = Ndarray.map2 { f2 = Ndarray.A.blit } src dst

let is_initialized, initialize =
let initialized = ref false in
((fun () -> !initialized), fun () -> initialized := true)
((fun () -> !initialized), fun _config -> initialized := true)

let finalize _ctx = ()

Expand Down
1 change: 1 addition & 0 deletions arrayjit/lib/cuda_backend.missing.ml
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ type stream = Unimplemented_stream [@@deriving sexp_of]
type device = Unimplemented_device [@@deriving sexp_of]

let init Unimplemented_stream = Unimplemented_ctx
let buffer_ptr _ctx_array = Unimplemented_buffer_ptr
let alloc_buffer ?old_buffer:_ ~size_in_bytes:_ Unimplemented_stream = Unimplemented_buffer_ptr
let await _stream = ()
let is_idle _stream = true
Expand Down
3 changes: 1 addition & 2 deletions arrayjit/lib/gcc_backend.gccjit.ml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ type mem_properties =
let root_ctx = ref None

module Tn = Tnode
include Backend_types.No_device_types

type buffer_ptr = ctx_array [@@deriving sexp_of]
(** Alternative approach:
Expand Down Expand Up @@ -57,7 +56,7 @@ let buffer_to_host dst ~src = Ndarray.map2 { f2 = Ndarray.A.blit } src dst

let is_initialized () = Option.is_some !root_ctx

let initialize () =
let initialize _config =
if Option.is_none !root_ctx then (
let open Gccjit in
let ctx = Context.create () in
Expand Down
20 changes: 20 additions & 0 deletions lib/attic.mld
Original file line number Diff line number Diff line change
Expand Up @@ -337,4 +337,24 @@ let input_or_recurrent_nodes asgns =
in
loop asgns

]}

Upcoming in backend_types.ml:
{[

val from_host : dst_ptr:buffer_ptr -> dst:context -> Tnode.t -> unit
(** Like {!Backend.from_host}, but without synchronization and buffer retrieval. *)

val to_host : src_ptr:buffer_ptr -> src:context -> Tnode.t -> unit
(** Like {!Backend.to_host}, but without synchronization and buffer retrieval. *)

val device_to_device :
Tnode.t ->
into_merge_buffer:merge_buffer_use ->
dst_ptr:buffer_ptr ->
dst:context ->
src_ptr:buffer_ptr ->
src:context ->
unit
(** Like {!Backend.device_to_device}, but without synchronization and buffer retrieval. *)
]}

0 comments on commit 1416586

Please sign in to comment.