Skip to content

Commit

Permalink
optimizer: supports callsite annotations of inlining, fixes #18773
Browse files Browse the repository at this point in the history
Enable `@inline`/`@noinline` annotations on function callsites.
From #40754.

Now `@inline` and `@noinline` can be applied to a code block and then
the compiler will try to (not) inline calls within the block:
```julia
@inline f(...)

@noinline f(...) + g(...)

@inline f(args...) = ...
```

Here are couple of notes on how those callsite annotations will work:
- callsite annotation always has the precedence over the annotation
  applied to the definition of the called function, whichever we use
  `@inline`/`@noinline`:
  ```julia
  @inline function explicit_inline(args...)
      # body
  end

  let
      @noinline explicit_inline(args...) # this call will not be inlined
  end
  ```
- when callsite annotations are nested, the innermost annotations has
  the precedence
  ```julia
  @noinline let a0, b0 = ...
      a = @inline f(a0)  # the compiler will try to inline this call
      b = notinlined(b0) # the compiler will NOT try to inline this call
      return a, b
  end
  ```
They're both tested and included in documentations.

Co-authored-by: Joseph Tan <jdtan638@gmail.com>
  • Loading branch information
aviatesk and dghosef committed Jun 23, 2021
1 parent 37c0b06 commit d21935c
Show file tree
Hide file tree
Showing 16 changed files with 294 additions and 70 deletions.
11 changes: 7 additions & 4 deletions base/compiler/abstractinterpretation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -549,7 +549,7 @@ function maybe_get_const_prop_profitable(interp::AbstractInterpreter, result::Me
return nothing
end
mi = mi::MethodInstance
if !force && !const_prop_methodinstance_heuristic(interp, method, mi)
if !force && !const_prop_methodinstance_heuristic(interp, match, mi)
add_remark!(interp, sv, "[constprop] Disabled by heuristic")
return nothing
end
Expand Down Expand Up @@ -648,7 +648,8 @@ end
# This is a heuristic to avoid trying to const prop through complicated functions
# where we would spend a lot of time, but are probably unlikely to get an improved
# result anyway.
function const_prop_methodinstance_heuristic(interp::AbstractInterpreter, method::Method, mi::MethodInstance)
function const_prop_methodinstance_heuristic(interp::AbstractInterpreter, match::MethodMatch, mi::MethodInstance)
method = match.method
if method.is_for_opaque_closure
# Not inlining an opaque closure can be very expensive, so be generous
# with the const-prop-ability. It is quite possible that we can't infer
Expand All @@ -666,7 +667,8 @@ function const_prop_methodinstance_heuristic(interp::AbstractInterpreter, method
if isdefined(code, :inferred) && !cache_inlineable
cache_inf = code.inferred
if !(cache_inf === nothing)
cache_inlineable = inlining_policy(interp)(cache_inf) !== nothing
# TODO maybe we want to respect callsite `@inline`/`@noinline` annotations here ?
cache_inlineable = inlining_policy(interp)(cache_inf, nothing, match) !== nothing
end
end
if !cache_inlineable
Expand Down Expand Up @@ -1806,7 +1808,8 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState)
if isa(fname, SlotNumber)
changes = StateUpdate(fname, VarState(Any, false), changes, false)
end
elseif hd === :inbounds || hd === :meta || hd === :loopinfo || hd === :code_coverage_effect
elseif hd === :code_coverage_effect ||
(hd !== :boundscheck && hd !== nothing && is_meta_expr_head(hd)) # :boundscheck can be narrowed to Bool
# these do not generate code
else
t = abstract_eval_statement(interp, stmt, changes, frame)
Expand Down
48 changes: 43 additions & 5 deletions base/compiler/optimize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,20 @@ struct InliningState{S <: Union{EdgeTracker, Nothing}, T, P}
policy::P
end

function default_inlining_policy(@nospecialize(src))
function default_inlining_policy(@nospecialize(src), stmt_flag::Union{Nothing,UInt8}, match::Union{MethodMatch,InferenceResult})
if isa(src, CodeInfo) || isa(src, Vector{UInt8})
src_inferred = ccall(:jl_ir_flag_inferred, Bool, (Any,), src)
src_inlineable = ccall(:jl_ir_flag_inlineable, Bool, (Any,), src)
src_inlineable = is_stmt_inline(stmt_flag) || ccall(:jl_ir_flag_inlineable, Bool, (Any,), src)
return src_inferred && src_inlineable ? src : nothing
end
if isa(src, OptimizationState) && isdefined(src, :ir)
return src.src.inlineable ? src.ir : nothing
elseif isa(src, OptimizationState) && isdefined(src, :ir)
return (is_stmt_inline(stmt_flag) || src.src.inlineable) ? src.ir : nothing
elseif src === nothing && is_stmt_inline(stmt_flag) && isa(match, MethodMatch)
# when the source isn't available at this moment, try to re-infer and inline it
# HACK in order to avoid cycles here, we disable inlining and makes sure the following inference never comes here
# TODO sort out `AbstractInterpreter` interface to handle this well, and also inference should try to keep the source if the statement will be inlined
interp = NativeInterpreter(; opt_params = OptimizationParams(; inlining = false))
src, rt = typeinf_code(interp, match.method, match.spec_types, match.sparams, true)
return src
end
return nothing
end
Expand Down Expand Up @@ -134,6 +140,10 @@ const SLOT_USEDUNDEF = 32 # slot has uses that might raise UndefVarError
# This statement was marked as @inbounds by the user. If replaced by inlining,
# any contained boundschecks may be removed
const IR_FLAG_INBOUNDS = 0x01
# This statement was marked as @inline by the user
const IR_FLAG_INLINE = 0x01 << 1
# This statement was marked as @noinline by the user
const IR_FLAG_NOINLINE = 0x01 << 2
# This statement may be removed if its result is unused. In particular it must
# thus be both pure and effect free.
const IR_FLAG_EFFECT_FREE = 0x01 << 4
Expand Down Expand Up @@ -179,6 +189,11 @@ function isinlineable(m::Method, me::OptimizationState, params::OptimizationPara
return inlineable
end

is_stmt_inline(stmt_flag::UInt8) = stmt_flag & IR_FLAG_INLINE != 0
is_stmt_inline(::Nothing) = false
is_stmt_noinline(stmt_flag::UInt8) = stmt_flag & IR_FLAG_NOINLINE != 0
is_stmt_noinline(::Nothing) = false # not used for now

# These affect control flow within the function (so may not be removed
# if there is no usage within the function), but don't affect the purity
# of the function as a whole.
Expand Down Expand Up @@ -366,6 +381,7 @@ function convert_to_ircode(ci::CodeInfo, code::Vector{Any}, coverage::Bool, narg
renumber_ir_elements!(code, changemap, labelmap)

inbounds_depth = 0 # Number of stacked inbounds
inline_flags = BitVector()
meta = Any[]
flags = fill(0x00, length(code))
for i = 1:length(code)
Expand All @@ -380,6 +396,20 @@ function convert_to_ircode(ci::CodeInfo, code::Vector{Any}, coverage::Bool, narg
inbounds_depth -= 1
end
stmt = nothing
elseif isexpr(stmt, :inline)
if stmt.args[1]::Bool
push!(inline_flags, true)
else
pop!(inline_flags)
end
stmt = nothing
elseif isexpr(stmt, :noinline)
if stmt.args[1]::Bool
push!(inline_flags, false)
else
pop!(inline_flags)
end
stmt = nothing
else
stmt = normalize(stmt, meta)
end
Expand All @@ -388,8 +418,16 @@ function convert_to_ircode(ci::CodeInfo, code::Vector{Any}, coverage::Bool, narg
if inbounds_depth > 0
flags[i] |= IR_FLAG_INBOUNDS
end
if !isempty(inline_flags)
if last(inline_flags)
flags[i] |= IR_FLAG_INLINE
else
flags[i] |= IR_FLAG_NOINLINE
end
end
end
end
@assert isempty(inline_flags) "malformed meta flags"
strip_trailing_junk!(ci, code, stmtinfo, flags)
cfg = compute_basic_blocks(code)
types = Any[]
Expand Down
69 changes: 31 additions & 38 deletions base/compiler/ssair/inlining.jl
Original file line number Diff line number Diff line change
Expand Up @@ -600,6 +600,7 @@ function rewrite_apply_exprargs!(ir::IRCode, todo::Vector{Pair{Int, Any}}, idx::
argexprs::Vector{Any}, atypes::Vector{Any}, arginfos::Vector{Any},
arg_start::Int, istate::InliningState)

flag = ir.stmts[idx][:flag]
new_argexprs = Any[argexprs[arg_start]]
new_atypes = Any[atypes[arg_start]]
# loop over original arguments and flatten any known iterators
Expand Down Expand Up @@ -655,8 +656,9 @@ function rewrite_apply_exprargs!(ir::IRCode, todo::Vector{Pair{Int, Any}}, idx::
info = call.info
handled = false
if isa(info, ConstCallInfo)
if maybe_handle_const_call!(ir, state1.id, new_stmt, info, new_sig,
call.rt, istate, false, todo)
if !is_stmt_noinline(flag) && maybe_handle_const_call!(
ir, state1.id, new_stmt, info, new_sig,call.rt, istate, flag, false, todo)

handled = true
else
info = info.call
Expand All @@ -667,7 +669,7 @@ function rewrite_apply_exprargs!(ir::IRCode, todo::Vector{Pair{Int, Any}}, idx::
MethodMatchInfo[info] : info.matches
# See if we can inline this call to `iterate`
analyze_single_call!(ir, todo, state1.id, new_stmt,
new_sig, call.rt, info, istate)
new_sig, call.rt, info, istate, flag)
end
if i != length(thisarginfo.each)
valT = getfield_tfunc(call.rt, Const(1))
Expand Down Expand Up @@ -716,16 +718,16 @@ function compileable_specialization(et::Union{EdgeTracker, Nothing}, result::Inf
return mi
end

function resolve_todo(todo::InliningTodo, state::InliningState)
spec = todo.spec::DelayedInliningSpec
function resolve_todo(todo::InliningTodo, state::InliningState, flag::UInt8)
(; match) = todo.spec::DelayedInliningSpec

#XXX: update_valid_age!(min_valid[1], max_valid[1], sv)
isconst, src = false, nothing
if isa(spec.match, InferenceResult)
let inferred_src = spec.match.src
if isa(match, InferenceResult)
let inferred_src = match.src
if isa(inferred_src, Const)
if !is_inlineable_constant(inferred_src.val)
return compileable_specialization(state.et, spec.match)
return compileable_specialization(state.et, match)
end
isconst, src = true, quoted(inferred_src.val)
else
Expand Down Expand Up @@ -753,12 +755,10 @@ function resolve_todo(todo::InliningTodo, state::InliningState)
return ConstantCase(src)
end

if src !== nothing
src = state.policy(src)
end
src = state.policy(src, flag, match)

if src === nothing
return compileable_specialization(et, spec.match)
return compileable_specialization(et, match)
end

if isa(src, IRCode)
Expand All @@ -769,17 +769,9 @@ function resolve_todo(todo::InliningTodo, state::InliningState)
return InliningTodo(todo.mi, src)
end

function resolve_todo(todo::UnionSplit, state::InliningState)
function resolve_todo(todo::UnionSplit, state::InliningState, flag::UInt8)
UnionSplit(todo.fully_covered, todo.atype,
Pair{Any,Any}[sig=>resolve_todo(item, state) for (sig, item) in todo.cases])
end

function resolve_todo!(todo::Vector{Pair{Int, Any}}, state::InliningState)
for i = 1:length(todo)
idx, item = todo[i]
todo[i] = idx=>resolve_todo(item, state)
end
todo
Pair{Any,Any}[sig=>resolve_todo(item, state, flag) for (sig, item) in todo.cases])
end

function validate_sparams(sparams::SimpleVector)
Expand All @@ -790,7 +782,7 @@ function validate_sparams(sparams::SimpleVector)
end

function analyze_method!(match::MethodMatch, atypes::Vector{Any},
state::InliningState, @nospecialize(stmttyp))
state::InliningState, @nospecialize(stmttyp), flag::UInt8)
method = match.method
methsig = method.sig

Expand All @@ -806,11 +798,9 @@ function analyze_method!(match::MethodMatch, atypes::Vector{Any},
end

# Bail out if any static parameters are left as TypeVar
ok = true
validate_sparams(match.sparams) || return nothing


if !state.params.inlining
if !state.params.inlining || is_stmt_noinline(flag)
return compileable_specialization(state.et, match)
end

Expand All @@ -824,7 +814,7 @@ function analyze_method!(match::MethodMatch, atypes::Vector{Any},
# If we don't have caches here, delay resolving this MethodInstance
# until the batch inlining step (or an external post-processing pass)
state.mi_cache === nothing && return todo
return resolve_todo(todo, state)
return resolve_todo(todo, state, flag)
end

function InliningTodo(mi::MethodInstance, ir::IRCode)
Expand Down Expand Up @@ -1050,7 +1040,7 @@ is_builtin(s::Signature) =
s.ft Builtin

function inline_invoke!(ir::IRCode, idx::Int, sig::Signature, info::InvokeCallInfo,
state::InliningState, todo::Vector{Pair{Int, Any}})
state::InliningState, todo::Vector{Pair{Int, Any}}, flag::UInt8)
stmt = ir.stmts[idx][:inst]
calltype = ir.stmts[idx][:type]

Expand All @@ -1064,7 +1054,7 @@ function inline_invoke!(ir::IRCode, idx::Int, sig::Signature, info::InvokeCallIn
atypes = atypes[4:end]
pushfirst!(atypes, atype0)

result = analyze_method!(info.match, atypes, state, calltype)
result = analyze_method!(info.match, atypes, state, calltype, flag)
handle_single_case!(ir, stmt, idx, result, true, todo)
return nothing
end
Expand Down Expand Up @@ -1159,7 +1149,7 @@ end

function analyze_single_call!(ir::IRCode, todo::Vector{Pair{Int, Any}}, idx::Int, @nospecialize(stmt),
sig::Signature, @nospecialize(calltype), infos::Vector{MethodMatchInfo},
state::InliningState)
state::InliningState, flag::UInt8)
cases = Pair{Any, Any}[]
signature_union = Union{}
only_method = nothing # keep track of whether there is one matching method
Expand Down Expand Up @@ -1192,7 +1182,7 @@ function analyze_single_call!(ir::IRCode, todo::Vector{Pair{Int, Any}}, idx::Int
fully_covered = false
continue
end
case = analyze_method!(match, sig.atypes, state, calltype)
case = analyze_method!(match, sig.atypes, state, calltype, flag)
if case === nothing
fully_covered = false
continue
Expand All @@ -1219,7 +1209,7 @@ function analyze_single_call!(ir::IRCode, todo::Vector{Pair{Int, Any}}, idx::Int
match = meth[1]
end
fully_covered = true
case = analyze_method!(match, sig.atypes, state, calltype)
case = analyze_method!(match, sig.atypes, state, calltype, flag)
case === nothing && return
push!(cases, Pair{Any,Any}(match.spec_types, case))
end
Expand All @@ -1241,7 +1231,7 @@ end

function maybe_handle_const_call!(ir::IRCode, idx::Int, stmt::Expr,
info::ConstCallInfo, sig::Signature, @nospecialize(calltype),
state::InliningState,
state::InliningState, flag::UInt8,
isinvoke::Bool, todo::Vector{Pair{Int, Any}})
# when multiple matches are found, bail out and later inliner will union-split this signature
# TODO effectively use multiple constant analysis results here
Expand All @@ -1253,7 +1243,7 @@ function maybe_handle_const_call!(ir::IRCode, idx::Int, stmt::Expr,
validate_sparams(item.mi.sparam_vals) || return true
mthd_sig = item.mi.def.sig
mistypes = item.mi.specTypes
state.mi_cache !== nothing && (item = resolve_todo(item, state))
state.mi_cache !== nothing && (item = resolve_todo(item, state, flag))
if sig.atype <: mthd_sig
handle_single_case!(ir, stmt, idx, item, isinvoke, todo)
return true
Expand Down Expand Up @@ -1291,6 +1281,8 @@ function assemble_inline_todo!(ir::IRCode, state::InliningState)
info = info.info
end

flag = ir.stmts[idx][:flag]

# Inference determined this couldn't be analyzed. Don't question it.
if info === false
continue
Expand All @@ -1300,23 +1292,24 @@ function assemble_inline_todo!(ir::IRCode, state::InliningState)
# it'll have performed a specialized analysis for just this case. Use its
# result.
if isa(info, ConstCallInfo)
if maybe_handle_const_call!(ir, idx, stmt, info, sig, calltype, state, sig.f === Core.invoke, todo)
if !is_stmt_noinline(flag) && maybe_handle_const_call!(
ir, idx, stmt, info, sig, calltype, state, flag, sig.f === Core.invoke, todo)
continue
else
info = info.call
end
end

if isa(info, OpaqueClosureCallInfo)
result = analyze_method!(info.match, sig.atypes, state, calltype)
result = analyze_method!(info.match, sig.atypes, state, calltype, flag)
handle_single_case!(ir, stmt, idx, result, false, todo)
continue
end

# Handle invoke
if sig.f === Core.invoke
if isa(info, InvokeCallInfo)
inline_invoke!(ir, idx, sig, info, state, todo)
inline_invoke!(ir, idx, sig, info, state, todo, flag)
end
continue
end
Expand All @@ -1330,7 +1323,7 @@ function assemble_inline_todo!(ir::IRCode, state::InliningState)
continue
end

analyze_single_call!(ir, todo, idx, stmt, sig, calltype, infos, state)
analyze_single_call!(ir, todo, idx, stmt, sig, calltype, infos, state, flag)
end
todo
end
Expand Down
2 changes: 1 addition & 1 deletion base/compiler/typeinfer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -343,7 +343,7 @@ function maybe_compress_codeinfo(interp::AbstractInterpreter, linfo::MethodInsta
nslots = length(ci.slotflags)
resize!(ci.slottypes, nslots)
resize!(ci.slotnames, nslots)
return ccall(:jl_compress_ir, Any, (Any, Any), def, ci)
return ccall(:jl_compress_ir, Vector{UInt8}, (Any, Any), def, ci)
else
return ci
end
Expand Down
8 changes: 4 additions & 4 deletions base/compiler/types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@ swapped in as long as they follow the AbstractInterpreter API.
All AbstractInterpreters are expected to provide at least the following methods:
- InferenceParams(interp) - return an `InferenceParams` instance
- OptimizationParams(interp) - return an `OptimizationParams` instance
- get_world_counter(interp) - return the world age for this interpreter
- get_inference_cache(interp) - return the runtime inference cache
- `InferenceParams(interp)` - return an `InferenceParams` instance
- `OptimizationParams(interp)` - return an `OptimizationParams` instance
- `get_world_counter(interp)` - return the world age for this interpreter
- `get_inference_cache(interp)` - return the runtime inference cache
"""
abstract type AbstractInterpreter; end

Expand Down
Loading

0 comments on commit d21935c

Please sign in to comment.