diff --git a/base/compiler/ssair/inlining.jl b/base/compiler/ssair/inlining.jl index ab93375db4d0e..ae3d5cc931cbe 100644 --- a/base/compiler/ssair/inlining.jl +++ b/base/compiler/ssair/inlining.jl @@ -371,7 +371,6 @@ function ir_inline_item!(compact::IncrementalCompact, idx::Int, argexprs::Vector stmt′ = ssa_substitute!(idx′, stmt′, argexprs, sig, sparam_vals, linetable_offset, boundscheck, compact) if isa(stmt′, ReturnNode) val = stmt′.val - isa(val, SSAValue) && (compact.used_ssas[val.id] += 1) return_value = SSAValue(idx′) inline_compact[idx′] = val inline_compact.result[idx′][:type] = @@ -428,13 +427,6 @@ function ir_inline_item!(compact::IncrementalCompact, idx::Int, argexprs::Vector just_fixup!(inline_compact) compact.result_idx = inline_compact.result_idx compact.active_result_bb = inline_compact.active_result_bb - for i = 1:length(pn.values) - isassigned(pn.values, i) || continue - v = pn.values[i] - if isa(v, SSAValue) - compact.used_ssas[v.id] += 1 - end - end if length(pn.edges) == 1 return_value = pn.values[1] else diff --git a/base/compiler/ssair/ir.jl b/base/compiler/ssair/ir.jl index b3f01e8a4a415..73e70973166c9 100644 --- a/base/compiler/ssair/ir.jl +++ b/base/compiler/ssair/ir.jl @@ -31,6 +31,11 @@ end function block_for_inst(index::Vector{Int}, inst::Int) return searchsortedfirst(index, inst, lt=(<=)) end + +function block_for_inst(index::Vector{BasicBlock}, inst::Int) + return searchsortedfirst(index, BasicBlock(StmtRange(inst, inst)), by=x->first(x.stmts), lt=(<=))-1 +end + block_for_inst(cfg::CFG, inst::Int) = block_for_inst(cfg.index, inst) function basic_blocks_starts(stmts::Vector{Any}) @@ -120,9 +125,6 @@ function compute_basic_blocks(stmts::Vector{Any}) # :enter gets a virtual edge to the exception handler and # the exception handler gets a virtual edge from outside # the function. - # See the devdocs on exception handling in SSA form (or - # bug Keno to write them, if you're reading this and they - # don't exist) block′ = block_for_inst(basic_block_index, terminator.args[1]::Int) push!(blocks[block′].preds, num) push!(blocks[block′].preds, 0) @@ -556,6 +558,7 @@ mutable struct IncrementalCompact new_nodes_idx::Int # This supports insertion while compacting new_new_nodes::NewNodeStream # New nodes that were before the compaction point at insertion time + new_new_used_ssas::Vector{Int} # TODO: Switch these two to a min-heap of some sort pending_nodes::NewNodeStream # New nodes that were after the compaction point at insertion time pending_perm::Vector{Int} @@ -576,6 +579,7 @@ mutable struct IncrementalCompact new_len = length(code.stmts) + length(code.new_nodes) result = InstructionStream(new_len) used_ssas = fill(0, new_len) + new_new_used_ssas = Vector{Int}() blocks = code.cfg.blocks if allow_cfg_transforms bb_rename = Vector{Int}(undef, length(blocks)) @@ -618,7 +622,7 @@ mutable struct IncrementalCompact pending_nodes = NewNodeStream() pending_perm = Int[] return new(code, result, result_bbs, ssa_rename, bb_rename, bb_rename, used_ssas, late_fixup, perm, 1, - new_new_nodes, pending_nodes, pending_perm, + new_new_nodes, new_new_used_ssas, pending_nodes, pending_perm, 1, 1, 1, false, allow_cfg_transforms, allow_cfg_transforms) end @@ -627,7 +631,7 @@ mutable struct IncrementalCompact perm = my_sortperm(Int[code.new_nodes.info[i].pos for i in 1:length(code.new_nodes)]) new_len = length(code.stmts) + length(code.new_nodes) ssa_rename = Any[SSAValue(i) for i = 1:new_len] - used_ssas = fill(0, new_len) + new_new_used_ssas = Vector{Int}() late_fixup = Vector{Int}() bb_rename = Vector{Int}() new_new_nodes = NewNodeStream() @@ -636,7 +640,7 @@ mutable struct IncrementalCompact return new(code, parent.result, parent.result_bbs, ssa_rename, bb_rename, bb_rename, parent.used_ssas, late_fixup, perm, 1, - new_new_nodes, pending_nodes, pending_perm, + new_new_nodes, new_new_used_ssas, pending_nodes, pending_perm, 1, result_offset, parent.active_result_bb, false, false, false) end end @@ -646,6 +650,7 @@ struct TypesView{T} end types(ir::Union{IRCode, IncrementalCompact}) = TypesView(ir) +# TODO We can be a bit better about access here by using a pattern similar to InstructionStream function getindex(compact::IncrementalCompact, idx::Int) if idx < compact.result_idx return compact.result[idx][:inst] @@ -661,7 +666,10 @@ end function getindex(compact::IncrementalCompact, ssa::OldSSAValue) id = ssa.id - if id <= length(compact.ir.stmts) + if id < compact.idx + new_idx = compact.ssa_rename[id] + return compact.result[new_idx][:inst] + elseif id <= length(compact.ir.stmts) return compact.ir.stmts[id][:inst] end id -= length(compact.ir.stmts) @@ -676,21 +684,85 @@ function getindex(compact::IncrementalCompact, ssa::NewSSAValue) return compact.new_new_nodes.stmts[ssa.id][:inst] end +function block_for_inst(compact::IncrementalCompact, idx::SSAValue) + id = idx.id + if id < compact.result_idx # if ssa within result + return block_for_inst(compact.result_bbs, id) + else + return block_for_inst(compact.ir.cfg, id) + end +end + +function block_for_inst(compact::IncrementalCompact, idx::OldSSAValue) + id = idx.id + if id < compact.idx # if ssa within result + return block_for_inst(compact.result_bbs, compact.ssa_rename[id]) + else + return block_for_inst(compact.ir.cfg, id) + end +end + +function block_for_inst(compact::IncrementalCompact, idx::NewSSAValue) + block_for_inst(compact, SSAValue(compact.new_new_nodes.info[idx.id].pos)) +end + +function dominates_ssa(compact::IncrementalCompact, domtree::DomTree, x::AnySSAValue, y::AnySSAValue) + xb = block_for_inst(compact, x) + yb = block_for_inst(compact, y) + if xb == yb + xinfo = yinfo = nothing + if isa(x, OldSSAValue) + x′ = compact.ssa_rename[x.id]::SSAValue + elseif isa(x, NewSSAValue) + xinfo = compact.new_new_nodes.info[x.id] + x′ = SSAValue(xinfo.pos) + else + x′ = x + end + if isa(y, OldSSAValue) + y′ = compact.ssa_rename[y.id]::SSAValue + elseif isa(y, NewSSAValue) + yinfo = compact.new_new_nodes.info[y.id] + y′ = SSAValue(yinfo.pos) + else + y′ = y + end + if x′.id == y′.id && (xinfo !== nothing || yinfo !== nothing) + if xinfo !== nothing && yinfo !== nothing + if xinfo.attach_after == yinfo.attach_after + return x.id < y.id + end + return yinfo.attach_after + elseif xinfo !== nothing + return !xinfo.attach_after + else + return yinfo.attach_after + end + end + return x′.id < y′.id + end + return dominates(domtree, xb, yb) +end + function count_added_node!(compact::IncrementalCompact, @nospecialize(v)) - needs_late_fixup = isa(v, NewSSAValue) + needs_late_fixup = false if isa(v, SSAValue) compact.used_ssas[v.id] += 1 + elseif isa(v, NewSSAValue) + compact.new_new_used_ssas[v.id] += 1 + needs_late_fixup = true else for ops in userefs(v) val = ops[] if isa(val, SSAValue) compact.used_ssas[val.id] += 1 elseif isa(val, NewSSAValue) + compact.new_new_used_ssas[val.id] += 1 needs_late_fixup = true end end end - needs_late_fixup + return needs_late_fixup end function add_pending!(compact::IncrementalCompact, pos::Int, attach_after::Bool) @@ -708,6 +780,7 @@ function insert_node!(compact::IncrementalCompact, before, inst::NewInstruction, count_added_node!(compact, inst.stmt) line = something(inst.line, compact.result[before.id][:line]) node = add!(compact.new_new_nodes, before.id, attach_after) + push!(compact.new_new_used_ssas, 0) node[:inst], node[:type], node[:line], node[:flag] = inst.stmt, inst.type, line, inst.flag return NewSSAValue(node.idx) else @@ -726,6 +799,7 @@ function insert_node!(compact::IncrementalCompact, before, inst::NewInstruction, count_added_node!(compact, inst.stmt) line = something(inst.line, compact.result[renamed.id][:line]) node = add!(compact.new_new_nodes, renamed.id, attach_after) + push!(compact.new_new_used_ssas, 0) node[:inst], node[:type], node[:line], node[:flag] = inst.stmt, inst.type, line, inst.flag return NewSSAValue(node.idx) else @@ -747,6 +821,7 @@ function insert_node!(compact::IncrementalCompact, before, inst::NewInstruction, line = something(inst.line, compact.new_new_nodes.stmts[before.id][:line]) new_entry = add!(compact.new_new_nodes, before_entry.pos, attach_after) new_entry[:inst], new_entry[:type], new_entry[:line], new_entry[:flag] = inst.stmt, inst.type, line, inst.flag + push!(compact.new_new_used_ssas, 0) return NewSSAValue(new_entry.idx) else error("Unsupported") @@ -773,9 +848,7 @@ function insert_node_here!(compact::IncrementalCompact, inst::NewInstruction, re end node = compact.result[result_idx] node[:inst], node[:type], node[:line], node[:flag] = inst.stmt, inst.type, inst.line, flag - if count_added_node!(compact, inst.stmt) - push!(compact.late_fixup, result_idx) - end + count_added_node!(compact, inst.stmt) && push!(compact.late_fixup, result_idx) compact.result_idx = result_idx + 1 inst = SSAValue(result_idx) refinish && finish_current_bb!(compact, 0) @@ -797,22 +870,50 @@ function getindex(view::TypesView, v::OldSSAValue) return view.ir.pending_nodes.stmts[id][:type] end -function setindex!(compact::IncrementalCompact, @nospecialize(v), idx::SSAValue) - @assert idx.id < compact.result_idx - (compact.result[idx.id][:inst] === v) && return - # Kill count for current uses - for ops in userefs(compact.result[idx.id][:inst]) +function kill_current_uses(compact::IncrementalCompact, @nospecialize(stmt)) + for ops in userefs(stmt) val = ops[] if isa(val, SSAValue) @assert compact.used_ssas[val.id] >= 1 compact.used_ssas[val.id] -= 1 + elseif isa(val, NewSSAValue) + @assert compact.new_new_used_ssas[val.id] >= 1 + compact.new_new_used_ssas[val.id] -= 1 end end +end + +function setindex!(compact::IncrementalCompact, @nospecialize(v), idx::SSAValue) + @assert idx.id < compact.result_idx + (compact.result[idx.id][:inst] === v) && return + # Kill count for current uses + kill_current_uses(compact, compact.result[idx.id][:inst]) compact.result[idx.id][:inst] = v # Add count for new use - if count_added_node!(compact, v) - push!(compact.late_fixup, idx.id) + count_added_node!(compact, v) && push!(compact.late_fixup, idx.id) + return compact +end + +function setindex!(compact::IncrementalCompact, @nospecialize(v), idx::OldSSAValue) + id = idx.id + if id < compact.idx + new_idx = compact.ssa_rename[id] + (compact.result[new_idx][:inst] === v) && return + kill_current_uses(compact, compact.result[new_idx][:inst]) + compact.result[new_idx][:inst] = v + count_added_node!(compact, v) && push!(compact.late_fixup, new_idx) + return compact + elseif id <= length(compact.ir.stmts) # ir.stmts, new_nodes, and pending_nodes uses aren't counted yet, so no need to adjust + compact.ir.stmts[id][:inst] = v + return compact end + id -= length(compact.ir.stmts) + if id <= length(compact.ir.new_nodes) + compact.ir.new_nodes.stmts[id][:inst] = v + return compact + end + id -= length(compact.ir.new_nodes) + compact.pending_nodes.stmts[id][:inst] = v return compact end @@ -856,6 +957,7 @@ end function process_phinode_values(old_values::Vector{Any}, late_fixup::Vector{Int}, processed_idx::Int, result_idx::Int, ssa_rename::Vector{Any}, used_ssas::Vector{Int}, + new_new_used_ssas::Vector{Int}, do_rename_ssa::Bool) values = Vector{Any}(undef, length(old_values)) for i = 1:length(old_values) @@ -867,7 +969,7 @@ function process_phinode_values(old_values::Vector{Any}, late_fixup::Vector{Int} push!(late_fixup, result_idx) val = OldSSAValue(val.id) else - val = renumber_ssa2(val, ssa_rename, used_ssas, do_rename_ssa) + val = renumber_ssa2(val, ssa_rename, used_ssas, new_new_used_ssas, do_rename_ssa) end else used_ssas[val.id] += 1 @@ -877,17 +979,19 @@ function process_phinode_values(old_values::Vector{Any}, late_fixup::Vector{Int} push!(late_fixup, result_idx) else # Always renumber these. do_rename_ssa applies only to actual SSAValues - val = renumber_ssa2(SSAValue(val.id), ssa_rename, used_ssas, true) + val = renumber_ssa2(SSAValue(val.id), ssa_rename, used_ssas, new_new_used_ssas, true) end elseif isa(val, NewSSAValue) push!(late_fixup, result_idx) + new_new_used_ssas[val.id] += 1 end values[i] = val end return values end -function renumber_ssa2(val::SSAValue, ssanums::Vector{Any}, used_ssas::Vector{Int}, do_rename_ssa::Bool) +function renumber_ssa2(val::SSAValue, ssanums::Vector{Any}, used_ssas::Vector{Int}, + new_new_used_ssas::Vector{Int}, do_rename_ssa::Bool) id = val.id if id > length(ssanums) return val @@ -896,22 +1000,26 @@ function renumber_ssa2(val::SSAValue, ssanums::Vector{Any}, used_ssas::Vector{In val = ssanums[id] end if isa(val, SSAValue) - if used_ssas !== nothing - used_ssas[val.id] += 1 - end + used_ssas[val.id] += 1 end return val end -function renumber_ssa2!(@nospecialize(stmt), ssanums::Vector{Any}, used_ssas::Vector{Int}, late_fixup::Vector{Int}, result_idx::Int, do_rename_ssa::Bool) +function renumber_ssa2(val::NewSSAValue, ssanums::Vector{Any}, used_ssas::Vector{Int}, + new_new_used_ssas::Vector{Int}, do_rename_ssa::Bool) + new_new_used_ssas[val.id] += 1 + return val +end + +function renumber_ssa2!(@nospecialize(stmt), ssanums::Vector{Any}, used_ssas::Vector{Int}, new_new_used_ssas::Vector{Int}, late_fixup::Vector{Int}, result_idx::Int, do_rename_ssa::Bool) urs = userefs(stmt) for op in urs val = op[] if isa(val, OldSSAValue) || isa(val, NewSSAValue) push!(late_fixup, result_idx) end - if isa(val, SSAValue) - val = renumber_ssa2(val, ssanums, used_ssas, do_rename_ssa) + if isa(val, Union{SSAValue, NewSSAValue}) + val = renumber_ssa2(val, ssanums, used_ssas, new_new_used_ssas, do_rename_ssa) end if isa(val, OldSSAValue) || isa(val, NewSSAValue) push!(late_fixup, result_idx) @@ -991,16 +1099,13 @@ end function process_node!(compact::IncrementalCompact, result_idx::Int, inst::Instruction, idx::Int, processed_idx::Int, active_bb::Int, do_rename_ssa::Bool) stmt = inst[:inst] - result = compact.result - ssa_rename = compact.ssa_rename - late_fixup = compact.late_fixup - used_ssas = compact.used_ssas + (; result, ssa_rename, late_fixup, used_ssas, new_new_used_ssas, cfg_transforms_enabled, fold_constant_branches) = compact ssa_rename[idx] = SSAValue(result_idx) if stmt === nothing ssa_rename[idx] = stmt elseif isa(stmt, OldSSAValue) ssa_rename[idx] = ssa_rename[stmt.id] - elseif isa(stmt, GotoNode) && compact.cfg_transforms_enabled + elseif isa(stmt, GotoNode) && cfg_transforms_enabled result[result_idx][:inst] = GotoNode(compact.bb_rename_succ[stmt.label]) result_idx += 1 elseif isa(stmt, GlobalRef) @@ -1010,11 +1115,11 @@ function process_node!(compact::IncrementalCompact, result_idx::Int, inst::Instr elseif isa(stmt, GotoNode) result[result_idx][:inst] = stmt result_idx += 1 - elseif isa(stmt, GotoIfNot) && compact.cfg_transforms_enabled - stmt = renumber_ssa2!(stmt, ssa_rename, used_ssas, late_fixup, result_idx, do_rename_ssa)::GotoIfNot + elseif isa(stmt, GotoIfNot) && cfg_transforms_enabled + stmt = renumber_ssa2!(stmt, ssa_rename, used_ssas, new_new_used_ssas, late_fixup, result_idx, do_rename_ssa)::GotoIfNot result[result_idx][:inst] = stmt cond = stmt.cond - if compact.fold_constant_branches + if fold_constant_branches if !isa(cond, Bool) condT = widenconditional(argextype(cond, compact)) isa(condT, Const) || @goto bail @@ -1036,8 +1141,8 @@ function process_node!(compact::IncrementalCompact, result_idx::Int, inst::Instr result_idx += 1 end elseif isa(stmt, Expr) - stmt = renumber_ssa2!(stmt, ssa_rename, used_ssas, late_fixup, result_idx, do_rename_ssa)::Expr - if compact.cfg_transforms_enabled && isexpr(stmt, :enter) + stmt = renumber_ssa2!(stmt, ssa_rename, used_ssas, new_new_used_ssas, late_fixup, result_idx, do_rename_ssa)::Expr + if cfg_transforms_enabled && isexpr(stmt, :enter) stmt.args[1] = compact.bb_rename_succ[stmt.args[1]::Int] end result[result_idx][:inst] = stmt @@ -1046,10 +1151,11 @@ function process_node!(compact::IncrementalCompact, result_idx::Int, inst::Instr # As an optimization, we eliminate any trivial pinodes. For performance, we use === # type equality. We may want to consider using == in either a separate pass or if # performance turns out ok - stmt = renumber_ssa2!(stmt, ssa_rename, used_ssas, late_fixup, result_idx, do_rename_ssa)::PiNode + stmt = renumber_ssa2!(stmt, ssa_rename, used_ssas, new_new_used_ssas, late_fixup, result_idx, do_rename_ssa)::PiNode pi_val = stmt.val if isa(pi_val, SSAValue) - if stmt.typ === compact.result[pi_val.id][:type] + if stmt.typ === result[pi_val.id][:type] + used_ssas[pi_val.id] -= 1 ssa_rename[idx] = pi_val return result_idx end @@ -1068,10 +1174,10 @@ function process_node!(compact::IncrementalCompact, result_idx::Int, inst::Instr result[result_idx][:inst] = stmt result_idx += 1 elseif isa(stmt, ReturnNode) || isa(stmt, UpsilonNode) || isa(stmt, GotoIfNot) - result[result_idx][:inst] = renumber_ssa2!(stmt, ssa_rename, used_ssas, late_fixup, result_idx, do_rename_ssa) + result[result_idx][:inst] = renumber_ssa2!(stmt, ssa_rename, used_ssas, new_new_used_ssas, late_fixup, result_idx, do_rename_ssa) result_idx += 1 elseif isa(stmt, PhiNode) - if compact.cfg_transforms_enabled + if cfg_transforms_enabled # Rename phi node edges map!(i -> compact.bb_rename_pred[i], stmt.edges, stmt.edges) @@ -1105,7 +1211,7 @@ function process_node!(compact::IncrementalCompact, result_idx::Int, inst::Instr values = stmt.values end - values = process_phinode_values(values, late_fixup, processed_idx, result_idx, ssa_rename, used_ssas, do_rename_ssa) + values = process_phinode_values(values, late_fixup, processed_idx, result_idx, ssa_rename, used_ssas, new_new_used_ssas, do_rename_ssa) # Don't remove the phi node if it is before the definition of its value # because doing so can create forward references. This should only # happen with dead loops, but can cause problems when optimization @@ -1114,17 +1220,21 @@ function process_node!(compact::IncrementalCompact, result_idx::Int, inst::Instr # just to be safe. before_def = isassigned(values, 1) && (v = values[1]; isa(v, OldSSAValue)) && idx < v.id if length(edges) == 1 && isassigned(values, 1) && !before_def && - length(compact.cfg_transforms_enabled ? + length(cfg_transforms_enabled ? compact.result_bbs[compact.bb_rename_succ[active_bb]].preds : compact.ir.cfg.blocks[active_bb].preds) == 1 # There's only one predecessor left - just replace it + @assert !isa(values[1], NewSSAValue) + if isa(values[1], SSAValue) + used_ssas[values[1].id] -= 1 + end ssa_rename[idx] = values[1] else result[result_idx][:inst] = PhiNode(edges, values) result_idx += 1 end elseif isa(stmt, PhiCNode) - result[result_idx][:inst] = PhiCNode(process_phinode_values(stmt.values, late_fixup, processed_idx, result_idx, ssa_rename, used_ssas, do_rename_ssa)) + result[result_idx][:inst] = PhiCNode(process_phinode_values(stmt.values, late_fixup, processed_idx, result_idx, ssa_rename, used_ssas, new_new_used_ssas, do_rename_ssa)) result_idx += 1 elseif isa(stmt, SSAValue) # identity assign, replace uses of this ssa value with its result @@ -1326,31 +1436,34 @@ function iterate(compact::IncrementalCompact, (idx, active_bb)::Tuple{Int, Int}= end function maybe_erase_unused!( - extra_worklist::Vector{Int}, compact::IncrementalCompact, idx::Int, + extra_worklist::Vector{Int}, compact::IncrementalCompact, idx::Int, in_worklist::Bool, callback = null_dce_callback) - stmt = compact.result[idx][:inst] + + inst = idx <= length(compact.result) ? compact.result[idx] : + compact.new_new_nodes.stmts[idx - length(compact.result)] + stmt = inst[:inst] stmt === nothing && return false - if argextype(SSAValue(idx), compact) === Bottom + if inst[:type] === Bottom effect_free = false else - effect_free = compact.result[idx][:flag] & IR_FLAG_EFFECT_FREE != 0 + effect_free = inst[:flag] & IR_FLAG_EFFECT_FREE != 0 end - if effect_free - for ops in userefs(stmt) - val = ops[] - # If the pass we ran inserted new nodes, it's possible for those - # to be outside our used_ssas count. - if isa(val, SSAValue) && val.id <= length(compact.used_ssas) - if compact.used_ssas[val.id] == 1 - if val.id < idx - push!(extra_worklist, val.id) - end - end - compact.used_ssas[val.id] -= 1 - callback(val) + function kill_ssa_value(val::SSAValue) + if compact.used_ssas[val.id] == 1 + if val.id < idx || in_worklist + push!(extra_worklist, val.id) end end - compact.result[idx][:inst] = nothing + compact.used_ssas[val.id] -= 1 + callback(val) + end + if effect_free + if isa(stmt, SSAValue) + kill_ssa_value(stmt) + else + foreachssa(kill_ssa_value, stmt) + end + inst[:inst] = nothing return true end return false @@ -1361,13 +1474,8 @@ function fixup_phinode_values!(compact::IncrementalCompact, old_values::Vector{A for i = 1:length(old_values) isassigned(old_values, i) || continue val = old_values[i] - if isa(val, OldSSAValue) - val = compact.ssa_rename[val.id] - if isa(val, SSAValue) - compact.used_ssas[val.id] += 1 - end - elseif isa(val, NewSSAValue) - val = SSAValue(length(compact.result) + val.id) + if isa(val, Union{OldSSAValue, NewSSAValue}) + val = fixup_node(compact, val) end values[i] = val end @@ -1382,29 +1490,30 @@ function fixup_node(compact::IncrementalCompact, @nospecialize(stmt)) elseif isa(stmt, NewSSAValue) return SSAValue(length(compact.result) + stmt.id) elseif isa(stmt, OldSSAValue) - return compact.ssa_rename[stmt.id] + val = compact.ssa_rename[stmt.id] + if isa(val, SSAValue) + # If `val.id` is greater than the length of `compact.result` or + # `compact.used_ssas`, this SSA value is in `new_new_nodes`, so + # don't count the use + compact.used_ssas[val.id] += 1 + end + return val else urs = userefs(stmt) for ur in urs val = ur[] - if isa(val, NewSSAValue) - val = SSAValue(length(compact.result) + val.id) - elseif isa(val, OldSSAValue) - val = compact.ssa_rename[val.id] - end - if isa(val, SSAValue) && val.id <= length(compact.used_ssas) - # If `val.id` is greater than the length of `compact.result` or - # `compact.used_ssas`, this SSA value is in `new_new_nodes`, so - # don't count the use - compact.used_ssas[val.id] += 1 + if isa(val, Union{NewSSAValue, OldSSAValue}) + ur[] = fixup_node(compact, val) end - ur[] = val end return urs[] end end function just_fixup!(compact::IncrementalCompact) + resize!(compact.used_ssas, length(compact.result)) + append!(compact.used_ssas, compact.new_new_used_ssas) + empty!(compact.new_new_used_ssas) for idx in compact.late_fixup stmt = compact.result[idx][:inst] new_stmt = fixup_node(compact, stmt) @@ -1422,14 +1531,14 @@ end function simple_dce!(compact::IncrementalCompact, callback = null_dce_callback) # Perform simple DCE for unused values + @assert isempty(compact.new_new_used_ssas) # just_fixup! wasn't run? extra_worklist = Int[] for (idx, nused) in Iterators.enumerate(compact.used_ssas) - idx >= compact.result_idx && break nused == 0 || continue - maybe_erase_unused!(extra_worklist, compact, idx, callback) + maybe_erase_unused!(extra_worklist, compact, idx, false, callback) end while !isempty(extra_worklist) - maybe_erase_unused!(extra_worklist, compact, pop!(extra_worklist), callback) + maybe_erase_unused!(extra_worklist, compact, pop!(extra_worklist), true, callback) end end diff --git a/base/compiler/ssair/passes.jl b/base/compiler/ssair/passes.jl index 7aeb303bc03a2..7fcaa79a468d5 100644 --- a/base/compiler/ssair/passes.jl +++ b/base/compiler/ssair/passes.jl @@ -440,10 +440,15 @@ function lift_arg!( lifted = stmt.args[argidx] if is_old(compact, leaf) && isa(lifted, SSAValue) lifted = OldSSAValue(lifted.id) + if already_inserted(compact, lifted) + lifted = compact.ssa_rename[lifted.id] + end end if isa(lifted, GlobalRef) || isa(lifted, Expr) lifted = insert_node!(compact, leaf, effect_free(NewInstruction(lifted, argextype(lifted, compact)))) + compact[leaf] = nothing stmt.args[argidx] = lifted + compact[leaf] = stmt if isa(leaf, SSAValue) && leaf.id < compact.result_idx push!(compact.late_fixup, leaf.id) end @@ -556,7 +561,7 @@ function lift_comparison_leaves!(@specialize(tfunc), # perform lifting lifted_val = perform_lifting!(compact, visited_phinodes, cmp, lifting_cache, Bool, - lifted_leaves::LiftedLeaves, val)::LiftedValue + lifted_leaves::LiftedLeaves, val, ()->nothing, idx)::LiftedValue compact[idx] = lifted_val.x end @@ -576,9 +581,43 @@ end function perform_lifting!(compact::IncrementalCompact, visited_phinodes::Vector{AnySSAValue}, @nospecialize(cache_key), lifting_cache::IdDict{Pair{AnySSAValue, Any}, AnySSAValue}, - @nospecialize(result_t), lifted_leaves::LiftedLeaves, @nospecialize(stmt_val)) + @nospecialize(result_t), lifted_leaves::LiftedLeaves, @nospecialize(stmt_val), get_domtree, idx::Int) reverse_mapping = IdDict{AnySSAValue, Int}(ssa => id for (id, ssa) in enumerate(visited_phinodes)) + # Check if all the lifted leaves are the same + local the_leaf + all_same = true + for (_, val) in lifted_leaves + if !@isdefined(the_leaf) + the_leaf = val + continue + end + if val !== the_leaf + all_same = false + end + end + + the_leaf_val = isa(the_leaf, LiftedValue) ? the_leaf.x : nothing + if !isa(the_leaf_val, SSAValue) + all_same = false + end + + if all_same + dominates_all = true + domtree = get_domtree() + if domtree !== nothing + for item in visited_phinodes + if !dominates_ssa(compact, domtree, the_leaf_val, item) + dominates_all = false + break + end + end + if dominates_all + return the_leaf + end + end + end + # Insert PhiNodes lifted_phis = LiftedPhi[] for item in visited_phinodes @@ -632,10 +671,7 @@ function perform_lifting!(compact::IncrementalCompact, # Probably ignored by path condition, skip this end end - end - - for lf in lifted_phis - count_added_node!(compact, lf.node) + count_added_node!(compact, new_node) end # Fixup the stmt itself @@ -678,6 +714,14 @@ function sroa_pass!(ir::IRCode) compact = IncrementalCompact(ir) defuses = nothing # will be initialized once we encounter mutability in order to reduce dynamic allocations lifting_cache = IdDict{Pair{AnySSAValue, Any}, AnySSAValue}() + # initialization of domtree is delayed to avoid the expensive computation in many cases + local domtree = nothing + function get_domtree() + if domtree === nothing + @timeit "domtree 2" domtree = construct_domtree(ir.cfg.blocks) + end + return domtree + end for ((_, idx), stmt) in compact # check whether this statement is `getfield` / `setfield!` (or other "interesting" statement) isa(stmt, Expr) || continue @@ -740,6 +784,7 @@ function sroa_pass!(ir::IRCode) continue end if !isempty(new_preserves) + compact[idx] = nothing compact[idx] = form_new_preserves(stmt, preserved, new_preserves) end continue @@ -819,7 +864,7 @@ function sroa_pass!(ir::IRCode) end val = perform_lifting!(compact, - visited_phinodes, field, lifting_cache, result_t, lifted_leaves, val) + visited_phinodes, field, lifting_cache, result_t, lifted_leaves, val, get_domtree, idx) # Insert the undef check if necessary if any_undef @@ -846,7 +891,7 @@ function sroa_pass!(ir::IRCode) used_ssas = copy(compact.used_ssas) simple_dce!(compact, (x::SSAValue) -> used_ssas[x.id] -= 1) ir = complete(compact) - sroa_mutables!(ir, defuses, used_ssas) + sroa_mutables!(ir, defuses, used_ssas, get_domtree) return ir else simple_dce!(compact) @@ -854,9 +899,7 @@ function sroa_pass!(ir::IRCode) end end -function sroa_mutables!(ir::IRCode, defuses::IdDict{Int, Tuple{SPCSet, SSADefUse}}, used_ssas::Vector{Int}) - # initialization of domtree is delayed to avoid the expensive computation in many cases - local domtree = nothing +function sroa_mutables!(ir::IRCode, defuses::IdDict{Int, Tuple{SPCSet, SSADefUse}}, used_ssas::Vector{Int}, get_domtree) for (idx, (intermediaries, defuse)) in defuses intermediaries = collect(intermediaries) # Check if there are any uses we did not account for. If so, the variable @@ -920,15 +963,13 @@ function sroa_mutables!(ir::IRCode, defuses::IdDict{Int, Tuple{SPCSet, SSADefUse if isempty(ldu.live_in_bbs) phiblocks = Int[] else - domtree === nothing && (@timeit "domtree 2" domtree = construct_domtree(ir.cfg.blocks)) - phiblocks = iterated_dominance_frontier(ir.cfg, ldu, domtree) + phiblocks = iterated_dominance_frontier(ir.cfg, ldu, get_domtree()) end allblocks = sort(vcat(phiblocks, ldu.def_bbs)) blocks[fidx] = phiblocks, allblocks if fidx + 1 > length(defexpr.args) for use in du.uses - domtree === nothing && (@timeit "domtree 2" domtree = construct_domtree(ir.cfg.blocks)) - has_safe_def(ir, domtree, allblocks, du, newidx, use) || @goto skip + has_safe_def(ir, get_domtree(), allblocks, du, newidx, use) || @goto skip end end end @@ -936,7 +977,7 @@ function sroa_mutables!(ir::IRCode, defuses::IdDict{Int, Tuple{SPCSet, SSADefUse # Compute domtree now, needed below, now that we have finished compacting the IR. # This needs to be after we iterate through the IR with `IncrementalCompact` # because removing dead blocks can invalidate the domtree. - domtree === nothing && (@timeit "domtree 2" domtree = construct_domtree(ir.cfg.blocks)) + domtree = get_domtree() preserve_uses = isempty(defuse.ccall_preserve_uses) ? nothing : IdDict{Int, Vector{Any}}((idx=>Any[] for idx in SPCSet(defuse.ccall_preserve_uses))) for fidx in 1:ndefuse @@ -1031,12 +1072,12 @@ function canonicalize_typeassert!(compact::IncrementalCompact, idx::Int, stmt::E compact.ssa_rename[compact.idx-1] = pi end -function adce_erase!(phi_uses::Vector{Int}, extra_worklist::Vector{Int}, compact::IncrementalCompact, idx::Int) +function adce_erase!(phi_uses::Vector{Int}, extra_worklist::Vector{Int}, compact::IncrementalCompact, idx::Int, in_worklist::Bool) # return whether this made a change if isa(compact.result[idx][:inst], PhiNode) - return maybe_erase_unused!(extra_worklist, compact, idx, val::SSAValue -> phi_uses[val.id] -= 1) + return maybe_erase_unused!(extra_worklist, compact, idx, in_worklist, val::SSAValue -> phi_uses[val.id] -= 1) else - return maybe_erase_unused!(extra_worklist, compact, idx) + return maybe_erase_unused!(extra_worklist, compact, idx, in_worklist) end end @@ -1189,10 +1230,10 @@ function adce_pass!(ir::IRCode) for (idx, nused) in Iterators.enumerate(compact.used_ssas) idx >= compact.result_idx && break nused == 0 || continue - adce_erase!(phi_uses, extra_worklist, compact, idx) + adce_erase!(phi_uses, extra_worklist, compact, idx, false) end while !isempty(extra_worklist) - adce_erase!(phi_uses, extra_worklist, compact, pop!(extra_worklist)) + adce_erase!(phi_uses, extra_worklist, compact, pop!(extra_worklist), true) end # Go back and erase any phi cycles changed = true @@ -1211,7 +1252,7 @@ function adce_pass!(ir::IRCode) end end while !isempty(extra_worklist) - if adce_erase!(phi_uses, extra_worklist, compact, pop!(extra_worklist)) + if adce_erase!(phi_uses, extra_worklist, compact, pop!(extra_worklist), true) changed = true end end diff --git a/test/compiler/irpasses.jl b/test/compiler/irpasses.jl index 38f60b0fd12aa..045cf833944c2 100644 --- a/test/compiler/irpasses.jl +++ b/test/compiler/irpasses.jl @@ -834,3 +834,25 @@ end @test Core.Compiler.builtin_effects(getfield, Any[Complex{Int}, Symbol], Any).effect_free.state == 0x01 @test Core.Compiler.builtin_effects(getglobal, Any[Module, Symbol], Any).effect_free.state == 0x01 +# Test that UseRefIterator gets SROA'd inside of new_to_regular (#44557) +# expression and new_to_regular offset are arbitrary here, we just want to see the UseRefIterator erased +let e = Expr(:call, Core.GlobalRef(Base, :arrayset), false, Core.SSAValue(4), Core.SSAValue(9), Core.SSAValue(8)) + new_to_reg(expr) = Core.Compiler.new_to_regular(expr, 1) + @allocated new_to_reg(e) # warmup call + @test (@allocated new_to_reg(e)) == 0 +end + +# Test that SROA doesn't try to forward a previous iteration's SSA value +let sroa_no_forward() = begin + res = (0, 0) + for i in 1:5 + a = first(res) + a == 5 && error() + if i == 1 + res = (i, 2.0) + end + end + return res + end + @test sroa_no_forward() == (1, 2.0) +end