Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add closures support for compilation speed #37

Merged
merged 5 commits into from
Nov 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,14 @@ CUDAExt = "CUDA"
oneAPIExt = "oneAPI"

[compat]
julia = "1.10"
AMDGPU = "1"
CUDA = "5"
DataStructures = "0.18"
NumaAllocators = "0.2"
oneAPI = "1"
RuntimeGeneratedFunctions = "0.5"
StaticArrays = "1"
julia = "1.10"
oneAPI = "1"

[extras]
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
Expand Down
9 changes: 6 additions & 3 deletions ext/devices/cuda/function.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,15 @@ function ComputableDAGs.kernel(

init_caches = Expr(:block, tape.initCachesCode...)
assign_inputs = Expr(:block, ComputableDAGs.expr_from_fc.(tape.inputAssignCode)...)
code = Expr(:block, ComputableDAGs.expr_from_fc.(tape.computeCode)...)
# TODO: use gen_function_body here
code = Expr(:block, ComputableDAGs.expr_from_fc.(tape.schedule)...)

function_id = ComputableDAGs.to_var_name(UUIDs.uuid1(ComputableDAGs.rng[1]))
res_sym = eval(
ComputableDAGs.gen_access_expr(
ComputableDAGs.entry_device(tape.machine), tape.outputSymbol
ComputableDAGs._gen_access_expr(
ComputableDAGs.entry_device(tape.machine),
ComputableDAGs.entry_device(tape.machine).cacheStrategy,
tape.outputSymbol,
),
)
expr = Meta.parse(
Expand Down
12 changes: 8 additions & 4 deletions ext/devices/rocm/function.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,22 @@ function ComputableDAGs.kernel(

init_caches = Expr(:block, tape.initCachesCode...)
assign_inputs = Expr(:block, ComputableDAGs.expr_from_fc.(tape.inputAssignCode)...)
code = Expr(:block, ComputableDAGs.expr_from_fc.(tape.computeCode)...)

# TODO use gen_function_body here
code = Expr(:block, ComputableDAGs.expr_from_fc.(tape.schedule)...)

function_id = ComputableDAGs.to_var_name(UUIDs.uuid1(ComputableDAGs.rng[1]))
res_sym = eval(
ComputableDAGs.gen_access_expr(
ComputableDAGs.entry_device(tape.machine), tape.outputSymbol
ComputableDAGs._gen_access_expr(
ComputableDAGs.entry_device(tape.machine),
ComputableDAGs.entry_device(tape.machine).cacheStrategy,
tape.outputSymbol,
),
)
expr = Meta.parse(
"function compute_$(function_id)(input_vector, output_vector, n::Int64)
id = (workgroupIdx().x - 1) * workgroupDim().x + workgroupIdx().x
if (id > n)
if (id > n)
return
end
@inline data_input = input_vector[id]
Expand Down
1 change: 1 addition & 0 deletions src/ComputableDAGs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ include("scheduler/interface.jl")
include("scheduler/greedy.jl")

include("code_gen/type.jl")
include("code_gen/utils.jl")
include("code_gen/tape_machine.jl")
include("code_gen/function.jl")

Expand Down
16 changes: 13 additions & 3 deletions src/code_gen/function.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,28 @@ using RuntimeGeneratedFunctions
RuntimeGeneratedFunctions.init(@__MODULE__)
```
in your top level.

## Keyword Arguments

`closures_size` (default=0 (off)): The size of closures to use in the main generated code. This specifies the size of code blocks across which the compiler cannot optimize. For sufficiently large functions, a larger value means longer compile times but potentially faster execution time.
"""
function get_compute_function(
graph::DAG, instance, machine::Machine, context_module::Module
graph::DAG, instance, machine::Machine, context_module::Module; closures_size=0
)
tape = gen_tape(graph, instance, machine, context_module)

initCaches = Expr(:block, tape.initCachesCode...)
assignInputs = Expr(:block, expr_from_fc.(tape.inputAssignCode)...)
code = Expr(:block, expr_from_fc.(tape.computeCode)...)
code = gen_function_body(tape; closures_size=closures_size)

functionId = to_var_name(UUIDs.uuid1(rng[1]))
resSym = eval(gen_access_expr(entry_device(tape.machine), tape.outputSymbol))
resSym = eval(
_gen_access_expr(
entry_device(tape.machine),
entry_device(tape.machine).cacheStrategy,
tape.outputSymbol,
),
)
expr = #
Expr(
:function, # function definition
Expand Down
135 changes: 117 additions & 18 deletions src/code_gen/tape_machine.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,13 @@ end

function expr_from_fc(fc::FunctionCall{VectorT,0}) where {VectorT}
func_call = Expr(
:call, fc.func, eval.(gen_access_expr.(Ref(fc.device), fc.arguments))...
:call,
fc.func,
eval.(
_gen_access_expr.(Ref(fc.device), Ref(fc.device.cacheStrategy), fc.arguments)
)...,
)
access_expr = eval(gen_access_expr(fc.device, fc.return_symbol))
access_expr = eval(gen_access_expr(fc))

return Expr(:(=), access_expr, func_call)
end
Expand All @@ -69,9 +73,11 @@ function expr_from_fc(fc::FunctionCall{VectorT,M}) where {VectorT,M}
:call,
fc.func,
fc.value_arguments...,
eval.(gen_access_expr.(Ref(fc.device), fc.arguments))...,
eval.(
_gen_access_expr.(Ref(fc.device), Ref(fc.device.cacheStrategy), fc.arguments)
)...,
)
access_expr = eval(gen_access_expr(fc.device, fc.return_symbol))
access_expr = eval(gen_access_expr(fc))

return Expr(:(=), access_expr, func_call)
end
Expand Down Expand Up @@ -115,14 +121,11 @@ function gen_input_assignment_code(
device = entry_device(machine)

fc = FunctionCall(
RuntimeGeneratedFunction(
@__MODULE__,
context_module,
Expr(:->, :x, input_expr(instance, name, :x)),
),
context_module.eval(Expr(:->, :x, input_expr(instance, name, :x))),
SVector{0,Any}(),
SVector{1,Symbol}(:input),
symbol,
Nothing,
device,
)

Expand All @@ -133,6 +136,89 @@ function gen_input_assignment_code(
return assign_inputs
end

"""
gen_function_body(tape::Tape; closures_size)

Generate the function body from the given [`Tape`](@ref).

## Keyword Arguments
`closures_size`: The size of closures to generate (in lines of code). Closures introduce function barriers in the function body, preventing some optimizations by the compiler and therefore greatly reducing compile time. A value of 1 or less will disable the use of closures entirely.
"""
function gen_function_body(tape::Tape; closures_size::Int)
if closures_size > 1
# only need to annotate types later when using closures
infer_types!(tape)
end

fc_vec = tape.schedule

if (closures_size <= 1)
return Expr(:block, expr_from_fc.(fc_vec)...)
end

closures = Vector{Expr}()
# iterate from end to beginning
# this helps because we can collect all undefined arguments to the closures that have to be returned somewhere earlier
undefined_argument_symbols = Set{Symbol}()
# the final return symbol is the return of the entire generated function, it always has to be returned
push!(undefined_argument_symbols, eval(gen_access_expr(fc_vec[end])))

for i in length(fc_vec):(-closures_size):1
e = i
b = max(i - closures_size, 1)
code_block = fc_vec[b:e]

# collect `local var` statements that need to exist before the closure starts
local_inits = gen_local_init.(code_block)

return_symbols = eval.(gen_access_expr.(code_block))

ret_symbols_set = Set(return_symbols)
for fc in code_block
for arg in fc.arguments
symbol = eval(_gen_access_expr(fc.device, fc.device.cacheStrategy, arg))

# symbol won't be defined if it is first calculated in the closure
# so don't add it to the arguments in this case
if !(symbol in ret_symbols_set)
push!(undefined_argument_symbols, symbol)
end
end
end

intersect!(ret_symbols_set, undefined_argument_symbols)
return_symbols = Symbol[ret_symbols_set...]

closure = Expr(
:block,
Expr(
:(=),
Expr(:tuple, return_symbols...),
Expr(
:call, # call to the following closure (no arguments)
Expr( # create the closure: () -> code block; return (locals)
:->,
:(), # closure arguments (none)
Expr( # actual function body of the closure
:block,
local_inits..., # declare local variables with type information inside the closure
expr_from_fc.(code_block)...,
Expr(:return, Expr(:tuple, return_symbols...)),
),
),
),
),
)

setdiff!(undefined_argument_symbols, ret_symbols_set)

# combine to one closure call, including all the local inits and the actual call to the closure
pushfirst!(closures, closure)
end

return Expr(:block, closures...)
end

"""
gen_tape(
graph::DAG,
Expand All @@ -154,25 +240,33 @@ function gen_tape(
scheduler::AbstractScheduler=GreedyScheduler(),
)
schedule = schedule_dag(scheduler, graph, machine)
function_body = lower(schedule, machine)

# get inSymbols
inputSyms = Dict{String,Vector{Symbol}}()
# get input symbols
input_syms = Dict{String,Vector{Symbol}}()
for node in get_entry_nodes(graph)
if !haskey(inputSyms, node.name)
inputSyms[node.name] = Vector{Symbol}()
if !haskey(input_syms, node.name)
input_syms[node.name] = Vector{Symbol}()
end

push!(inputSyms[node.name], Symbol("$(to_var_name(node.id))_in"))
push!(input_syms[node.name], Symbol("$(to_var_name(node.id))_in"))
end

# get outSymbol
outSym = Symbol(to_var_name(get_exit_node(graph).id))

initCaches = gen_cache_init_code(machine)
assign_inputs = gen_input_assignment_code(inputSyms, instance, machine, context_module)
init_caches = gen_cache_init_code(machine)
assign_inputs = gen_input_assignment_code(input_syms, instance, machine, context_module)

return Tape{input_type(instance)}(
initCaches, assign_inputs, schedule, inputSyms, outSym, Dict(), instance, machine
init_caches,
assign_inputs,
function_body,
input_syms,
outSym,
Dict(),
instance,
machine,
)
end

Expand All @@ -182,6 +276,9 @@ end
Execute the given tape with the given input.

For implementation reasons, this disregards the set [`CacheStrategy`](@ref) of the devices and always uses a dictionary.

!!! warning
This is very slow and might not work. This is to be majorly revamped.
"""
function execute_tape(tape::Tape, input)
cache = Dict{Symbol,Any}()
Expand All @@ -192,10 +289,12 @@ function execute_tape(tape::Tape, input)
@eval $expr
end

compute_code = tape.schedule

for function_call in tape.inputAssignCode
call_fc(function_call, cache)
end
for function_call in tape.computeCode
for function_call in compute_code
call_fc(function_call, cache)
end

Expand Down
2 changes: 1 addition & 1 deletion src/code_gen/type.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ TODO: update docs
struct Tape{INPUT}
initCachesCode::Vector{Expr}
inputAssignCode::Vector{FunctionCall}
computeCode::Vector{FunctionCall}
schedule::Vector{FunctionCall}
inputSymbols::Dict{String,Vector{Symbol}}
outputSymbol::Symbol
cache::Dict{Symbol,Any}
Expand Down
45 changes: 45 additions & 0 deletions src/code_gen/utils.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
"""
infer_types!(schedule::Vector{FunctionCall})

Infer the result type of each function call in the given schedule. Returns a dictionary with the result type for each [`Node`](@ref). This assumes that each node has only one statically inferrable return type and will throw an exceptin otherwise.
This also assumes that the given `Vector` contains a topological ordering of its nodes, such as returned by a call to [`schedule_dag`](@ref).
"""
function infer_types!(tape::Tape)
known_result_types = Dict{Symbol,Type}()

# the only initially known type
known_result_types[:input] = input_type(tape.instance)

for fc in tape.inputAssignCode
res_type = result_type(fc, known_result_types)
fc.return_type = res_type
known_result_types[fc.return_symbol] = res_type
end

for fc in tape.schedule
res_type = result_type(fc, known_result_types)
fc.return_type = res_type
known_result_types[fc.return_symbol] = res_type
end

return nothing
end

"""
lower(schedule::Vector{Node}, machine::Machine)

After [`schedule_dag`](@ref) has made a schedule of nodes, this function lowers the vector of [`Node`](@ref)s into a vector of [`FunctionCall`](@ref)s.
"""
function lower(schedule::Vector{Node}, machine::Machine)
calls = Vector{FunctionCall}()

for node in schedule
if (node isa DataTaskNode && length(children(node)) == 0)
push!(calls, get_init_function_call(node, entry_device(machine)))
else
push!(calls, get_function_call(node)...)
end
end

return calls
end
18 changes: 18 additions & 0 deletions src/devices/impl.jl
Original file line number Diff line number Diff line change
Expand Up @@ -62,3 +62,21 @@ function cpu_st()
[NumaNode(0, 1, default_strategy(NumaNode), -1.0, UUIDs.uuid1())], [-1.0;;]
)
end

"""
gen_access_expr(fc::FunctionCall)

Dispatch from the given [`FunctionCall`](@ref) to the interface function `_gen_access_expr`(@ref).
"""
function gen_access_expr(fc::FunctionCall)
return _gen_access_expr(fc.device, fc.device.cacheStrategy, fc.return_symbol)
end

"""
gen_local_init(fc::FunctionCall)

Dispatch from the given [`FunctionCall`](@ref) to the interface function `_gen_local_init`(@ref).
"""
function gen_local_init(fc::FunctionCall)
return _gen_local_init(fc, fc.device, fc.device.cacheStrategy)
end
Loading
Loading