Skip to content

Commit

Permalink
Add closures support for compilation speed (#37)
Browse files Browse the repository at this point in the history
* Try adding closure functionality in code gen

* Fix type inference ability and make closures into anonymous functions

* Improve required allocations dramatically

* Add error for failure in type inference

* Turn closures off by default
  • Loading branch information
AntonReinhard authored Nov 1, 2024
1 parent bb191d6 commit 9a93eb8
Show file tree
Hide file tree
Showing 16 changed files with 296 additions and 76 deletions.
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,14 @@ CUDAExt = "CUDA"
oneAPIExt = "oneAPI"

[compat]
julia = "1.10"
AMDGPU = "1"
CUDA = "5"
DataStructures = "0.18"
NumaAllocators = "0.2"
oneAPI = "1"
RuntimeGeneratedFunctions = "0.5"
StaticArrays = "1"
julia = "1.10"
oneAPI = "1"

[extras]
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
Expand Down
9 changes: 6 additions & 3 deletions ext/devices/cuda/function.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,15 @@ function ComputableDAGs.kernel(

init_caches = Expr(:block, tape.initCachesCode...)
assign_inputs = Expr(:block, ComputableDAGs.expr_from_fc.(tape.inputAssignCode)...)
code = Expr(:block, ComputableDAGs.expr_from_fc.(tape.computeCode)...)
# TODO: use gen_function_body here
code = Expr(:block, ComputableDAGs.expr_from_fc.(tape.schedule)...)

function_id = ComputableDAGs.to_var_name(UUIDs.uuid1(ComputableDAGs.rng[1]))
res_sym = eval(
ComputableDAGs.gen_access_expr(
ComputableDAGs.entry_device(tape.machine), tape.outputSymbol
ComputableDAGs._gen_access_expr(
ComputableDAGs.entry_device(tape.machine),
ComputableDAGs.entry_device(tape.machine).cacheStrategy,
tape.outputSymbol,
),
)
expr = Meta.parse(
Expand Down
12 changes: 8 additions & 4 deletions ext/devices/rocm/function.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,22 @@ function ComputableDAGs.kernel(

init_caches = Expr(:block, tape.initCachesCode...)
assign_inputs = Expr(:block, ComputableDAGs.expr_from_fc.(tape.inputAssignCode)...)
code = Expr(:block, ComputableDAGs.expr_from_fc.(tape.computeCode)...)

# TODO use gen_function_body here
code = Expr(:block, ComputableDAGs.expr_from_fc.(tape.schedule)...)

function_id = ComputableDAGs.to_var_name(UUIDs.uuid1(ComputableDAGs.rng[1]))
res_sym = eval(
ComputableDAGs.gen_access_expr(
ComputableDAGs.entry_device(tape.machine), tape.outputSymbol
ComputableDAGs._gen_access_expr(
ComputableDAGs.entry_device(tape.machine),
ComputableDAGs.entry_device(tape.machine).cacheStrategy,
tape.outputSymbol,
),
)
expr = Meta.parse(
"function compute_$(function_id)(input_vector, output_vector, n::Int64)
id = (workgroupIdx().x - 1) * workgroupDim().x + workgroupIdx().x
if (id > n)
if (id > n)
return
end
@inline data_input = input_vector[id]
Expand Down
1 change: 1 addition & 0 deletions src/ComputableDAGs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ include("scheduler/interface.jl")
include("scheduler/greedy.jl")

include("code_gen/type.jl")
include("code_gen/utils.jl")
include("code_gen/tape_machine.jl")
include("code_gen/function.jl")

Expand Down
16 changes: 13 additions & 3 deletions src/code_gen/function.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,28 @@ using RuntimeGeneratedFunctions
RuntimeGeneratedFunctions.init(@__MODULE__)
```
in your top level.
## Keyword Arguments
`closures_size` (default=0 (off)): The size of closures to use in the main generated code. This specifies the size of code blocks across which the compiler cannot optimize. For sufficiently large functions, a larger value means longer compile times but potentially faster execution time.
"""
function get_compute_function(
graph::DAG, instance, machine::Machine, context_module::Module
graph::DAG, instance, machine::Machine, context_module::Module; closures_size=0
)
tape = gen_tape(graph, instance, machine, context_module)

initCaches = Expr(:block, tape.initCachesCode...)
assignInputs = Expr(:block, expr_from_fc.(tape.inputAssignCode)...)
code = Expr(:block, expr_from_fc.(tape.computeCode)...)
code = gen_function_body(tape; closures_size=closures_size)

functionId = to_var_name(UUIDs.uuid1(rng[1]))
resSym = eval(gen_access_expr(entry_device(tape.machine), tape.outputSymbol))
resSym = eval(
_gen_access_expr(
entry_device(tape.machine),
entry_device(tape.machine).cacheStrategy,
tape.outputSymbol,
),
)
expr = #
Expr(
:function, # function definition
Expand Down
135 changes: 117 additions & 18 deletions src/code_gen/tape_machine.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,13 @@ end

function expr_from_fc(fc::FunctionCall{VectorT,0}) where {VectorT}
func_call = Expr(
:call, fc.func, eval.(gen_access_expr.(Ref(fc.device), fc.arguments))...
:call,
fc.func,
eval.(
_gen_access_expr.(Ref(fc.device), Ref(fc.device.cacheStrategy), fc.arguments)
)...,
)
access_expr = eval(gen_access_expr(fc.device, fc.return_symbol))
access_expr = eval(gen_access_expr(fc))

return Expr(:(=), access_expr, func_call)
end
Expand All @@ -69,9 +73,11 @@ function expr_from_fc(fc::FunctionCall{VectorT,M}) where {VectorT,M}
:call,
fc.func,
fc.value_arguments...,
eval.(gen_access_expr.(Ref(fc.device), fc.arguments))...,
eval.(
_gen_access_expr.(Ref(fc.device), Ref(fc.device.cacheStrategy), fc.arguments)
)...,
)
access_expr = eval(gen_access_expr(fc.device, fc.return_symbol))
access_expr = eval(gen_access_expr(fc))

return Expr(:(=), access_expr, func_call)
end
Expand Down Expand Up @@ -115,14 +121,11 @@ function gen_input_assignment_code(
device = entry_device(machine)

fc = FunctionCall(
RuntimeGeneratedFunction(
@__MODULE__,
context_module,
Expr(:->, :x, input_expr(instance, name, :x)),
),
context_module.eval(Expr(:->, :x, input_expr(instance, name, :x))),
SVector{0,Any}(),
SVector{1,Symbol}(:input),
symbol,
Nothing,
device,
)

Expand All @@ -133,6 +136,89 @@ function gen_input_assignment_code(
return assign_inputs
end

"""
gen_function_body(tape::Tape; closures_size)
Generate the function body from the given [`Tape`](@ref).
## Keyword Arguments
`closures_size`: The size of closures to generate (in lines of code). Closures introduce function barriers in the function body, preventing some optimizations by the compiler and therefore greatly reducing compile time. A value of 1 or less will disable the use of closures entirely.
"""
function gen_function_body(tape::Tape; closures_size::Int)
if closures_size > 1
# only need to annotate types later when using closures
infer_types!(tape)
end

fc_vec = tape.schedule

if (closures_size <= 1)
return Expr(:block, expr_from_fc.(fc_vec)...)
end

closures = Vector{Expr}()
# iterate from end to beginning
# this helps because we can collect all undefined arguments to the closures that have to be returned somewhere earlier
undefined_argument_symbols = Set{Symbol}()
# the final return symbol is the return of the entire generated function, it always has to be returned
push!(undefined_argument_symbols, eval(gen_access_expr(fc_vec[end])))

for i in length(fc_vec):(-closures_size):1
e = i
b = max(i - closures_size, 1)
code_block = fc_vec[b:e]

# collect `local var` statements that need to exist before the closure starts
local_inits = gen_local_init.(code_block)

return_symbols = eval.(gen_access_expr.(code_block))

ret_symbols_set = Set(return_symbols)
for fc in code_block
for arg in fc.arguments
symbol = eval(_gen_access_expr(fc.device, fc.device.cacheStrategy, arg))

# symbol won't be defined if it is first calculated in the closure
# so don't add it to the arguments in this case
if !(symbol in ret_symbols_set)
push!(undefined_argument_symbols, symbol)
end
end
end

intersect!(ret_symbols_set, undefined_argument_symbols)
return_symbols = Symbol[ret_symbols_set...]

closure = Expr(
:block,
Expr(
:(=),
Expr(:tuple, return_symbols...),
Expr(
:call, # call to the following closure (no arguments)
Expr( # create the closure: () -> code block; return (locals)
:->,
:(), # closure arguments (none)
Expr( # actual function body of the closure
:block,
local_inits..., # declare local variables with type information inside the closure
expr_from_fc.(code_block)...,
Expr(:return, Expr(:tuple, return_symbols...)),
),
),
),
),
)

setdiff!(undefined_argument_symbols, ret_symbols_set)

# combine to one closure call, including all the local inits and the actual call to the closure
pushfirst!(closures, closure)
end

return Expr(:block, closures...)
end

"""
gen_tape(
graph::DAG,
Expand All @@ -154,25 +240,33 @@ function gen_tape(
scheduler::AbstractScheduler=GreedyScheduler(),
)
schedule = schedule_dag(scheduler, graph, machine)
function_body = lower(schedule, machine)

# get inSymbols
inputSyms = Dict{String,Vector{Symbol}}()
# get input symbols
input_syms = Dict{String,Vector{Symbol}}()
for node in get_entry_nodes(graph)
if !haskey(inputSyms, node.name)
inputSyms[node.name] = Vector{Symbol}()
if !haskey(input_syms, node.name)
input_syms[node.name] = Vector{Symbol}()
end

push!(inputSyms[node.name], Symbol("$(to_var_name(node.id))_in"))
push!(input_syms[node.name], Symbol("$(to_var_name(node.id))_in"))
end

# get outSymbol
outSym = Symbol(to_var_name(get_exit_node(graph).id))

initCaches = gen_cache_init_code(machine)
assign_inputs = gen_input_assignment_code(inputSyms, instance, machine, context_module)
init_caches = gen_cache_init_code(machine)
assign_inputs = gen_input_assignment_code(input_syms, instance, machine, context_module)

return Tape{input_type(instance)}(
initCaches, assign_inputs, schedule, inputSyms, outSym, Dict(), instance, machine
init_caches,
assign_inputs,
function_body,
input_syms,
outSym,
Dict(),
instance,
machine,
)
end

Expand All @@ -182,6 +276,9 @@ end
Execute the given tape with the given input.
For implementation reasons, this disregards the set [`CacheStrategy`](@ref) of the devices and always uses a dictionary.
!!! warning
This is very slow and might not work. This is to be majorly revamped.
"""
function execute_tape(tape::Tape, input)
cache = Dict{Symbol,Any}()
Expand All @@ -192,10 +289,12 @@ function execute_tape(tape::Tape, input)
@eval $expr
end

compute_code = tape.schedule

for function_call in tape.inputAssignCode
call_fc(function_call, cache)
end
for function_call in tape.computeCode
for function_call in compute_code
call_fc(function_call, cache)
end

Expand Down
2 changes: 1 addition & 1 deletion src/code_gen/type.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ TODO: update docs
struct Tape{INPUT}
initCachesCode::Vector{Expr}
inputAssignCode::Vector{FunctionCall}
computeCode::Vector{FunctionCall}
schedule::Vector{FunctionCall}
inputSymbols::Dict{String,Vector{Symbol}}
outputSymbol::Symbol
cache::Dict{Symbol,Any}
Expand Down
45 changes: 45 additions & 0 deletions src/code_gen/utils.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
"""
infer_types!(schedule::Vector{FunctionCall})
Infer the result type of each function call in the given schedule. Returns a dictionary with the result type for each [`Node`](@ref). This assumes that each node has only one statically inferrable return type and will throw an exceptin otherwise.
This also assumes that the given `Vector` contains a topological ordering of its nodes, such as returned by a call to [`schedule_dag`](@ref).
"""
function infer_types!(tape::Tape)
known_result_types = Dict{Symbol,Type}()

# the only initially known type
known_result_types[:input] = input_type(tape.instance)

for fc in tape.inputAssignCode
res_type = result_type(fc, known_result_types)
fc.return_type = res_type
known_result_types[fc.return_symbol] = res_type
end

for fc in tape.schedule
res_type = result_type(fc, known_result_types)
fc.return_type = res_type
known_result_types[fc.return_symbol] = res_type
end

return nothing
end

"""
lower(schedule::Vector{Node}, machine::Machine)
After [`schedule_dag`](@ref) has made a schedule of nodes, this function lowers the vector of [`Node`](@ref)s into a vector of [`FunctionCall`](@ref)s.
"""
function lower(schedule::Vector{Node}, machine::Machine)
calls = Vector{FunctionCall}()

for node in schedule
if (node isa DataTaskNode && length(children(node)) == 0)
push!(calls, get_init_function_call(node, entry_device(machine)))
else
push!(calls, get_function_call(node)...)
end
end

return calls
end
18 changes: 18 additions & 0 deletions src/devices/impl.jl
Original file line number Diff line number Diff line change
Expand Up @@ -62,3 +62,21 @@ function cpu_st()
[NumaNode(0, 1, default_strategy(NumaNode), -1.0, UUIDs.uuid1())], [-1.0;;]
)
end

"""
gen_access_expr(fc::FunctionCall)
Dispatch from the given [`FunctionCall`](@ref) to the interface function `_gen_access_expr`(@ref).
"""
function gen_access_expr(fc::FunctionCall)
return _gen_access_expr(fc.device, fc.device.cacheStrategy, fc.return_symbol)
end

"""
gen_local_init(fc::FunctionCall)
Dispatch from the given [`FunctionCall`](@ref) to the interface function `_gen_local_init`(@ref).
"""
function gen_local_init(fc::FunctionCall)
return _gen_local_init(fc, fc.device, fc.device.cacheStrategy)
end
Loading

0 comments on commit 9a93eb8

Please sign in to comment.