Skip to content

Commit

Permalink
fix #42078, improve the idempotency of callsite inlining (#42082)
Browse files Browse the repository at this point in the history
After #41328, inference can observe statement flags and try to re-infer
a discarded source if it's going to be inlined.
The re-inferred source will only be cached into the inference-local
cache, and won't be cached globally.
  • Loading branch information
aviatesk authored Sep 4, 2021
1 parent c3d2903 commit 876df79
Show file tree
Hide file tree
Showing 6 changed files with 104 additions and 53 deletions.
11 changes: 6 additions & 5 deletions base/compiler/abstractinterpretation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -546,10 +546,9 @@ function abstract_call_method_with_const_args(interp::AbstractInterpreter, resul
end
end
inf_result = InferenceResult(mi, argtypes, va_override)
frame = InferenceState(inf_result, #=cache=#false, interp)
frame = InferenceState(inf_result, #=cache=#:local, interp)
frame === nothing && return nothing # this is probably a bad generated function (unsound), but just ignore it
frame.parent = sv
push!(inf_cache, inf_result)
typeinf(interp, frame) || return nothing
end
result = inf_result.result
Expand Down Expand Up @@ -592,7 +591,7 @@ function maybe_get_const_prop_profitable(interp::AbstractInterpreter, result::Me
return nothing
end
mi = mi::MethodInstance
if !force && !const_prop_methodinstance_heuristic(interp, match, mi, sv)
if !force && !const_prop_methodinstance_heuristic(interp, match, mi, argtypes, sv)
add_remark!(interp, sv, "[constprop] Disabled by method instance heuristic")
return nothing
end
Expand Down Expand Up @@ -696,7 +695,9 @@ end
# This is a heuristic to avoid trying to const prop through complicated functions
# where we would spend a lot of time, but are probably unlikely to get an improved
# result anyway.
function const_prop_methodinstance_heuristic(interp::AbstractInterpreter, match::MethodMatch, mi::MethodInstance, sv::InferenceState)
function const_prop_methodinstance_heuristic(
interp::AbstractInterpreter, match::MethodMatch, mi::MethodInstance,
argtypes::Vector{Any}, sv::InferenceState)
method = match.method
if method.is_for_opaque_closure
# Not inlining an opaque closure can be very expensive, so be generous
Expand All @@ -715,7 +716,7 @@ function const_prop_methodinstance_heuristic(interp::AbstractInterpreter, match:
if isdefined(code, :inferred) && !cache_inlineable
cache_inf = code.inferred
if !(cache_inf === nothing)
src = inlining_policy(interp, cache_inf, get_curr_ssaflag(sv))
src = inlining_policy(interp, cache_inf, get_curr_ssaflag(sv), mi, argtypes)
cache_inlineable = src !== nothing
end
end
Expand Down
13 changes: 7 additions & 6 deletions base/compiler/inferencestate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,9 @@ mutable struct InferenceState

# src is assumed to be a newly-allocated CodeInfo, that can be modified in-place to contain intermediate results
function InferenceState(result::InferenceResult, src::CodeInfo,
cached::Bool, interp::AbstractInterpreter)
cache::Symbol, interp::AbstractInterpreter)
(; def) = linfo = result.linfo
code = src.code::Array{Any,1}
code = src.code::Vector{Any}

sp = sptypes_from_meth_instance(linfo::MethodInstance)

Expand Down Expand Up @@ -92,6 +92,7 @@ mutable struct InferenceState
valid_worlds = WorldRange(src.min_world,
src.max_world == typemax(UInt) ? get_world_counter() : src.max_world)

@assert cache === :no || cache === :local || cache === :global
frame = new(
InferenceParams(interp), result, linfo,
sp, slottypes, mod, 0,
Expand All @@ -103,11 +104,11 @@ mutable struct InferenceState
Vector{Tuple{InferenceState,LineNum}}(), # cycle_backedges
Vector{InferenceState}(), # callers_in_cycle
#=parent=#nothing,
cached, false, false,
cache === :global, false, false,
CachedMethodTable(method_table(interp)),
interp)
result.result = frame
cached && push!(get_inference_cache(interp), result)
cache !== :no && push!(get_inference_cache(interp), result)
return frame
end
end
Expand Down Expand Up @@ -222,12 +223,12 @@ end

method_table(interp::AbstractInterpreter, sv::InferenceState) = sv.method_table

function InferenceState(result::InferenceResult, cached::Bool, interp::AbstractInterpreter)
function InferenceState(result::InferenceResult, cache::Symbol, interp::AbstractInterpreter)
# prepare an InferenceState object for inferring lambda
src = retrieve_code_info(result.linfo)
src === nothing && return nothing
validate_code_in_debug_mode(result.linfo, src, "lowered")
return InferenceState(result, src, cached, interp)
return InferenceState(result, src, cache, interp)
end

function sptypes_from_meth_instance(linfo::MethodInstance)
Expand Down
26 changes: 18 additions & 8 deletions base/compiler/optimize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,20 +28,30 @@ struct InliningState{S <: Union{EdgeTracker, Nothing}, T, I<:AbstractInterpreter
interp::I
end

function inlining_policy(interp::AbstractInterpreter, @nospecialize(src), stmt_flag::UInt8)
function inlining_policy(interp::AbstractInterpreter, @nospecialize(src), stmt_flag::UInt8,
mi::MethodInstance, argtypes::Vector{Any})
if isa(src, CodeInfo) || isa(src, Vector{UInt8})
src_inferred = ccall(:jl_ir_flag_inferred, Bool, (Any,), src)
src_inlineable = is_stmt_inline(stmt_flag) || ccall(:jl_ir_flag_inlineable, Bool, (Any,), src)
return src_inferred && src_inlineable ? src : nothing
elseif isa(src, OptimizationState) && isdefined(src, :ir)
return (is_stmt_inline(stmt_flag) || src.src.inlineable) ? src.ir : nothing
else
# maybe we want to make inference keep the source in a local cache if a statement is going to inlined
# and re-optimize it here with disabling further inlining to avoid infinite optimization loop
# (we can even naively try to re-infer it entirely)
# but it seems like that "single-level-inlining" is more trouble and complex than it's worth
# see https://github.com/JuliaLang/julia/pull/41328/commits/0fc0f71a42b8c9d04b0dafabf3f1f17703abf2e7
return nothing
elseif src === nothing && is_stmt_inline(stmt_flag)
# if this statement is forced to be inlined, make an additional effort to find the
# inferred source in the local cache
# we still won't find a source for recursive call because the "single-level" inlining
# seems to be more trouble and complex than it's worth
inf_result = cache_lookup(mi, argtypes, get_inference_cache(interp))
inf_result === nothing && return nothing
src = inf_result.src
if isa(src, CodeInfo)
src_inferred = ccall(:jl_ir_flag_inferred, Bool, (Any,), src)
return src_inferred ? src : nothing
elseif isa(src, OptimizationState)
return isdefined(src, :ir) ? src.ir : nothing
else
return nothing
end
end
end

Expand Down
4 changes: 2 additions & 2 deletions base/compiler/ssair/inlining.jl
Original file line number Diff line number Diff line change
Expand Up @@ -722,7 +722,7 @@ end

function resolve_todo(todo::InliningTodo, state::InliningState, flag::UInt8)
mi = todo.mi
(; match) = todo.spec::DelayedInliningSpec
(; match, atypes) = todo.spec::DelayedInliningSpec

#XXX: update_valid_age!(min_valid[1], max_valid[1], sv)
isconst, src = false, nothing
Expand Down Expand Up @@ -757,7 +757,7 @@ function resolve_todo(todo::InliningTodo, state::InliningState, flag::UInt8)
return ConstantCase(src)
end

src = inlining_policy(state.interp, src, flag)
src = inlining_policy(state.interp, src, flag, mi, atypes)

if src === nothing
return compileable_specialization(et, match)
Expand Down
63 changes: 33 additions & 30 deletions base/compiler/typeinfer.jl
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
# This file is a part of Julia. License is MIT: https://julialang.org/license

# build (and start inferring) the inference frame for the top-level MethodInstance
function typeinf(interp::AbstractInterpreter, result::InferenceResult, cached::Bool)
frame = InferenceState(result, cached, interp)
function typeinf(interp::AbstractInterpreter, result::InferenceResult, cache::Symbol)
frame = InferenceState(result, cache, interp)
frame === nothing && return false
cached && lock_mi_inference(interp, result.linfo)
cache === :global && lock_mi_inference(interp, result.linfo)
return typeinf(interp, frame)
end

Expand Down Expand Up @@ -774,22 +774,30 @@ function typeinf_edge(interp::AbstractInterpreter, method::Method, @nospecialize
mi = specialize_method(method, atypes, sparams)::MethodInstance
code = get(code_cache(interp), mi, nothing)
if code isa CodeInstance # return existing rettype if the code is already inferred
update_valid_age!(caller, WorldRange(min_world(code), max_world(code)))
rettype = code.rettype
if isdefined(code, :rettype_const)
rettype_const = code.rettype_const
if isa(rettype_const, Vector{Any}) && !(Vector{Any} <: rettype)
return PartialStruct(rettype, rettype_const), mi
elseif rettype <: Core.OpaqueClosure && isa(rettype_const, PartialOpaque)
return rettype_const, mi
elseif isa(rettype_const, InterConditional)
return rettype_const, mi
if code.inferred === nothing && is_stmt_inline(get_curr_ssaflag(caller))
# we already inferred this edge previously and decided to discarded the inferred code
# but the inlinear will request to use it, we re-infer it here and keep it around in the local cache
cache = :local
else
update_valid_age!(caller, WorldRange(min_world(code), max_world(code)))
rettype = code.rettype
if isdefined(code, :rettype_const)
rettype_const = code.rettype_const
if isa(rettype_const, Vector{Any}) && !(Vector{Any} <: rettype)
return PartialStruct(rettype, rettype_const), mi
elseif rettype <: Core.OpaqueClosure && isa(rettype_const, PartialOpaque)
return rettype_const, mi
elseif isa(rettype_const, InterConditional)
return rettype_const, mi
else
return Const(rettype_const), mi
end
else
return Const(rettype_const), mi
return rettype, mi
end
else
return rettype, mi
end
else
cache = :global # cache edge targets by default
end
if ccall(:jl_get_module_infer, Cint, (Any,), method.module) == 0
return Any, nothing
Expand All @@ -805,7 +813,7 @@ function typeinf_edge(interp::AbstractInterpreter, method::Method, @nospecialize
# completely new
lock_mi_inference(interp, mi)
result = InferenceResult(mi)
frame = InferenceState(result, #=cached=#true, interp) # always use the cache for edge targets
frame = InferenceState(result, cache, interp) # always use the cache for edge targets
if frame === nothing
# can't get the source for this, so we know nothing
unlock_mi_inference(interp, mi)
Expand Down Expand Up @@ -834,14 +842,9 @@ function typeinf_code(interp::AbstractInterpreter, method::Method, @nospecialize
mi = specialize_method(method, atypes, sparams)::MethodInstance
ccall(:jl_typeinf_begin, Cvoid, ())
result = InferenceResult(mi)
frame = InferenceState(result, false, interp)
frame = InferenceState(result, run_optimizer ? :global : :no, interp)
frame === nothing && return (nothing, Any)
if typeinf(interp, frame) && run_optimizer
opt_params = OptimizationParams(interp)
result.src = src = OptimizationState(frame, opt_params, interp)
optimize(interp, src, opt_params, ignorelimited(result.result))
frame.src = finish!(interp, result)
end
typeinf(interp, frame)
ccall(:jl_typeinf_end, Cvoid, ())
frame.inferred || return (nothing, Any)
return (frame.src, widenconst(ignorelimited(result.result)))
Expand Down Expand Up @@ -898,7 +901,7 @@ function typeinf_ext(interp::AbstractInterpreter, mi::MethodInstance)
return retrieve_code_info(mi)
end
lock_mi_inference(interp, mi)
frame = InferenceState(InferenceResult(mi), #=cached=#true, interp)
frame = InferenceState(InferenceResult(mi), #=cache=#:global, interp)
frame === nothing && return nothing
typeinf(interp, frame)
ccall(:jl_typeinf_end, Cvoid, ())
Expand All @@ -921,11 +924,11 @@ function typeinf_type(interp::AbstractInterpreter, method::Method, @nospecialize
return code.rettype
end
end
frame = InferenceResult(mi)
typeinf(interp, frame, true)
result = InferenceResult(mi)
typeinf(interp, result, :global)
ccall(:jl_typeinf_end, Cvoid, ())
frame.result isa InferenceState && return nothing
return widenconst(ignorelimited(frame.result))
result.result isa InferenceState && return nothing
return widenconst(ignorelimited(result.result))
end

# This is a bridge for the C code calling `jl_typeinf_func()`
Expand All @@ -941,7 +944,7 @@ function typeinf_ext_toplevel(interp::AbstractInterpreter, linfo::MethodInstance
ccall(:jl_typeinf_begin, Cvoid, ())
if !src.inferred
result = InferenceResult(linfo)
frame = InferenceState(result, src, #=cached=#true, interp)
frame = InferenceState(result, src, #=cache=#:global, interp)
typeinf(interp, frame)
@assert frame.inferred # TODO: deal with this better
src = frame.src
Expand Down
40 changes: 38 additions & 2 deletions test/compiler/inline.jl
Original file line number Diff line number Diff line change
Expand Up @@ -172,8 +172,7 @@ function f_ifelse(x)
b = ifelse(a, true, false)
return b ? x + 1 : x
end
# 2 for now because the compiler leaves a GotoNode around
@test_broken length(code_typed(f_ifelse, (String,))[1][1].code) <= 2
@test length(code_typed(f_ifelse, (String,))[1][1].code) <= 2

# Test that inlining of _apply_iterate properly hits the inference cache
@noinline cprop_inline_foo1() = (1, 1)
Expand Down Expand Up @@ -614,3 +613,40 @@ end
# Issue #41299 - inlining deletes error check in :>
g41299(f::Tf, args::Vararg{Any,N}) where {Tf,N} = f(args...)
@test_throws TypeError g41299(>:, 1, 2)

# https://github.com/JuliaLang/julia/issues/42078
# idempotency of callsite inling
function getcache(mi::Core.MethodInstance)
cache = Core.Compiler.code_cache(Core.Compiler.NativeInterpreter())
codeinf = Core.Compiler.get(cache, mi, nothing)
return isnothing(codeinf) ? nothing : codeinf
end
@noinline f42078(a) = sum(sincos(a))
let
ninlined = let
code = code_typed1((Int,)) do a
@inline f42078(a)
end
@test all(x->!isinvoke(x, :f42078), code)
length(code)
end

let # codegen will discard the source because it's not supposed to be inlined in general context
a = 42
f42078(a)
end
let # make sure to discard the inferred source
specs = collect(only(methods(f42078)).specializations)
mi = specs[findfirst(!isnothing, specs)]::Core.MethodInstance
codeinf = getcache(mi)::Core.CodeInstance
codeinf.inferred = nothing
end

let # inference should re-infer `f42078(::Int)` and we should get the same code
code = code_typed1((Int,)) do a
@inline f42078(a)
end
@test all(x->!isinvoke(x, :f42078), code)
@test ninlined == length(code)
end
end

0 comments on commit 876df79

Please sign in to comment.