Skip to content

Commit

Permalink
add support for inlining :invoke exprs (#46606)
Browse files Browse the repository at this point in the history
`NativeInterpreter` won't need this, but provide a support for `:invoke`
exprs here for external `AbstractInterpreter`s that may run the inlining
pass multiple times.

Co-authored-by: Shuhei Kadowaki <aviatesk@gmail.com>
  • Loading branch information
JeffBezanson and aviatesk authored Oct 1, 2022
1 parent 02574e3 commit 0d00660
Show file tree
Hide file tree
Showing 4 changed files with 114 additions and 28 deletions.
2 changes: 1 addition & 1 deletion base/compiler/optimize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -546,7 +546,7 @@ function run_passes(
@pass "slot2reg" ir = slot2reg(ir, ci, sv)
# TODO: Domsorting can produce an updated domtree - no need to recompute here
@pass "compact 1" ir = compact!(ir)
@pass "Inlining" ir = ssa_inlining_pass!(ir, ir.linetable, sv.inlining, ci.propagate_inbounds)
@pass "Inlining" ir = ssa_inlining_pass!(ir, sv.inlining, ci.propagate_inbounds)
# @timeit "verify 2" verify_ir(ir)
@pass "compact 2" ir = compact!(ir)
@pass "SROA" ir = sroa_pass!(ir, sv.inlining)
Expand Down
120 changes: 94 additions & 26 deletions base/compiler/ssair/inlining.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ struct Signature
end

struct ResolvedInliningSpec
# The LineTable and IR of the inlinee
# The IR of the inlinee
ir::IRCode
# If the function being inlined is a single basic block we can use a
# simpler inlining algorithm. This flag determines whether that's allowed
Expand Down Expand Up @@ -62,7 +62,7 @@ end

struct InliningCase
sig # Type
item # Union{InliningTodo, MethodInstance, ConstantCase}
item # Union{InliningTodo, InvokeCase, ConstantCase}
function InliningCase(@nospecialize(sig), @nospecialize(item))
@assert isa(item, Union{InliningTodo, InvokeCase, ConstantCase}) "invalid inlining item"
return new(sig, item)
Expand Down Expand Up @@ -97,13 +97,13 @@ function add_inlining_backedge!((; et, invokesig)::InliningEdgeTracker, mi::Meth
return nothing
end

function ssa_inlining_pass!(ir::IRCode, linetable::Vector{LineInfoNode}, state::InliningState, propagate_inbounds::Bool)
function ssa_inlining_pass!(ir::IRCode, state::InliningState, propagate_inbounds::Bool)
# Go through the function, performing simple inlining (e.g. replacing call by constants
# and analyzing legality of inlining).
@timeit "analysis" todo = assemble_inline_todo!(ir, state)
isempty(todo) && return ir
# Do the actual inlining for every call we identified
@timeit "execution" ir = batch_inline!(todo, ir, linetable, propagate_inbounds, state.params)
@timeit "execution" ir = batch_inline!(todo, ir, propagate_inbounds, state.params)
return ir
end

Expand Down Expand Up @@ -656,7 +656,7 @@ function ir_inline_unionsplit!(compact::IncrementalCompact, idx::Int,
return insert_node_here!(compact, NewInstruction(pn, typ, line))
end

function batch_inline!(todo::Vector{Pair{Int, Any}}, ir::IRCode, linetable::Vector{LineInfoNode}, propagate_inbounds::Bool, params::OptimizationParams)
function batch_inline!(todo::Vector{Pair{Int, Any}}, ir::IRCode, propagate_inbounds::Bool, params::OptimizationParams)
# Compute the new CFG first (modulo statement ranges, which will be computed below)
state = CFGInliningState(ir)
for (idx, item) in todo
Expand Down Expand Up @@ -693,7 +693,12 @@ function batch_inline!(todo::Vector{Pair{Int, Any}}, ir::IRCode, linetable::Vect
for ((old_idx, idx), stmt) in compact
if old_idx == inline_idx
stmt = stmt::Expr
argexprs = copy(stmt.args)
if stmt.head === :invoke
argexprs = stmt.args[2:end]
else
@assert stmt.head === :call
argexprs = copy(stmt.args)
end
refinish = false
if compact.result_idx == first(compact.result_bbs[compact.active_result_bb].stmts)
compact.active_result_bb -= 1
Expand All @@ -712,9 +717,9 @@ function batch_inline!(todo::Vector{Pair{Int, Any}}, ir::IRCode, linetable::Vect
end
end
if isa(item, InliningTodo)
compact.ssa_rename[old_idx] = ir_inline_item!(compact, idx, argexprs, linetable, item, boundscheck, state.todo_bbs)
compact.ssa_rename[old_idx] = ir_inline_item!(compact, idx, argexprs, ir.linetable, item, boundscheck, state.todo_bbs)
elseif isa(item, UnionSplit)
compact.ssa_rename[old_idx] = ir_inline_unionsplit!(compact, idx, argexprs, linetable, item, boundscheck, state.todo_bbs, params)
compact.ssa_rename[old_idx] = ir_inline_unionsplit!(compact, idx, argexprs, ir.linetable, item, boundscheck, state.todo_bbs, params)
end
compact[idx] = nothing
refinish && finish_current_bb!(compact, 0)
Expand Down Expand Up @@ -847,6 +852,27 @@ end
compileable_specialization(result::InferenceResult, args...; kwargs...) = (@nospecialize;
compileable_specialization(result.linfo, args...; kwargs...))

struct CachedResult
src::Any
effects::Effects
CachedResult(@nospecialize(src), effects::Effects) = new(src, effects)
end
@inline function get_cached_result(state::InliningState, mi::MethodInstance)
code = get(state.mi_cache, mi, nothing)
if code isa CodeInstance
if use_const_api(code)
# in this case function can be inlined to a constant
return ConstantCase(quoted(code.rettype_const))
else
src = @atomic :monotonic code.inferred
end
effects = decode_effects(code.ipo_purity_bits)
return CachedResult(src, effects)
else # fallback pass for external AbstractInterpreter cache
return CachedResult(code, Effects())
end
end

function resolve_todo(todo::InliningTodo, state::InliningState, flag::UInt8)
mi = todo.mi
(; match, argtypes, invokesig) = todo.spec::DelayedInliningSpec
Expand All @@ -864,20 +890,12 @@ function resolve_todo(todo::InliningTodo, state::InliningState, flag::UInt8)
end
effects = match.ipo_effects
else
code = get(state.mi_cache, mi, nothing)
if code isa CodeInstance
if use_const_api(code)
# in this case function can be inlined to a constant
add_inlining_backedge!(et, mi)
return ConstantCase(quoted(code.rettype_const))
else
src = @atomic :monotonic code.inferred
end
effects = decode_effects(code.ipo_purity_bits)
else # fallback pass for external AbstractInterpreter cache
effects = Effects()
src = code
cached_result = get_cached_result(state, mi)
if cached_result isa ConstantCase
add_inlining_backedge!(et, mi)
return cached_result
end
(; src, effects) = cached_result
end

# the duplicated check might have been done already within `analyze_method!`, but still
Expand All @@ -896,6 +914,28 @@ function resolve_todo(todo::InliningTodo, state::InliningState, flag::UInt8)
return InliningTodo(mi, retrieve_ir_for_inlining(mi, src), effects)
end

function resolve_todo(mi::MethodInstance, argtypes::Vector{Any}, state::InliningState, flag::UInt8)
if !state.params.inlining || is_stmt_noinline(flag)
return nothing
end

et = InliningEdgeTracker(state.et, nothing)

cached_result = get_cached_result(state, mi)
if cached_result isa ConstantCase
add_inlining_backedge!(et, mi)
return cached_result
end
(; src, effects) = cached_result

src = inlining_policy(state.interp, src, flag, mi, argtypes)

src === nothing && return nothing

add_inlining_backedge!(et, mi)
return InliningTodo(mi, retrieve_ir_for_inlining(mi, src), effects)
end

function resolve_todo((; fully_covered, atype, cases, #=bbs=#)::UnionSplit, state::InliningState, flag::UInt8)
ncases = length(cases)
newcases = Vector{InliningCase}(undef, ncases)
Expand Down Expand Up @@ -1015,7 +1055,7 @@ function handle_single_case!(
isinvoke && rewrite_invoke_exprargs!(stmt)
push!(todo, idx=>(case::InliningTodo))
end
nothing
return nothing
end

rewrite_invoke_exprargs!(expr::Expr) = (expr.args = invoke_rewrite(expr.args); expr)
Expand Down Expand Up @@ -1068,14 +1108,21 @@ end

function call_sig(ir::IRCode, stmt::Expr)
isempty(stmt.args) && return nothing
ft = argextype(stmt.args[1], ir)
if stmt.head === :call
offset = 1
elseif stmt.head === :invoke
offset = 2
else
return nothing
end
ft = argextype(stmt.args[offset], ir)
has_free_typevars(ft) && return nothing
f = singleton_type(ft)
f === Core.Intrinsics.llvmcall && return nothing
f === Core.Intrinsics.cglobal && return nothing
argtypes = Vector{Any}(undef, length(stmt.args))
argtypes[1] = ft
for i = 2:length(stmt.args)
for i = (offset+1):length(stmt.args)
a = argextype(stmt.args[i], ir)
(a === Bottom || isvarargtype(a)) && return nothing
argtypes[i] = a
Expand Down Expand Up @@ -1244,6 +1291,10 @@ function process_simple!(ir::IRCode, idx::Int, state::InliningState, todo::Vecto
inline_splatnew!(ir, idx, stmt, rt)
elseif head === :new_opaque_closure
narrow_opaque_closure!(ir, stmt, ir.stmts[idx][:info], state)
elseif head === :invoke
sig = call_sig(ir, stmt)
sig === nothing && return nothing
return stmt, sig
end
check_effect_free!(ir, idx, stmt, rt)
return nothing
Expand Down Expand Up @@ -1593,6 +1644,16 @@ function handle_finalizer_call!(
return nothing
end

function handle_invoke!(todo::Vector{Pair{Int,Any}},
idx::Int, stmt::Expr, flag::UInt8, sig::Signature, state::InliningState)
mi = stmt.args[1]::MethodInstance
case = resolve_todo(mi, sig.argtypes, state, flag)
if case !== nothing
push!(todo, idx=>(case::InliningTodo))
end
return nothing
end

function inline_const_if_inlineable!(inst::Instruction)
rt = inst[:type]
if rt isa Const && is_inlineable_constant(rt.val)
Expand All @@ -1611,6 +1672,15 @@ function assemble_inline_todo!(ir::IRCode, state::InliningState)
simpleres === nothing && continue
stmt, sig = simpleres

flag = ir.stmts[idx][:flag]

# `NativeInterpreter` won't need this, but provide a support for `:invoke` exprs here
# for external `AbstractInterpreter`s that may run the inlining pass multiple times
if isexpr(stmt, :invoke)
handle_invoke!(todo, idx, stmt, flag, sig, state)
continue
end

info = ir.stmts[idx][:info]

# Check whether this call was @pure and evaluates to a constant
Expand All @@ -1623,8 +1693,6 @@ function assemble_inline_todo!(ir::IRCode, state::InliningState)
continue
end

flag = ir.stmts[idx][:flag]

if isa(info, OpaqueClosureCallInfo)
result = info.result
if isa(result, ConstPropResult)
Expand Down
2 changes: 1 addition & 1 deletion test/compiler/EscapeAnalysis/EAUtils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ function run_passes_with_ea(interp::EscapeAnalyzer, ci::CodeInfo, sv::Optimizati
interp.state = state
interp.linfo = sv.linfo
end
@timeit "Inlining" ir = ssa_inlining_pass!(ir, ir.linetable, sv.inlining, ci.propagate_inbounds)
@timeit "Inlining" ir = ssa_inlining_pass!(ir, sv.inlining, ci.propagate_inbounds)
# @timeit "verify 2" verify_ir(ir)
@timeit "compact 2" ir = compact!(ir)
if caller.linfo.specTypes === interp.entry_tt && interp.optimize
Expand Down
18 changes: 18 additions & 0 deletions test/compiler/inline.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1743,3 +1743,21 @@ let src = code_typed1((Atomic{Int},Union{Int,Float64})) do a, b
end
@test count(isinvokemodify(:mymax), src.code) == 2
end

# apply `ssa_inlining_pass` multiple times
let interp = Core.Compiler.NativeInterpreter()
# check if callsite `@noinline` annotation works
ir, = Base.code_ircode((Int,Int); optimize_until="inlining", interp) do a, b
@noinline a*b
end |> only
i = findfirst(isinvoke(:*), ir.stmts.inst)
@test i !== nothing

# ok, now delete the callsite flag, and see the second inlining pass can inline the call
@eval Core.Compiler $ir.stmts[$i][:flag] &= ~IR_FLAG_NOINLINE
inlining = Core.Compiler.InliningState(Core.Compiler.OptimizationParams(interp), nothing,
Core.Compiler.code_cache(interp), interp)
ir = Core.Compiler.ssa_inlining_pass!(ir, inlining, false)
@test count(isinvoke(:*), ir.stmts.inst) == 0
@test count(iscall((ir, Core.Intrinsics.mul_int)), ir.stmts.inst) == 1
end

0 comments on commit 0d00660

Please sign in to comment.