Skip to content

Commit

Permalink
perform inference using optimizer-derived type information
Browse files Browse the repository at this point in the history
In certain cases, the optimizer can introduce new type information.
This is particularly evident in SROA, where load forwarding can reveal
type information that was not visible during abstract interpretation.
In such cases, re-running abstract interpretation using this new type
information can be highly valuable, however, currently, this only occurs
when semi-concrete interpretation happens to be triggered.

This commit introduces a new "post-optimization inference" phase at the
end of the optimizer pipeline. When the optimizer derives new type
information, this phase performs IR abstract interpretation to further
optimize the IR.
  • Loading branch information
aviatesk committed Nov 27, 2024
1 parent f6ebc4b commit 231b196
Show file tree
Hide file tree
Showing 8 changed files with 108 additions and 29 deletions.
3 changes: 2 additions & 1 deletion Compiler/src/inferencestate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -804,6 +804,7 @@ mutable struct IRInterpretationState
callstack #::Vector{AbsIntState}
frameid::Int
parentid::Int
new_call_inferred::Bool

function IRInterpretationState(interp::AbstractInterpreter,
spec_info::SpecInfo, ir::IRCode, mi::MethodInstance, argtypes::Vector{Any},
Expand All @@ -829,7 +830,7 @@ mutable struct IRInterpretationState
edges = Any[]
callstack = AbsIntState[]
return new(spec_info, ir, mi, WorldWithRange(world, valid_worlds), curridx, argtypes_refined, ir.sptypes, tpdum,
ssa_refined, lazyreachability, tasks, edges, callstack, 0, 0)
ssa_refined, lazyreachability, tasks, edges, callstack, 0, 0, #=new_call_inferred=#false)
end
end

Expand Down
64 changes: 52 additions & 12 deletions Compiler/src/optimize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -992,7 +992,7 @@ end

# run the optimization work
function optimize(interp::AbstractInterpreter, opt::OptimizationState, caller::InferenceResult)
@timeit "optimizer" ir = run_passes_ipo_safe(opt.src, opt)
@timeit "optimizer" ir = run_passes_ipo_safe(interp, opt, caller)
ipo_dataflow_analysis!(interp, opt, ir, caller)
return finish(interp, opt, ir, caller)
end
Expand All @@ -1012,27 +1012,25 @@ matchpass(optimize_until::Int, stage, _) = optimize_until == stage
matchpass(optimize_until::String, _, name) = optimize_until == name
matchpass(::Nothing, _, _) = false

function run_passes_ipo_safe(
ci::CodeInfo,
sv::OptimizationState,
optimize_until = nothing, # run all passes by default
)
function run_passes_ipo_safe(interp::AbstractInterpreter, sv::OptimizationState, result::InferenceResult;
optimize_until = nothing) # run all passes by default
ci = sv.src
__stage__ = 0 # used by @pass
# NOTE: The pass name MUST be unique for `optimize_until::AbstractString` to work
@pass "convert" ir = convert_to_ircode(ci, sv)
@pass "slot2reg" ir = slot2reg(ir, ci, sv)
# TODO: Domsorting can produce an updated domtree - no need to recompute here
@pass "compact 1" ir = compact!(ir)
@pass "Inlining" ir = ssa_inlining_pass!(ir, sv.inlining, ci.propagate_inbounds)
# @timeit "verify 2" verify_ir(ir)
@pass "compact 2" ir = compact!(ir)
@pass "SROA" ir = sroa_pass!(ir, sv.inlining)
@pass "ADCE" (ir, made_changes) = adce_pass!(ir, sv.inlining)
if made_changes
@pass "compact 3" ir = compact!(ir, true)
end
@pass "ADCE" ir, changed = adce_pass!(ir, sv.inlining)
@pass "compact 3" changed && (
ir = compact!(ir, true))
@pass "optinf" optinf_worthwhile(ir) && (
ir = optinf!(ir, interp, sv, result))
if is_asserts()
@timeit "verify 3" begin
@timeit "verify" begin
verify_ir(ir, true, false, optimizer_lattice(sv.inlining.interp), sv.linfo)
verify_linetable(ir.debuginfo, length(ir.stmts))
end
Expand All @@ -1041,6 +1039,48 @@ function run_passes_ipo_safe(
return ir
end

# If the optimizer derives new type information (as implied by `IR_FLAG_REFINED`),
# and this new type information is available for the arguments of a call expression,
# further optimizations may be possible by performing irinterp on the optimized IR.
function optinf_worthwhile(ir::IRCode)
@assert isempty(ir.new_nodes) "expected compacted IRCode"
for i = 1:length(ir.stmts)
if has_flag(ir[SSAValue(i)], IR_FLAG_REFINED)
stmt = ir[SSAValue(i)][:stmt]
if isexpr(stmt, :call)
return true
end
end
end
return false
end

function optinf!(ir::IRCode, interp::AbstractInterpreter, sv::OptimizationState, result::InferenceResult)
ci = sv.src
spec_info = SpecInfo(ci)
world = get_inference_world(interp)
min_world, max_world = first(result.valid_worlds), last(result.valid_worlds)
irsv = IRInterpretationState(interp, spec_info, ir, result.linfo, ir.argtypes,
world, min_world, max_world)
rt, (nothrow, noub) = ir_abstract_constant_propagation(interp, irsv)
if irsv.new_call_inferred
ir = ssa_inlining_pass!(ir, sv.inlining, ci.propagate_inbounds)
ir = compact!(ir)
effects = result.effects
if nothrow
effects = Effects(effects; nothrow=true)
end
if noub
effects = Effects(effects; noub=ALWAYS_TRUE)
end
result.effects = effects
result.exc_result = refine_exception_type(result.exc_result, effects)
= strictneqpartialorder(ipo_lattice(interp))
result.result = rt result.result ? rt : result.result
end
return ir
end

function strip_trailing_junk!(code::Vector{Any}, ssavaluetypes::Vector{Any}, ssaflags::Vector, debuginfo::DebugInfoStream, cfg::CFG, info::Vector{CallInfo})
# Remove `nothing`s at the end, we don't handle them well
# (we expect the last instruction to be a terminator)
Expand Down
3 changes: 2 additions & 1 deletion Compiler/src/ssair/EscapeAnalysis.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,11 @@ using Base: # Base definitions
unwrap_unionall, !, !=, !==, &, *, +, -, :, <, <<, =>, >, |, , , , , , , , ,
hasintersect
using ..Compiler: # Compiler specific definitions
Compiler, @show, ,
AbstractLattice, Bottom, IRCode, IR_FLAG_NOTHROW, InferenceResult, SimpleInferenceLattice,
argextype, fieldcount_noerror, hasintersect, has_flag, intrinsic_nothrow,
is_meta_expr_head, is_identity_free_argtype, isexpr, println, setfield!_nothrow,
singleton_type, try_compute_field, try_compute_fieldidx, widenconst, , Compiler
singleton_type, try_compute_field, try_compute_fieldidx, widenconst

function include(x::String)
if !isdefined(Base, :end_base_include)
Expand Down
3 changes: 2 additions & 1 deletion Compiler/src/ssair/ir.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1709,7 +1709,8 @@ function reprocess_phi_node!(𝕃ₒ::AbstractLattice, compact::IncrementalCompa

# There's only one predecessor left - just replace it
v = phi.values[1]
if !(𝕃ₒ, compact[compact.ssa_rename[old_idx]][:type], argextype(v, compact))
= strictneqpartialorder(𝕃ₒ)
if argextype(v, compact) compact[compact.ssa_rename[old_idx]][:type]
v = Refined(v)
end
compact.ssa_rename[old_idx] = v
Expand Down
8 changes: 6 additions & 2 deletions Compiler/src/ssair/irinterp.jl
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ function abstract_call(interp::AbstractInterpreter, arginfo::ArgInfo, sstate::St
call = abstract_call(interp, arginfo, si, irsv)::Future
Future{Any}(call, interp, irsv) do call, interp, irsv
irsv.ir.stmts[irsv.curridx][:info] = call.info
irsv.new_call_inferred |= true
nothing
end
return call
Expand Down Expand Up @@ -204,7 +205,8 @@ function reprocess_instruction!(interp::AbstractInterpreter, inst::Instruction,
# Handled at the very end
return false
elseif isa(stmt, PiNode)
rt = tmeet(typeinf_lattice(interp), argextype(stmt.val, ir), widenconst(stmt.typ))
= join(typeinf_lattice(interp))
rt = argextype(stmt.val, ir) widenconst(stmt.typ)
elseif stmt === nothing
return false
elseif isa(stmt, GlobalRef)
Expand All @@ -226,7 +228,9 @@ function reprocess_instruction!(interp::AbstractInterpreter, inst::Instruction,
inst[:stmt] = quoted(rt.val)
end
return true
elseif !(typeinf_lattice(interp), inst[:type], rt)
end
= strictneqpartialorder(typeinf_lattice(interp))
if rt inst[:type]
inst[:type] = rt
return true
end
Expand Down
36 changes: 26 additions & 10 deletions Compiler/src/ssair/passes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -989,9 +989,10 @@ function lift_keyvalue_get!(compact::IncrementalCompact, idx::Int, stmt::Expr,
lifted_leaves === nothing && return

result_t = Union{}
= join(𝕃ₒ)
for v in values(lifted_leaves)
v === nothing && return
result_t = tmerge(𝕃ₒ, result_t, argextype(v.val, compact))
result_t = result_t argextype(v.val, compact)
end

(lifted_val, nest) = perform_lifting!(compact,
Expand All @@ -1001,8 +1002,12 @@ function lift_keyvalue_get!(compact::IncrementalCompact, idx::Int, stmt::Expr,
compact[idx] = lifted_val === nothing ? nothing : Expr(:call, GlobalRef(Core, :tuple), lifted_val.val)
finish_phi_nest!(compact, nest)
if lifted_val !== nothing
if !(𝕃ₒ, compact[SSAValue(idx)][:type], tuple_tfunc(𝕃ₒ, Any[result_t]))
add_flag!(compact[SSAValue(idx)], IR_FLAG_REFINED)
stmttype = tuple_tfunc(𝕃ₒ, Any[result_t])
inst = compact[SSAValue(idx)]
= strictneqpartialorder(𝕃ₒ)
if stmttype inst[:type]
inst[:type] = stmttype
add_flag!(inst, IR_FLAG_REFINED)
end
end

Expand Down Expand Up @@ -1440,19 +1445,23 @@ function sroa_pass!(ir::IRCode, inlining::Union{Nothing,InliningState}=nothing)
lifted_leaves, any_undef = lifted_result

result_t = Union{}
= join(𝕃ₒ)
for v in values(lifted_leaves)
v === nothing && continue
result_t = tmerge(𝕃ₒ, result_t, argextype(v.val, compact))
result_t = result_t argextype(v.val, compact)
end

(lifted_val, nest) = perform_lifting!(compact,
visited_philikes, field, result_t, lifted_leaves, val, lazydomtree)

should_delete_node = false
line = compact[SSAValue(idx)][:line]
if lifted_val !== nothing && !(𝕃ₒ, compact[SSAValue(idx)][:type], result_t)
inst = compact[SSAValue(idx)]
line = inst[:line]
= strictneqpartialorder(𝕃ₒ)
if lifted_val !== nothing && result_t inst[:type]
compact[idx] = lifted_val === nothing ? nothing : lifted_val.val
add_flag!(compact[SSAValue(idx)], IR_FLAG_REFINED)
inst[:type] = result_t
add_flag!(inst, IR_FLAG_REFINED)
elseif lifted_val === nothing || isa(lifted_val.val, AnySSAValue)
# Save some work in a later compaction, by inserting this into the renamer now,
# but only do this if we didn't set the REFINED flag, to save work for irinterp
Expand Down Expand Up @@ -1855,9 +1864,15 @@ function sroa_mutables!(ir::IRCode, defuses::IdDict{Int,Tuple{SPCSet,SSADefUse}}
for use in du.uses
if use.kind === :getfield
inst = ir[SSAValue(use.idx)]
inst[:stmt] = compute_value_for_use(ir, domtree, allblocks,
newvalue = compute_value_for_use(ir, domtree, allblocks,
du, phinodes, fidx, use.idx)
add_flag!(inst, IR_FLAG_REFINED)
inst[:stmt] = newvalue
newvaluetyp = argextype(newvalue, ir)
= strictneqpartialorder(𝕃ₒ)
if newvaluetyp inst[:type]
inst[:type] = newvaluetyp
add_flag!(inst, IR_FLAG_REFINED)
end
elseif use.kind === :isdefined
continue # already rewritten if possible
elseif use.kind === :nopreserve
Expand All @@ -1878,11 +1893,12 @@ function sroa_mutables!(ir::IRCode, defuses::IdDict{Int,Tuple{SPCSet,SSADefUse}}
for b in phiblocks
n = ir[phinodes[b]][:stmt]::PhiNode
result_t = Bottom
= join(𝕃ₒ)
for p in ir.cfg.blocks[b].preds
push!(n.edges, p)
v = compute_value_for_block(ir, domtree, allblocks, du, phinodes, fidx, p)
push!(n.values, v)
result_t = tmerge(𝕃ₒ, result_t, argextype(v, ir))
result_t = result_t argextype(v, ir)
end
ir[phinodes[b]][:type] = result_t
end
Expand Down
3 changes: 2 additions & 1 deletion Compiler/src/typeinfer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -999,7 +999,7 @@ function typeinf_ircode(interp::AbstractInterpreter, mi::MethodInstance,
end
(; result) = frame
opt = OptimizationState(frame, interp)
ir = run_passes_ipo_safe(opt.src, opt, optimize_until)
ir = run_passes_ipo_safe(interp, opt, result; optimize_until)
rt = widenconst(ignorelimited(result.result))
return ir, rt
end
Expand All @@ -1024,6 +1024,7 @@ function typeinf_frame(interp::AbstractInterpreter, mi::MethodInstance, run_opti
opt = OptimizationState(frame, interp)
optimize(interp, opt, frame.result)
src = ir_to_codeinf!(opt)
src.rettype = widenconst(result.result)
end
result.src = frame.src = src
end
Expand Down
17 changes: 16 additions & 1 deletion Compiler/test/inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3495,7 +3495,7 @@ f31974(n::Int) = f31974(1:n)
@test code_typed(f31974, Tuple{Int}) !== nothing

f_overly_abstract_complex() = Complex(Ref{Number}(1)[])
@test Base.return_types(f_overly_abstract_complex, Tuple{}) == [Complex]
@test Base.infer_return_type(f_overly_abstract_complex, Tuple{}) == Complex{Int}

# Issue 26724
const IntRange = AbstractUnitRange{<:Integer}
Expand Down Expand Up @@ -6126,3 +6126,18 @@ function func_swapglobal!_must_throw(x)
end
@test Base.infer_return_type(func_swapglobal!_must_throw, (Int,); interp=SwapGlobalInterp()) === Union{}
@test !Compiler.is_effect_free(Base.infer_effects(func_swapglobal!_must_throw, (Int,); interp=SwapGlobalInterp()) )

# opt inf
@test Base.infer_return_type((Vector{Any},)) do argtypes
box = Core.Box()
box.contents = argtypes
return length(box.contents)
end == Int
@test Base.infer_return_type((Vector{Any},)) do argtypes
local argtypesi
function cls()
argtypesi = @noinline copy(argtypes)
return length(argtypesi)
end
return @inline cls()
end == Int

0 comments on commit 231b196

Please sign in to comment.