Skip to content

Commit

Permalink
Backends: Factor out buffer retrieval from copying, finalize design
Browse files Browse the repository at this point in the history
  • Loading branch information
lukstafi committed Oct 21, 2024
1 parent 1416586 commit 50afe3e
Show file tree
Hide file tree
Showing 9 changed files with 409 additions and 458 deletions.
116 changes: 70 additions & 46 deletions arrayjit/lib/backend_types.ml
Original file line number Diff line number Diff line change
Expand Up @@ -95,9 +95,36 @@ module type Backend_common = sig
[occupancy] returns true are included. *)
end

(** Parts shared by backend implementations excluding what's already in {!Backend_any_common}. *)
module type Backend_impl_common = sig
type context [@@deriving sexp_of]
type ctx_array [@@deriving sexp_of]
type ctx_arrays [@@deriving sexp_of]
type buffer_ptr [@@deriving sexp_of]

val buffer_ptr : ctx_array -> buffer_ptr
val ctx_arrays : context -> ctx_arrays
val get_array : ctx_arrays -> Tnode.t -> ctx_array option

val is_in_context : Low_level.traced_array -> bool
(** If true, the node is required to be in the contexts linked with code that uses it.
Should return false for nodes that are virtual, local, or which the backend prefers to access
directly from the host. *)
end

module type No_device_copying = sig
type buffer_ptr

val buffer_to_buffer : dst:buffer_ptr -> src:buffer_ptr -> unit
val host_to_buffer : Ndarray.t -> dst:buffer_ptr -> unit
val buffer_to_host : Ndarray.t -> src:buffer_ptr -> unit
end

(** An intermediate interface for stream-agnostic (typically CPU) backend implementations. *)
module type No_device_backend = sig
include Backend_common with type init_info := string and type stream := unit
include Backend_impl_common with type context := context and type buffer_ptr := buffer_ptr

val link : merge_buffer:(buffer_ptr * Tnode.t) option ref -> context -> code -> routine
(** Returns the routine for the code's procedure, in a new context derived from the given context. *)
Expand All @@ -114,10 +141,7 @@ module type No_device_backend = sig
val get_used_memory : unit -> int
(** Returns (an upper bound of) the memory used for arrays, in bytes. *)

val to_buffer : Tnode.t -> dst:buffer_ptr -> src:context -> unit
val host_to_buffer : Ndarray.t -> dst:buffer_ptr -> unit
val buffer_to_host : Ndarray.t -> src:buffer_ptr -> unit
val get_buffer : Tnode.t -> context -> buffer_ptr option
include No_device_copying with type buffer_ptr := buffer_ptr
end

(** Parts shared by both assignments-level and lowered-level backend interfaces providing streams
Expand Down Expand Up @@ -181,21 +205,8 @@ module type Backend_device_common = sig
val get_name : stream -> string
end

module type Backend = sig
include Backend_device_common

include
Backend_common
with type context := context
and type init_info := stream
and type stream := stream

val link : context -> code -> routine
(** Returns the routine for the code's procedure, in a new context derived from the given context. *)

val link_batch : context -> code_batch -> context * routine option array
(** Returns the routines for the procedures included in the code batch. The returned context is
downstream of all the returned routines. *)
module type With_buffer_retrieval_and_syncing = sig
type context

val from_host : context -> Tnode.t -> bool
(** If the tensor node is both hosted and in-context, schedules a copy from host to context and
Expand Down Expand Up @@ -227,27 +238,28 @@ module type Backend = sig
[will_wait_for src (all_work (get_ctx_stream dst))]. *)
end

(** Parts shared by lowered-level backends excluding what's already in {!Backend_any_common}. *)
module type Lowered_backend_common = sig
type context [@@deriving sexp_of]
type ctx_array [@@deriving sexp_of]
type ctx_arrays [@@deriving sexp_of]
type buffer_ptr
module type Backend = sig
include Backend_device_common

val buffer_ptr : ctx_array -> buffer_ptr
val ctx_arrays : context -> ctx_arrays
val get_array : ctx_arrays -> Tnode.t -> ctx_array option
include
Backend_common
with type context := context
and type init_info := stream
and type stream := stream

val is_in_context : Low_level.traced_array -> bool
(** If true, the node is required to be in the contexts linked with code that uses it.
val link : context -> code -> routine
(** Returns the routine for the code's procedure, in a new context derived from the given context. *)

Should return false for nodes that are virtual, local, or which the backend prefers to access
directly from the host. *)
val link_batch : context -> code_batch -> context * routine option array
(** Returns the routines for the procedures included in the code batch. The returned context is
downstream of all the returned routines. *)

include With_buffer_retrieval_and_syncing with type context := context
end

(** Lowered-level stream agnostic backend interface: implementation-facing API for CPU backends. *)
module type Lowered_no_device_backend = sig
include Lowered_backend_common
include Backend_impl_common

include
Backend_any_common
Expand Down Expand Up @@ -278,17 +290,36 @@ module type Lowered_no_device_backend = sig
procedure ->
context * Indexing.lowered_bindings * Task.t * string

val to_buffer : Tnode.t -> dst:buffer_ptr -> src:context -> unit
val host_to_buffer : Ndarray.t -> dst:buffer_ptr -> unit
val buffer_to_host : Ndarray.t -> src:buffer_ptr -> unit
include No_device_copying with type buffer_ptr := buffer_ptr
end

module type No_buffer_retrieval_or_syncing = sig
include Backend_impl_common
include Backend_device_common with type context := context and type buffer_ptr := buffer_ptr

val from_host : dst_ptr:buffer_ptr -> dst:context -> Ndarray.t -> unit
(** Like {!Backend.from_host}, but without synchronization and buffer retrieval. *)

val to_host : src_ptr:buffer_ptr -> src:context -> Ndarray.t -> unit
(** Like {!Backend.to_host}, but without synchronization and buffer retrieval. *)

val device_to_device :
Tnode.t ->
into_merge_buffer:Types.merge_buffer_use ->
dst_ptr:buffer_ptr option ->
dst:context ->
src_ptr:buffer_ptr ->
src:context ->
unit
(** Like {!Backend.device_to_device}, but without synchronization and buffer retrieval. Raises
[Invalid_argument] if [into_merge_buffer = No] and [dst_ptr = None]. *)
end

(** Lowered-level backend interface: implementation-facing API for device-based (typically GPU)
backends. *)
module type Lowered_backend = sig
include Lowered_backend_common
include Backend_device_common with type context := context and type buffer_ptr := buffer_ptr

include No_buffer_retrieval_or_syncing

type code [@@deriving sexp_of]
type code_batch [@@deriving sexp_of]

Expand All @@ -305,13 +336,6 @@ module type Lowered_backend = sig
val link_batch :
context -> code_batch -> context * Indexing.lowered_bindings * Task.t option array

val from_host : context -> Tnode.t -> bool

val to_host : context -> Tnode.t -> bool

val device_to_device :
Tnode.t -> into_merge_buffer:Types.merge_buffer_use -> dst:context -> src:context -> bool

val scheduled_merge_node : stream -> Tnode.t option
(** [scheduled_merge_node stream] is the tensor node that would be in the [stream]'s merge buffer
right after [await stream]. *)
Expand Down
Loading

0 comments on commit 50afe3e

Please sign in to comment.