Skip to content

Commit

Permalink
wip: adjustments to the latest inlining interface changes
Browse files Browse the repository at this point in the history
  • Loading branch information
aviatesk committed Mar 18, 2024
1 parent 573ab89 commit 0f676ee
Showing 1 changed file with 41 additions and 65 deletions.
106 changes: 41 additions & 65 deletions src/compiler/interpreter.jl
Original file line number Diff line number Diff line change
Expand Up @@ -182,79 +182,74 @@ function simplify_kw(specTypes)
end
end

# https://github.com/JuliaLang/julia/pull/46965
@static if VERSION v"1.9.0-DEV.1535"

import Core.Compiler: CallInfo
function Core.Compiler.inlining_policy(interp::EnzymeInterpreter,
@nospecialize(src), @nospecialize(info::CallInfo), stmt_flag::UInt8, mi::MethodInstance, argtypes::Vector{Any})
const StmtFlag = @static VERSION v"1.11.0-DEV.377" ? UInt32 : UInt8

function enzyme_inlining_policy(interp::EnzymeInterpreter, mi::MethodInstance)
method_table = Core.Compiler.method_table(interp)
specTypes = simplify_kw(mi.specTypes)

if is_primitive_func(specTypes)
@safe_debug "Blocking inlining for primitive func" mi.specTypes
return nothing
return false
end

if is_alwaysinline_func(specTypes)
@safe_debug "Forcing inlining for primitive func" mi.specTypes
@assert src !== nothing
return src
return true
end

if EnzymeRules.is_inactive_from_sig(specTypes; world = interp.world, method_table)
@safe_debug "Blocking inlining due to inactive rule" mi.specTypes
return nothing
return false
end

if interp.mode == API.DEM_ForwardMode
if EnzymeRules.has_frule_from_sig(specTypes; world = interp.world, method_table)
@safe_debug "Blocking inlining due to frule" mi.specTypes
return nothing
return false
end
else
if EnzymeRules.has_rrule_from_sig(specTypes; world = interp.world, method_table)
@safe_debug "Blocking inling due to rrule" mi.specTypes
return nothing
return false
end
end
return nothing
end

@static if VERSION v"1.12.0-DEV.45"

# TODO

# https://github.com/JuliaLang/julia/pull/46965
elseif VERSION v"1.9.0-DEV.1535"

using Core.Compiler: CallInfo
function Core.Compiler.inlining_policy(interp::EnzymeInterpreter,
@nospecialize(src), @nospecialize(info::CallInfo), stmt_flag::StmtFlag, mi::MethodInstance, argtypes::Vector{Any})
ret = enzyme_inlining_policy(interp, mi)
if ret isa Bool
if ret
@assert src !== nothing
return src
else
return nothing
end
end
return Base.@invoke Core.Compiler.inlining_policy(interp::AbstractInterpreter,
src::Any, info::CallInfo, stmt_flag::UInt8, mi::MethodInstance, argtypes::Vector{Any})
src::Any, info::CallInfo, stmt_flag::StmtFlag, mi::MethodInstance, argtypes::Vector{Any})
end

# https://github.com/JuliaLang/julia/pull/41328
elseif isdefined(Core.Compiler, :is_stmt_inline)

function Core.Compiler.inlining_policy(interp::EnzymeInterpreter,
@nospecialize(src), stmt_flag::UInt8, mi::MethodInstance, argtypes::Vector{Any})

method_table = Core.Compiler.method_table(interp)
specTypes = simplify_kw(mi.specTypes)

if is_primitive_func(specTypes)
return nothing
end

if is_alwaysinline_func(specTypes)
@assert src !== nothing
return src
end

if EnzymeRules.is_inactive_from_sig(specTypes; world = interp.world, method_table)
return nothing
end
if interp.mode == API.DEM_ForwardMode
if EnzymeRules.has_frule_from_sig(specTypes; world = interp.world, method_table)
return nothing
end
else
if EnzymeRules.has_rrule_from_sig(specTypes; world = interp.world, method_table)
ret = enzyme_inlining_policy(interp, mi)
if ret isa Bool
if ret
@assert src !== nothing
return src
else
return nothing
end
end

return Base.@invoke Core.Compiler.inlining_policy(interp::AbstractInterpreter,
src::Any, stmt_flag::UInt8, mi::MethodInstance, argtypes::Vector{Any})
end
Expand All @@ -269,34 +264,15 @@ end
Core.Compiler.inlining_policy(interp::EnzymeInterpreter) = EnzymeInliningPolicy(interp)

function Core.Compiler.resolve_todo(todo::InliningTodo, state::InliningState{S, T, <:EnzymeInliningPolicy}) where {S<:Union{Nothing, Core.Compiler.EdgeTracker}, T}
mi = todo.mi
specTypes = simplify_kw(mi.specTypes)

if is_primitive_func(specTypes)
return Core.Compiler.compileable_specialization(state.et, todo.spec.match)
end

if is_alwaysinline_func(specTypes)
@assert false "Need to mark resolve_todo function as alwaysinline, but don't know how"
end

interp = state.policy.interp
method_table = Core.Compiler.method_table(interp)
if EnzymeRules.is_inactive_from_sig(specTypes; world = interp.world, method_table)
return Core.Compiler.compileable_specialization(state.et, todo.spec.match)
end
if interp.mode == API.DEM_ForwardMode
if EnzymeRules.has_frule_from_sig(specTypes; world = interp.world, method_table)
return Core.Compiler.compileable_specialization(state.et, todo.spec.match)
end
else
if EnzymeRules.has_rrule_from_sig(specTypes; world = interp.world, method_table)
ret = enzyme_inlining_policy(state.policy.interp, todo.mi)
if ret isa Bool
if ret
@assert false "Need to mark resolve_todo function as alwaysinline, but don't know how"
else
return Core.Compiler.compileable_specialization(state.et, todo.spec.match)
end
end

return Base.@invoke Core.Compiler.resolve_todo(
todo::InliningTodo, state::InliningState)
return Base.@invoke Core.Compiler.resolve_todo(todo::InliningTodo, state::InliningState)
end

end # @static if isdefined(Core.Compiler, :is_stmt_inline)
Expand Down

0 comments on commit 0f676ee

Please sign in to comment.