From 41a1bcd90b4b03100f7747a48d35dabaedb8682d Mon Sep 17 00:00:00 2001 From: Rubydragon Date: Thu, 3 Oct 2024 11:14:47 +0200 Subject: [PATCH 1/5] Try adding closure functionality in code gen --- ext/devices/cuda/function.jl | 7 +++-- ext/devices/rocm/function.jl | 10 +++++-- src/code_gen/function.jl | 16 +++++++++-- src/code_gen/tape_machine.jl | 54 +++++++++++++++++++++++++++++++++--- src/devices/impl.jl | 18 ++++++++++++ src/devices/interface.jl | 13 +++++++-- src/devices/numa/impl.jl | 45 +++++++++++++++++++----------- 7 files changed, 133 insertions(+), 30 deletions(-) diff --git a/ext/devices/cuda/function.jl b/ext/devices/cuda/function.jl index 1bf5a5d..e7c392e 100644 --- a/ext/devices/cuda/function.jl +++ b/ext/devices/cuda/function.jl @@ -6,12 +6,15 @@ function ComputableDAGs.kernel( init_caches = Expr(:block, tape.initCachesCode...) assign_inputs = Expr(:block, ComputableDAGs.expr_from_fc.(tape.inputAssignCode)...) + # TODO: use gen_function_body here code = Expr(:block, ComputableDAGs.expr_from_fc.(tape.computeCode)...) function_id = ComputableDAGs.to_var_name(UUIDs.uuid1(ComputableDAGs.rng[1])) res_sym = eval( - ComputableDAGs.gen_access_expr( - ComputableDAGs.entry_device(tape.machine), tape.outputSymbol + ComputableDAGs._gen_access_expr( + ComputableDAGs.entry_device(tape.machine), + ComputableDAGs.entry_device(tape.machine).cacheStrategy, + tape.outputSymbol, ), ) expr = Meta.parse( diff --git a/ext/devices/rocm/function.jl b/ext/devices/rocm/function.jl index 2ecca3b..ae8572b 100644 --- a/ext/devices/rocm/function.jl +++ b/ext/devices/rocm/function.jl @@ -6,18 +6,22 @@ function ComputableDAGs.kernel( init_caches = Expr(:block, tape.initCachesCode...) assign_inputs = Expr(:block, ComputableDAGs.expr_from_fc.(tape.inputAssignCode)...) + + # TODO use gen_function_body here code = Expr(:block, ComputableDAGs.expr_from_fc.(tape.computeCode)...) function_id = ComputableDAGs.to_var_name(UUIDs.uuid1(ComputableDAGs.rng[1])) res_sym = eval( - ComputableDAGs.gen_access_expr( - ComputableDAGs.entry_device(tape.machine), tape.outputSymbol + ComputableDAGs._gen_access_expr( + ComputableDAGs.entry_device(tape.machine), + ComputableDAGs.entry_device(tape.machine).cacheStrategy, + tape.outputSymbol, ), ) expr = Meta.parse( "function compute_$(function_id)(input_vector, output_vector, n::Int64) id = (workgroupIdx().x - 1) * workgroupDim().x + workgroupIdx().x - if (id > n) + if (id > n) return end @inline data_input = input_vector[id] diff --git a/src/code_gen/function.jl b/src/code_gen/function.jl index 3ff1273..c87a8f3 100644 --- a/src/code_gen/function.jl +++ b/src/code_gen/function.jl @@ -14,18 +14,28 @@ using RuntimeGeneratedFunctions RuntimeGeneratedFunctions.init(@__MODULE__) ``` in your top level. + +## Keyword Arguments + +`closures_size` (default=500): The size of closures to use in the main generated code. This specifies the size of code blocks across which the compiler cannot optimize. For sufficiently large functions, a larger value means longer compile times but potentially faster execution time. """ function get_compute_function( - graph::DAG, instance, machine::Machine, context_module::Module + graph::DAG, instance, machine::Machine, context_module::Module; closures_size=500 ) tape = gen_tape(graph, instance, machine, context_module) initCaches = Expr(:block, tape.initCachesCode...) assignInputs = Expr(:block, expr_from_fc.(tape.inputAssignCode)...) - code = Expr(:block, expr_from_fc.(tape.computeCode)...) + code = gen_function_body(tape.computeCode; closures_size=closures_size) functionId = to_var_name(UUIDs.uuid1(rng[1])) - resSym = eval(gen_access_expr(entry_device(tape.machine), tape.outputSymbol)) + resSym = eval( + _gen_access_expr( + entry_device(tape.machine), + entry_device(tape.machine).cacheStrategy, + tape.outputSymbol, + ), + ) expr = # Expr( :function, # function definition diff --git a/src/code_gen/tape_machine.jl b/src/code_gen/tape_machine.jl index a434f2c..c01a45a 100644 --- a/src/code_gen/tape_machine.jl +++ b/src/code_gen/tape_machine.jl @@ -52,9 +52,13 @@ end function expr_from_fc(fc::FunctionCall{VectorT,0}) where {VectorT} func_call = Expr( - :call, fc.func, eval.(gen_access_expr.(Ref(fc.device), fc.arguments))... + :call, + fc.func, + eval.( + _gen_access_expr.(Ref(fc.device), Ref(fc.device.cacheStrategy), fc.arguments) + )..., ) - access_expr = eval(gen_access_expr(fc.device, fc.return_symbol)) + access_expr = eval(gen_access_expr(fc)) return Expr(:(=), access_expr, func_call) end @@ -69,9 +73,11 @@ function expr_from_fc(fc::FunctionCall{VectorT,M}) where {VectorT,M} :call, fc.func, fc.value_arguments..., - eval.(gen_access_expr.(Ref(fc.device), fc.arguments))..., + eval.( + _gen_access_expr.(Ref(fc.device), Ref(fc.device.cacheStrategy), fc.arguments) + )..., ) - access_expr = eval(gen_access_expr(fc.device, fc.return_symbol)) + access_expr = eval(gen_access_expr(fc)) return Expr(:(=), access_expr, func_call) end @@ -133,6 +139,46 @@ function gen_input_assignment_code( return assign_inputs end +""" + gen_function_body(fc_vec::Vector{FunctionCall}; closures_size) + +Generate the function body from the given `Vector` of [`FunctionCall`](@ref)s. + +## Keyword Arguments +`closures_size`: The size of closures to generate (in lines of code). Closures introduce function barriers in the function body, preventing some optimizations by the compiler and therefore greatly reducing compile time. A value of 1 or less will disable the use of closures entirely. +""" +function gen_function_body(fc_vec::Vector{FunctionCall}; closures_size::Int) + if (closures_size <= 1) + return Expr(:block, expr_from_fc.(fc_vec)...) + end + + closures = Vector{Expr}() + for i in 1:closures_size:length(fc_vec) + code_block = fc_vec[i:min(i + closures_size, length(fc_vec))] + + # collect `local var` statements that need to exist before the closure starts + # since the return symbols are always unique, this has to happen for each fc and there will be no duplicates + local_inits = gen_local_init.(code_block) + + closure = Expr( # call to the following closure (no arguments) + :call, + Expr( # create the closure: () -> code block; return nothing + :->, + :(), + Expr(# # actual function body of the closure + :block, + expr_from_fc.(code_block)..., + Expr(:return, :nothing), + ), + ), + ) + # combine to one closure call, including all the local inits and the actual call to the closure + push!(closures, Expr(:block, local_inits..., closure)) + end + + return Expr(:block, closures...) +end + """ gen_tape( graph::DAG, diff --git a/src/devices/impl.jl b/src/devices/impl.jl index 8e28b46..2f84121 100644 --- a/src/devices/impl.jl +++ b/src/devices/impl.jl @@ -62,3 +62,21 @@ function cpu_st() [NumaNode(0, 1, default_strategy(NumaNode), -1.0, UUIDs.uuid1())], [-1.0;;] ) end + +""" + gen_access_expr(fc::FunctionCall) + +Dispatch from the given [`FunctionCall`](@ref) to the interface function `_gen_access_expr`(@ref). +""" +function gen_access_expr(fc::FunctionCall) + return _gen_access_expr(fc.device, fc.device.cacheStrategy, fc.return_symbol) +end + +""" + gen_local_init(fc::FunctionCall) + +Dispatch from the given [`FunctionCall`](@ref) to the interface function `_gen_local_init`(@ref). +""" +function gen_local_init(fc::FunctionCall) + return _gen_local_init(fc, fc.device, fc.device.cacheStrategy) +end diff --git a/src/devices/interface.jl b/src/devices/interface.jl index 1c39111..a4377e0 100644 --- a/src/devices/interface.jl +++ b/src/devices/interface.jl @@ -100,12 +100,21 @@ The strategy is a symbol function gen_cache_init_code end """ - gen_access_expr(device::AbstractDevice, symbol::Symbol) + _gen_access_expr(device::AbstractDevice, cache_strategy::CacheStrategy, symbol::Symbol) Interface function that must be implemented for every subtype of [`AbstractDevice`](@ref) and at least one [`CacheStrategy`](@ref). Return an `Expr` or `QuoteNode` accessing the variable identified by [`symbol`]. """ -function gen_access_expr end +function _gen_access_expr end + +""" + _gen_local_init(fc::FunctionCall, device::AbstractDevice, cache_strategy::CacheStrategy) + +Interface function that must be implemented for every subtype of [`AbstractDevice`](@ref) and at least one [`CacheStrategy`](@ref). +Return an `Expr` or `QuoteNode` that initializes the access expression returned by [`_gen_access_expr`](@ref) in the local scope. +This expression may be empty. For local variables it should be `local ::`. +""" +function _gen_local_init end """ kernel(gpu_type::Type{<:AbstractGPU}, graph::DAG, instance) diff --git a/src/devices/numa/impl.jl b/src/devices/numa/impl.jl index a418824..96a6504 100644 --- a/src/devices/numa/impl.jl +++ b/src/devices/numa/impl.jl @@ -64,32 +64,45 @@ function gen_cache_init_code(device::NumaNode) end """ - gen_access_expr(device::NumaNode, symbol::Symbol) + _gen_access_expr(device::NumaNode, ::LocalVariables, symbol::Symbol) -Generate code to access the variable designated by `symbol` on a [`NumaNode`](@ref), using the [`CacheStrategy`](@ref) set in the device. +Interface implementation, dispatched to from [`gen_access_expr`](@ref). """ -function gen_access_expr(device::NumaNode, symbol::Symbol) - return _gen_access_expr(device, device.cacheStrategy, symbol) +function _gen_access_expr(::NumaNode, ::LocalVariables, symbol::Symbol) + # TODO rewrite these with Expr instead of quote node + s = Symbol("data_$symbol") + quote_node = Meta.parse(":($s)") + return quote_node end """ - _gen_access_expr(device::NumaNode, ::LocalVariables, symbol::Symbol) + _gen_access_expr(device::NumaNode, ::Dictionary, symbol::Symbol) -Internal function for dispatch, used in [`gen_access_expr`](@ref). +Interface implementation, dispatched to from [`gen_access_expr`](@ref). """ -function _gen_access_expr(device::NumaNode, ::LocalVariables, symbol::Symbol) - s = Symbol("data_$symbol") - quoteNode = Meta.parse(":($s)") - return quoteNode +function _gen_access_expr(device::NumaNode, ::Dictionary, symbol::Symbol) + # TODO rewrite these with Expr instead of quote node + access_str = ":(cache_$(to_var_name(device.id))[:$symbol])" + quote_node = Meta.parse(access_str) + return quote_node end """ - _gen_access_expr(device::NumaNode, ::Dictionary, symbol::Symbol) + _gen_local_init(fc::FunctionCall, device::NumaNode, cache_strategy::LocalVariables) -Internal function for dispatch, used in [`gen_access_expr`](@ref). +Interface implementation, dispatched to from [`gen_local_init`](@ref). """ -function _gen_access_expr(device::NumaNode, ::Dictionary, symbol::Symbol) - accessStr = ":(cache_$(to_var_name(device.id))[:$symbol])" - quoteNode = Meta.parse(accessStr) - return quoteNode +function _gen_local_init(fc::FunctionCall, ::NumaNode, ::LocalVariables) + s = Symbol("data_$(fc.return_symbol)") + quote_node = Expr(:local, s) # TODO: figure out how to get type info for this local variable + return quote_node +end + +""" + _gen_local_init(fc::FunctionCall, device::NumaNode, cache_strategy::Dictionary) + +Interface implementation, dispatched to from [`gen_local_init`](@ref). +""" +function _gen_local_init(::FunctionCall, ::NumaNode, ::Dictionary) + return Exp() end From c8b9d79eafd5af9982888bca1cf692d259a944d4 Mon Sep 17 00:00:00 2001 From: Rubydragon Date: Thu, 3 Oct 2024 17:24:54 +0200 Subject: [PATCH 2/5] Fix type inference ability and make closures into anonymous functions --- Project.toml | 4 +- ext/devices/cuda/function.jl | 2 +- ext/devices/rocm/function.jl | 2 +- src/ComputableDAGs.jl | 1 + src/code_gen/function.jl | 2 +- src/code_gen/tape_machine.jl | 125 ++++++++++++++++++++++++++--------- src/code_gen/type.jl | 2 +- src/code_gen/utils.jl | 45 +++++++++++++ src/devices/numa/impl.jl | 2 +- src/node/create.jl | 1 - src/scheduler/greedy.jl | 36 +++++----- src/scheduler/interface.jl | 2 +- src/scheduler/type.jl | 5 +- src/task/compute.jl | 23 ++++++- 14 files changed, 188 insertions(+), 64 deletions(-) create mode 100644 src/code_gen/utils.jl diff --git a/Project.toml b/Project.toml index d488855..b2dd12c 100644 --- a/Project.toml +++ b/Project.toml @@ -22,14 +22,14 @@ CUDAExt = "CUDA" oneAPIExt = "oneAPI" [compat] -julia = "1.10" AMDGPU = "1" CUDA = "5" DataStructures = "0.18" NumaAllocators = "0.2" -oneAPI = "1" RuntimeGeneratedFunctions = "0.5" StaticArrays = "1" +julia = "1.10" +oneAPI = "1" [extras] SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" diff --git a/ext/devices/cuda/function.jl b/ext/devices/cuda/function.jl index e7c392e..7f0f889 100644 --- a/ext/devices/cuda/function.jl +++ b/ext/devices/cuda/function.jl @@ -7,7 +7,7 @@ function ComputableDAGs.kernel( init_caches = Expr(:block, tape.initCachesCode...) assign_inputs = Expr(:block, ComputableDAGs.expr_from_fc.(tape.inputAssignCode)...) # TODO: use gen_function_body here - code = Expr(:block, ComputableDAGs.expr_from_fc.(tape.computeCode)...) + code = Expr(:block, ComputableDAGs.expr_from_fc.(tape.schedule)...) function_id = ComputableDAGs.to_var_name(UUIDs.uuid1(ComputableDAGs.rng[1])) res_sym = eval( diff --git a/ext/devices/rocm/function.jl b/ext/devices/rocm/function.jl index ae8572b..64cfcfe 100644 --- a/ext/devices/rocm/function.jl +++ b/ext/devices/rocm/function.jl @@ -8,7 +8,7 @@ function ComputableDAGs.kernel( assign_inputs = Expr(:block, ComputableDAGs.expr_from_fc.(tape.inputAssignCode)...) # TODO use gen_function_body here - code = Expr(:block, ComputableDAGs.expr_from_fc.(tape.computeCode)...) + code = Expr(:block, ComputableDAGs.expr_from_fc.(tape.schedule)...) function_id = ComputableDAGs.to_var_name(UUIDs.uuid1(ComputableDAGs.rng[1])) res_sym = eval( diff --git a/src/ComputableDAGs.jl b/src/ComputableDAGs.jl index e9597ba..291fc4f 100644 --- a/src/ComputableDAGs.jl +++ b/src/ComputableDAGs.jl @@ -130,6 +130,7 @@ include("scheduler/interface.jl") include("scheduler/greedy.jl") include("code_gen/type.jl") +include("code_gen/utils.jl") include("code_gen/tape_machine.jl") include("code_gen/function.jl") diff --git a/src/code_gen/function.jl b/src/code_gen/function.jl index c87a8f3..b33ee1b 100644 --- a/src/code_gen/function.jl +++ b/src/code_gen/function.jl @@ -26,7 +26,7 @@ function get_compute_function( initCaches = Expr(:block, tape.initCachesCode...) assignInputs = Expr(:block, expr_from_fc.(tape.inputAssignCode)...) - code = gen_function_body(tape.computeCode; closures_size=closures_size) + code = gen_function_body(tape; closures_size=closures_size) functionId = to_var_name(UUIDs.uuid1(rng[1])) resSym = eval( diff --git a/src/code_gen/tape_machine.jl b/src/code_gen/tape_machine.jl index c01a45a..8079903 100644 --- a/src/code_gen/tape_machine.jl +++ b/src/code_gen/tape_machine.jl @@ -121,14 +121,11 @@ function gen_input_assignment_code( device = entry_device(machine) fc = FunctionCall( - RuntimeGeneratedFunction( - @__MODULE__, - context_module, - Expr(:->, :x, input_expr(instance, name, :x)), - ), + context_module.eval(Expr(:->, :x, input_expr(instance, name, :x))), SVector{0,Any}(), SVector{1,Symbol}(:input), symbol, + Nothing, device, ) @@ -140,40 +137,95 @@ function gen_input_assignment_code( end """ - gen_function_body(fc_vec::Vector{FunctionCall}; closures_size) + gen_function_body(tape::Tape; closures_size) -Generate the function body from the given `Vector` of [`FunctionCall`](@ref)s. +Generate the function body from the given [`Tape`](@ref). ## Keyword Arguments `closures_size`: The size of closures to generate (in lines of code). Closures introduce function barriers in the function body, preventing some optimizations by the compiler and therefore greatly reducing compile time. A value of 1 or less will disable the use of closures entirely. """ -function gen_function_body(fc_vec::Vector{FunctionCall}; closures_size::Int) +function gen_function_body(tape::Tape; closures_size::Int) + if closures_size > 1 + # only need to annotate types later when using closures + infer_types!(tape) + end + + fc_vec = tape.schedule + if (closures_size <= 1) return Expr(:block, expr_from_fc.(fc_vec)...) end closures = Vector{Expr}() - for i in 1:closures_size:length(fc_vec) - code_block = fc_vec[i:min(i + closures_size, length(fc_vec))] + # iterate from end to beginning + # this helps because we can collect all undefined arguments to the closures that have to be returned somewhere earlier + undefined_argument_symbols = Set{Symbol}() + # the final return symbol is the return of the entire generated function, it always has to be returned + push!(undefined_argument_symbols, eval(gen_access_expr(fc_vec[end]))) + + for i in length(fc_vec):(-closures_size):1 + e = i + b = max(i - closures_size, 1) + code_block = fc_vec[b:e] # collect `local var` statements that need to exist before the closure starts - # since the return symbols are always unique, this has to happen for each fc and there will be no duplicates local_inits = gen_local_init.(code_block) - closure = Expr( # call to the following closure (no arguments) - :call, - Expr( # create the closure: () -> code block; return nothing - :->, - :(), - Expr(# # actual function body of the closure - :block, - expr_from_fc.(code_block)..., - Expr(:return, :nothing), + return_symbols = eval.(gen_access_expr.(code_block)) + argument_symbols = Set{Symbol}() + + ret_symbols_set = Set(return_symbols) + for fc in code_block + for arg in fc.arguments + symbol = eval(_gen_access_expr(fc.device, fc.device.cacheStrategy, arg)) + + # symbol won't be defined if it is first calculated in the closure + # so don't add it to the arguments in this case + if !(symbol in ret_symbols_set) + push!(argument_symbols, symbol) + end + end + end + union!(undefined_argument_symbols, argument_symbols) + + intersect!(ret_symbols_set, undefined_argument_symbols) + return_symbols = Symbol[ret_symbols_set...] + + argument_symbols = [argument_symbols...] # make sure there is an order (doesn't matter which) + + closure = Expr( + :block, + Expr( + :(=), + Expr(:tuple, return_symbols...), + Expr( + :call, # call to the following closure (no arguments) + Expr( # create the closure: (args) -> code block; return (locals) + :->, + Expr(:tuple, argument_symbols...), # arguments in the closure definition + Expr( # actual function body of the closure + :block, + local_inits..., # declare local variables with type information inside the closure + expr_from_fc.(code_block)..., + Expr(:return, Expr(:tuple, return_symbols...)), + ), + ), + argument_symbols..., # arguments to the closure call ), ), ) + + setdiff!(undefined_argument_symbols, ret_symbols_set) + + #=Expr( + :macrocall, + Symbol("@closure"), + @__LINE__, + Expr( ) + )=# + # combine to one closure call, including all the local inits and the actual call to the closure - push!(closures, Expr(:block, local_inits..., closure)) + pushfirst!(closures, closure) end return Expr(:block, closures...) @@ -200,25 +252,33 @@ function gen_tape( scheduler::AbstractScheduler=GreedyScheduler(), ) schedule = schedule_dag(scheduler, graph, machine) + function_body = lower(schedule, machine) - # get inSymbols - inputSyms = Dict{String,Vector{Symbol}}() + # get input symbols + input_syms = Dict{String,Vector{Symbol}}() for node in get_entry_nodes(graph) - if !haskey(inputSyms, node.name) - inputSyms[node.name] = Vector{Symbol}() + if !haskey(input_syms, node.name) + input_syms[node.name] = Vector{Symbol}() end - push!(inputSyms[node.name], Symbol("$(to_var_name(node.id))_in")) + push!(input_syms[node.name], Symbol("$(to_var_name(node.id))_in")) end # get outSymbol outSym = Symbol(to_var_name(get_exit_node(graph).id)) - initCaches = gen_cache_init_code(machine) - assign_inputs = gen_input_assignment_code(inputSyms, instance, machine, context_module) + init_caches = gen_cache_init_code(machine) + assign_inputs = gen_input_assignment_code(input_syms, instance, machine, context_module) return Tape{input_type(instance)}( - initCaches, assign_inputs, schedule, inputSyms, outSym, Dict(), instance, machine + init_caches, + assign_inputs, + function_body, + input_syms, + outSym, + Dict(), + instance, + machine, ) end @@ -228,6 +288,9 @@ end Execute the given tape with the given input. For implementation reasons, this disregards the set [`CacheStrategy`](@ref) of the devices and always uses a dictionary. + +!!! warning + This is very slow and might not work. This is to be majorly revamped. """ function execute_tape(tape::Tape, input) cache = Dict{Symbol,Any}() @@ -238,10 +301,12 @@ function execute_tape(tape::Tape, input) @eval $expr end + compute_code = tape.schedule + for function_call in tape.inputAssignCode call_fc(function_call, cache) end - for function_call in tape.computeCode + for function_call in compute_code call_fc(function_call, cache) end diff --git a/src/code_gen/type.jl b/src/code_gen/type.jl index 1b81166..06405b9 100644 --- a/src/code_gen/type.jl +++ b/src/code_gen/type.jl @@ -12,7 +12,7 @@ TODO: update docs struct Tape{INPUT} initCachesCode::Vector{Expr} inputAssignCode::Vector{FunctionCall} - computeCode::Vector{FunctionCall} + schedule::Vector{FunctionCall} inputSymbols::Dict{String,Vector{Symbol}} outputSymbol::Symbol cache::Dict{Symbol,Any} diff --git a/src/code_gen/utils.jl b/src/code_gen/utils.jl new file mode 100644 index 0000000..8379b16 --- /dev/null +++ b/src/code_gen/utils.jl @@ -0,0 +1,45 @@ +""" + infer_types!(schedule::Vector{FunctionCall}) + +Infer the result type of each function call in the given schedule. Returns a dictionary with the result type for each [`Node`](@ref). This assumes that each node has only one statically inferrable return type and will throw an exceptin otherwise. +This also assumes that the given `Vector` contains a topological ordering of its nodes, such as returned by a call to [`schedule_dag`](@ref). +""" +function infer_types!(tape::Tape) + known_result_types = Dict{Symbol,Type}() + + # the only initially known type + known_result_types[:input] = input_type(tape.instance) + + for fc in tape.inputAssignCode + res_type = result_type(fc, known_result_types) + fc.return_type = res_type + known_result_types[fc.return_symbol] = res_type + end + + for fc in tape.schedule + res_type = result_type(fc, known_result_types) + fc.return_type = res_type + known_result_types[fc.return_symbol] = res_type + end + + return nothing +end + +""" + lower(schedule::Vector{Node}, machine::Machine) + +After [`schedule_dag`](@ref) has made a schedule of nodes, this function lowers the vector of [`Node`](@ref)s into a vector of [`FunctionCall`](@ref)s. +""" +function lower(schedule::Vector{Node}, machine::Machine) + calls = Vector{FunctionCall}() + + for node in schedule + if (node isa DataTaskNode && length(children(node)) == 0) + push!(calls, get_init_function_call(node, entry_device(machine))) + else + push!(calls, get_function_call(node)...) + end + end + + return calls +end diff --git a/src/devices/numa/impl.jl b/src/devices/numa/impl.jl index 96a6504..0c61b75 100644 --- a/src/devices/numa/impl.jl +++ b/src/devices/numa/impl.jl @@ -94,7 +94,7 @@ Interface implementation, dispatched to from [`gen_local_init`](@ref). """ function _gen_local_init(fc::FunctionCall, ::NumaNode, ::LocalVariables) s = Symbol("data_$(fc.return_symbol)") - quote_node = Expr(:local, s) # TODO: figure out how to get type info for this local variable + quote_node = Expr(:local, s, :(::), Symbol(fc.return_type)) # TODO: figure out how to get type info for this local variable return quote_node end diff --git a/src/node/create.jl b/src/node/create.jl index ea255b8..6331d39 100644 --- a/src/node/create.jl +++ b/src/node/create.jl @@ -1,4 +1,3 @@ - function DataTaskNode(t::AbstractDataTask, name="") return DataTaskNode( t, diff --git a/src/scheduler/greedy.jl b/src/scheduler/greedy.jl index d00cd99..63795f1 100644 --- a/src/scheduler/greedy.jl +++ b/src/scheduler/greedy.jl @@ -7,46 +7,42 @@ A greedy implementation of a scheduler, creating a topological ordering of nodes struct GreedyScheduler <: AbstractScheduler end function schedule_dag(::GreedyScheduler, graph::DAG, machine::Machine) - nodeQueue = PriorityQueue{Node,Int}() + node_queue = PriorityQueue{Node,Int}() # use a priority equal to the number of unseen children -> 0 are nodes that can be added for node in get_entry_nodes(graph) - enqueue!(nodeQueue, node => 0) + enqueue!(node_queue, node => 0) end - schedule = Vector{FunctionCall}() + schedule = Vector{Node}() sizehint!(schedule, length(graph.nodes)) # keep an accumulated cost of things scheduled to this device so far - deviceAccCost = PriorityQueue{AbstractDevice,Float64}() + device_acc_cost = PriorityQueue{AbstractDevice,Float64}() for device in machine.devices - enqueue!(deviceAccCost, device => 0) + enqueue!(device_acc_cost, device => 0) end - node = nothing - while !isempty(nodeQueue) - @assert peek(nodeQueue)[2] == 0 - node = dequeue!(nodeQueue) + local node + while !isempty(node_queue) + @assert peek(node_queue)[2] == 0 + node = dequeue!(node_queue) # assign the device with lowest accumulated cost to the node (if it's a compute node) if (isa(node, ComputeTaskNode)) - lowestDevice = peek(deviceAccCost)[1] - node.device = lowestDevice - deviceAccCost[lowestDevice] = compute_effort(task(node)) + lowest_device = peek(device_acc_cost)[1] + node.device = lowest_device + device_acc_cost[lowest_device] = compute_effort(task(node)) end - if (node isa DataTaskNode && length(children(node)) == 0) - push!(schedule, get_init_function_call(node, entry_device(machine))) - else - push!(schedule, get_function_call(node)...) - end + push!(schedule, node) for parent in parents(node) # reduce the priority of all parents by one - if (!haskey(nodeQueue, parent)) - enqueue!(nodeQueue, parent => length(children(parent)) - 1) + if (!haskey(node_queue, parent)) + enqueue!(node_queue, parent => length(children(parent)) - 1) else - nodeQueue[parent] = nodeQueue[parent] - 1 + node_queue[parent] = node_queue[parent] - 1 end end end diff --git a/src/scheduler/interface.jl b/src/scheduler/interface.jl index b420788..1dcbfbf 100644 --- a/src/scheduler/interface.jl +++ b/src/scheduler/interface.jl @@ -15,6 +15,6 @@ The function assigns each [`ComputeTaskNode`](@ref) of the [`DAG`](@ref) to one [`DataTaskNode`](@ref)s are not scheduled to devices since they do not compute. Instead, a data node transfers data from the [`AbstractDevice`](@ref) of their child to all [`AbstractDevice`](@ref)s of its parents. -Return a `Vector{FunctionCall}`. See [`FunctionCall`](@ref) +The produced schedule can be converted to [`FunctionCall`](@ref)s using [`lower`](@ref). """ function schedule_dag end diff --git a/src/scheduler/type.jl b/src/scheduler/type.jl index 0f76d07..008f677 100644 --- a/src/scheduler/type.jl +++ b/src/scheduler/type.jl @@ -5,11 +5,12 @@ using StaticArrays Type representing a function call with `N` parameters. Contains the function to call, argument symbols, the return symbol and the device to execute on. """ -struct FunctionCall{VectorType<:AbstractVector,N} +mutable struct FunctionCall{VectorType<:AbstractVector,N} func::Function # TODO: this should be a tuple value_arguments::SVector{N,Any} # value arguments for the function call, will be prepended to the other arguments - arguments::VectorType # symbols of the inputs to the function call + arguments::VectorType # symbols of the inputs to the function call return_symbol::Symbol + return_type::Type device::AbstractDevice end diff --git a/src/task/compute.jl b/src/task/compute.jl index 45b353f..8593646 100644 --- a/src/task/compute.jl +++ b/src/task/compute.jl @@ -11,7 +11,7 @@ For ordinary compute or data tasks the vector will contain exactly one element. function get_function_call( t::CompTask, device::AbstractDevice, in_symbols::AbstractVector, out_symbol::Symbol ) where {CompTask<:AbstractComputeTask} - return [FunctionCall(compute, SVector{1,Any}(t), in_symbols, out_symbol, device)] + return [FunctionCall(compute, SVector{1,Any}(t), in_symbols, out_symbol, Any, device)] end function get_function_call(node::ComputeTaskNode) @@ -42,7 +42,7 @@ function get_function_call(node::ComputeTaskNode) end function get_function_call(node::DataTaskNode) - @assert length(children(node)) == 1 "trying to call get_expression on a data task node that has $(length(node.children)) children instead of 1" + @assert length(children(node)) == 1 "trying to call get_function_call on a data task node that has $(length(node.children)) children instead of 1" # TODO: dispatch to device implementations generating the copy commands return [ @@ -51,19 +51,36 @@ function get_function_call(node::DataTaskNode) SVector{0,Any}(), SVector{1,Symbol}(Symbol(to_var_name(first(children(node))[1].id))), Symbol(to_var_name(node.id)), + Any, first(children(node))[1].device, ), ] end function get_init_function_call(node::DataTaskNode, device::AbstractDevice) - @assert isempty(children(node)) "trying to call get_init_expression on a data task node that is not an entry node." + @assert isempty(children(node)) "trying to call get_init_function_call on a data task node that is not an entry node." return FunctionCall( unpack_identity, SVector{0,Any}(), SVector{1,Symbol}(Symbol("$(to_var_name(node.id))_in")), Symbol(to_var_name(node.id)), + Any, device, ) end + +function result_type(fc::FunctionCall, known_res_types::Dict{Symbol,Type}) + argument_types = ( + typeof.(fc.value_arguments)..., getindex.(Ref(known_res_types), fc.arguments)... + ) + types = Base.return_types(fc.func, argument_types) + + if length(types) > 1 + throw( + "failure during type inference: function call $fc is type unstable, possible return types: $types", + ) + end + + return types[1] +end From 656f17e0193d982866ec0bc0cbaa3ff167470a0d Mon Sep 17 00:00:00 2001 From: Rubydragon Date: Fri, 4 Oct 2024 19:50:59 +0200 Subject: [PATCH 3/5] Improve required allocations dramatically --- src/code_gen/tape_machine.jl | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/src/code_gen/tape_machine.jl b/src/code_gen/tape_machine.jl index 8079903..b328a17 100644 --- a/src/code_gen/tape_machine.jl +++ b/src/code_gen/tape_machine.jl @@ -172,7 +172,6 @@ function gen_function_body(tape::Tape; closures_size::Int) local_inits = gen_local_init.(code_block) return_symbols = eval.(gen_access_expr.(code_block)) - argument_symbols = Set{Symbol}() ret_symbols_set = Set(return_symbols) for fc in code_block @@ -182,17 +181,14 @@ function gen_function_body(tape::Tape; closures_size::Int) # symbol won't be defined if it is first calculated in the closure # so don't add it to the arguments in this case if !(symbol in ret_symbols_set) - push!(argument_symbols, symbol) + push!(undefined_argument_symbols, symbol) end end end - union!(undefined_argument_symbols, argument_symbols) intersect!(ret_symbols_set, undefined_argument_symbols) return_symbols = Symbol[ret_symbols_set...] - argument_symbols = [argument_symbols...] # make sure there is an order (doesn't matter which) - closure = Expr( :block, Expr( @@ -200,9 +196,9 @@ function gen_function_body(tape::Tape; closures_size::Int) Expr(:tuple, return_symbols...), Expr( :call, # call to the following closure (no arguments) - Expr( # create the closure: (args) -> code block; return (locals) + Expr( # create the closure: () -> code block; return (locals) :->, - Expr(:tuple, argument_symbols...), # arguments in the closure definition + :(), # closure arguments (none) Expr( # actual function body of the closure :block, local_inits..., # declare local variables with type information inside the closure @@ -210,7 +206,6 @@ function gen_function_body(tape::Tape; closures_size::Int) Expr(:return, Expr(:tuple, return_symbols...)), ), ), - argument_symbols..., # arguments to the closure call ), ), ) From a01b9a8f5eb809ce0f16453d420d47f12227a18d Mon Sep 17 00:00:00 2001 From: AntonReinhard Date: Fri, 1 Nov 2024 13:00:20 +0100 Subject: [PATCH 4/5] Add error for failure in type inference --- src/code_gen/tape_machine.jl | 7 ------- src/task/compute.jl | 5 +++++ 2 files changed, 5 insertions(+), 7 deletions(-) diff --git a/src/code_gen/tape_machine.jl b/src/code_gen/tape_machine.jl index b328a17..78eabbe 100644 --- a/src/code_gen/tape_machine.jl +++ b/src/code_gen/tape_machine.jl @@ -212,13 +212,6 @@ function gen_function_body(tape::Tape; closures_size::Int) setdiff!(undefined_argument_symbols, ret_symbols_set) - #=Expr( - :macrocall, - Symbol("@closure"), - @__LINE__, - Expr( ) - )=# - # combine to one closure call, including all the local inits and the actual call to the closure pushfirst!(closures, closure) end diff --git a/src/task/compute.jl b/src/task/compute.jl index 8593646..dd78b09 100644 --- a/src/task/compute.jl +++ b/src/task/compute.jl @@ -81,6 +81,11 @@ function result_type(fc::FunctionCall, known_res_types::Dict{Symbol,Type}) "failure during type inference: function call $fc is type unstable, possible return types: $types", ) end + if isempty(types) + throw( + "failure during type inference: function call $fc has no return types, this is likely because no method matches the arguments", + ) + end return types[1] end From f0dbbc7e3085f2eb02764dc4558215dde6ed388f Mon Sep 17 00:00:00 2001 From: AntonReinhard Date: Fri, 1 Nov 2024 13:03:30 +0100 Subject: [PATCH 5/5] Turn closures off by default --- src/code_gen/function.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/code_gen/function.jl b/src/code_gen/function.jl index b33ee1b..f7afbf7 100644 --- a/src/code_gen/function.jl +++ b/src/code_gen/function.jl @@ -17,10 +17,10 @@ in your top level. ## Keyword Arguments -`closures_size` (default=500): The size of closures to use in the main generated code. This specifies the size of code blocks across which the compiler cannot optimize. For sufficiently large functions, a larger value means longer compile times but potentially faster execution time. +`closures_size` (default=0 (off)): The size of closures to use in the main generated code. This specifies the size of code blocks across which the compiler cannot optimize. For sufficiently large functions, a larger value means longer compile times but potentially faster execution time. """ function get_compute_function( - graph::DAG, instance, machine::Machine, context_module::Module; closures_size=500 + graph::DAG, instance, machine::Machine, context_module::Module; closures_size=0 ) tape = gen_tape(graph, instance, machine, context_module)