From d515f05812d544d6a4a5aa7486835035f424f17c Mon Sep 17 00:00:00 2001 From: Shuhei Kadowaki <40514306+aviatesk@users.noreply.github.com> Date: Mon, 29 Nov 2021 11:09:39 +0900 Subject: [PATCH] optimizer: refactor SROA pass (#43232) - avoid domtree construction when there are no mutables to eliminate - reduce # of dynamic allocations - separate some computations into individual functions --- base/compiler/ssair/passes.jl | 173 ++++++++++++++++++---------------- 1 file changed, 91 insertions(+), 82 deletions(-) diff --git a/base/compiler/ssair/passes.jl b/base/compiler/ssair/passes.jl index 7dc5856244cf3..1bf92c81770b9 100644 --- a/base/compiler/ssair/passes.jl +++ b/base/compiler/ssair/passes.jl @@ -23,7 +23,7 @@ SSADefUse() = SSADefUse(Int[], Int[], Int[]) compute_live_ins(cfg::CFG, du::SSADefUse) = compute_live_ins(cfg, du.defs, du.uses) -function try_compute_field_stmt(compact::IncrementalCompact, stmt::Expr) +function try_compute_field_stmt(ir::Union{IncrementalCompact,IRCode}, stmt::Expr) field = stmt.args[3] # fields are usually literals, handle them manually if isa(field, QuoteNode) @@ -31,7 +31,7 @@ function try_compute_field_stmt(compact::IncrementalCompact, stmt::Expr) elseif isa(field, Int) # try to resolve other constants, e.g. global reference else - field = compact_exprtype(compact, field) + field = isa(ir, IncrementalCompact) ? compact_exprtype(ir, field) : argextype(field, ir) if isa(field, Const) field = field.val else @@ -42,8 +42,8 @@ function try_compute_field_stmt(compact::IncrementalCompact, stmt::Expr) return field end -function try_compute_fieldidx_stmt(compact::IncrementalCompact, stmt::Expr, typ::DataType) - field = try_compute_field_stmt(compact, stmt) +function try_compute_fieldidx_stmt(ir::Union{IncrementalCompact,IRCode}, stmt::Expr, typ::DataType) + field = try_compute_field_stmt(ir, stmt) return try_compute_fieldidx(typ, field) end @@ -112,6 +112,13 @@ function find_def_for_use(ir::IRCode, domtree::DomTree, allblocks::Vector{Int}, return def, stmtblock, curblock end +function collect_leaves(compact::IncrementalCompact, @nospecialize(val), @nospecialize(typeconstraint)) + if isa(val, Union{OldSSAValue, SSAValue}) + val, typeconstraint = simple_walk_constraint(compact, val, typeconstraint) + end + return walk_to_defs(compact, val, typeconstraint) +end + function simple_walk(compact::IncrementalCompact, @nospecialize(defssa#=::AnySSAValue=#), callback = (@nospecialize(pi), @nospecialize(idx)) -> false) while true @@ -152,7 +159,7 @@ function simple_walk(compact::IncrementalCompact, @nospecialize(defssa#=::AnySSA end function simple_walk_constraint(compact::IncrementalCompact, @nospecialize(defssa#=::AnySSAValue=#), - @nospecialize(typeconstraint = types(compact)[defssa])) + @nospecialize(typeconstraint)) callback = function (@nospecialize(pi), @nospecialize(idx)) if isa(pi, PiNode) typeconstraint = typeintersect(typeconstraint, widenconst(pi.typ)) @@ -164,20 +171,16 @@ function simple_walk_constraint(compact::IncrementalCompact, @nospecialize(defss end """ - walk_to_defs(compact, val, intermediaries) + walk_to_defs(compact, val, typeconstraint) -Starting at `val` walk use-def chains to get all the leaves feeding into -this val (pruning those leaves rules out by path conditions). +Starting at `val` walk use-def chains to get all the leaves feeding into this `val` +(pruning those leaves rules out by path conditions). """ -function walk_to_defs(compact::IncrementalCompact, @nospecialize(defssa), @nospecialize(typeconstraint), visited_phinodes::Vector{AnySSAValue}=AnySSAValue[]) - isa(defssa, AnySSAValue) || return Any[defssa] +function walk_to_defs(compact::IncrementalCompact, @nospecialize(defssa), @nospecialize(typeconstraint)) + visited_phinodes = AnySSAValue[] + isa(defssa, AnySSAValue) || return Any[defssa], visited_phinodes def = compact[defssa] - isa(def, PhiNode) || return Any[defssa] - # Step 2: Figure out what the struct is defined as - ## Track definitions through PiNode/PhiNode - found_def = false - ## Track which PhiNodes, SSAValue intermediaries - ## we forwarded through. + isa(def, PhiNode) || return Any[defssa], visited_phinodes visited_constraints = IdDict{AnySSAValue, Any}() worklist_defs = AnySSAValue[] worklist_constraints = Any[] @@ -239,10 +242,10 @@ function walk_to_defs(compact::IncrementalCompact, @nospecialize(defssa), @nospe push!(leaves, defssa) end end - leaves + return leaves, visited_phinodes end -function process_immutable_preserve(new_preserves::Vector{Any}, compact::IncrementalCompact, def::Expr) +function process_immutable_preserve!(new_preserves::Vector{Any}, compact::IncrementalCompact, def::Expr) for arg in (isexpr(def, :new) ? def.args : def.args[2:end]) if !isbitstype(widenconst(compact_exprtype(compact, arg))) push!(new_preserves, arg) @@ -449,13 +452,10 @@ function lift_comparison!(compact::IncrementalCompact, return end - if isa(val, Union{OldSSAValue, SSAValue}) - val, typeconstraint = simple_walk_constraint(compact, val, typeconstraint) - end - - visited_phinodes = AnySSAValue[] - leaves = walk_to_defs(compact, val, typeconstraint, visited_phinodes) + valtyp = widenconst(compact_exprtype(compact, val)) + isa(valtyp, Union) || return # bail out if there won't be a good chance for lifting + leaves, visited_phinodes = collect_leaves(compact, val, valtyp) length(leaves) ≤ 1 && return # bail out if we don't have multiple leaves # Let's check if we evaluate the comparison for each one of the leaves @@ -476,10 +476,6 @@ function lift_comparison!(compact::IncrementalCompact, visited_phinodes, cmp, lifting_cache, Bool, lifted_leaves::IdDict{Any, Union{Nothing,LiftedValue}}, val)::LiftedValue - # global assertion_counter - # assertion_counter::Int += 1 - # insert_node_here!(compact, Expr(:assert_egal, Symbol(string("assert_egal_", assertion_counter)), SSAValue(idx), lifted_val), nothing, 0, true) - # return compact[idx] = lifted_val.x end @@ -576,6 +572,10 @@ function perform_lifting!(compact::IncrementalCompact, return stmt_val # N.B. should never happen end +# NOTE we use `IdSet{Int}` instead of `BitSet` for `sroa_pass!` since it works on IR after inlining, +# which can be very large sometimes, and analyzed program counters are often very sparse +const SPCSet = IdSet{Int} + """ sroa_pass!(ir::IRCode) -> newir::IRCode @@ -596,17 +596,16 @@ a result of succeeding dead code elimination. """ function sroa_pass!(ir::IRCode) compact = IncrementalCompact(ir) - defuses = IdDict{Int, Tuple{IdSet{Int}, SSADefUse}}() + defuses = nothing # will be initialized once we encounter mutability in order to reduce dynamic allocations lifting_cache = IdDict{Pair{AnySSAValue, Any}, AnySSAValue}() for ((_, idx), stmt) in compact + # check whether this statement is `getfield` / `setfield!` (or other "interesting" statement) isa(stmt, Expr) || continue - result_t = compact_exprtype(compact, SSAValue(idx)) is_setfield = false field_ordering = :unspecified - # Step 1: Check whether the statement we're looking at is a getfield/setfield! if is_known_call(stmt, setfield!, compact) - is_setfield = true 4 <= length(stmt.args) <= 5 || continue + is_setfield = true if length(stmt.args) == 5 field_ordering = compact_exprtype(compact, stmt.args[5]) end @@ -624,7 +623,7 @@ function sroa_pass!(ir::IRCode) old_preserves = stmt.args[(6+nccallargs):end] for (pidx, preserved_arg) in enumerate(old_preserves) isa(preserved_arg, SSAValue) || continue - let intermediaries = IdSet{Int}() + let intermediaries = SPCSet() callback = function (@nospecialize(pi), @nospecialize(ssa)) push!(intermediaries, ssa.id) return false @@ -634,7 +633,7 @@ function sroa_pass!(ir::IRCode) defidx = def.id def = compact[defidx] if is_tuple_call(compact, def) - process_immutable_preserve(new_preserves, compact, def) + process_immutable_preserve!(new_preserves, compact, def) old_preserves[pidx] = nothing continue elseif isexpr(def, :new) @@ -643,14 +642,17 @@ function sroa_pass!(ir::IRCode) typ = unwrap_unionall(typ) end if typ isa DataType && !ismutabletype(typ) - process_immutable_preserve(new_preserves, compact, def) + process_immutable_preserve!(new_preserves, compact, def) old_preserves[pidx] = nothing continue end else continue end - mid, defuse = get!(defuses, defidx, (IdSet{Int}(), SSADefUse())) + if defuses === nothing + defuses = IdDict{Int, Tuple{SPCSet, SSADefUse}}() + end + mid, defuse = get!(defuses, defidx, (SPCSet(), SSADefUse())) push!(defuse.ccall_preserve_uses, idx) union!(mid, intermediaries) end @@ -675,10 +677,15 @@ function sroa_pass!(ir::IRCode) else continue end + + # analyze this `getfield` / `setfield!` call + field = try_compute_field_stmt(compact, stmt) field === nothing && continue - struct_typ = unwrap_unionall(widenconst(compact_exprtype(compact, stmt.args[2]))) + val = stmt.args[2] + + struct_typ = unwrap_unionall(widenconst(compact_exprtype(compact, val))) if isa(struct_typ, Union) && struct_typ <: Tuple struct_typ = unswitchtupleunion(struct_typ) end @@ -689,19 +696,21 @@ function sroa_pass!(ir::IRCode) continue end - def, typeconstraint = stmt.args[2], struct_typ - + # analyze this mutable struct here for the later pass if ismutabletype(struct_typ) - isa(def, SSAValue) || continue - let intermediaries = IdSet{Int}() + isa(val, SSAValue) || continue + let intermediaries = SPCSet() callback = function (@nospecialize(pi), @nospecialize(ssa)) push!(intermediaries, ssa.id) return false end - def = simple_walk(compact, def, callback) + def = simple_walk(compact, val, callback) # Mutable stuff here isa(def, SSAValue) || continue - mid, defuse = get!(defuses, def.id, (IdSet{Int}(), SSADefUse())) + if defuses === nothing + defuses = IdDict{Int, Tuple{SPCSet, SSADefUse}}() + end + mid, defuse = get!(defuses, def.id, (SPCSet(), SSADefUse())) if is_setfield push!(defuse.defs, idx) else @@ -711,32 +720,28 @@ function sroa_pass!(ir::IRCode) end continue elseif is_setfield - continue + continue # invalid `setfield!` call, but just ignore here end # perform SROA on immutable structs here on - if isa(def, Union{OldSSAValue, SSAValue}) - def, typeconstraint = simple_walk_constraint(compact, def, typeconstraint) - end - - visited_phinodes = AnySSAValue[] - leaves = walk_to_defs(compact, def, typeconstraint, visited_phinodes) - - isempty(leaves) && continue - field = try_compute_fieldidx(struct_typ, field) field === nothing && continue - r = lift_leaves(compact, result_t, field, leaves) - r === nothing && continue - lifted_leaves, any_undef = r + leaves, visited_phinodes = collect_leaves(compact, val, struct_typ) + isempty(leaves) && continue + + result_t = compact_exprtype(compact, SSAValue(idx)) + lifted_result = lift_leaves(compact, result_t, field, leaves) + lifted_result === nothing && continue + lifted_leaves, any_undef = lifted_result if any_undef result_t = make_MaybeUndef(result_t) end - val = perform_lifting!(compact, visited_phinodes, field, lifting_cache, result_t, lifted_leaves, stmt.args[2]) + val = perform_lifting!(compact, + visited_phinodes, field, lifting_cache, result_t, lifted_leaves, val) # Insert the undef check if necessary if any_undef @@ -750,28 +755,32 @@ function sroa_pass!(ir::IRCode) @assert val !== nothing end - # global assertion_counter - # assertion_counter::Int += 1 - # insert_node_here!(compact, Expr(:assert_egal, Symbol(string("assert_egal_", assertion_counter)), SSAValue(idx), val), nothing, 0, true) - # continue compact[idx] = val === nothing ? nothing : val.x end non_dce_finish!(compact) - # Copy the use count, `simple_dce!` may modify it and for our predicate - # below we need it consistent with the state of the IR here (after tracking - # phi node arguments, but before dce). - used_ssas = copy(compact.used_ssas) - simple_dce!(compact) - ir = complete(compact) - - # Compute domtree, needed below, now that we have finished compacting the - # IR. This needs to be after we iterate through the IR with - # `IncrementalCompact` because removing dead blocks can invalidate the - # domtree. + if defuses !== nothing + # now go through analyzed mutable structs and see which ones we can eliminate + # NOTE copy the use count here, because `simple_dce!` may modify it and we need it + # consistent with the state of the IR here (after tracking `PhiNode` arguments, + # but before the DCE) for our predicate within `sroa_mutables!` + used_ssas = copy(compact.used_ssas) + simple_dce!(compact) + ir = complete(compact) + sroa_mutables!(ir, defuses, used_ssas) + return ir + else + simple_dce!(compact) + return complete(compact) + end +end + +function sroa_mutables!(ir::IRCode, defuses::IdDict{Int, Tuple{SPCSet, SSADefUse}}, used_ssas::Vector{Int}) + # Compute domtree, needed below, now that we have finished compacting the IR. + # This needs to be after we iterate through the IR with `IncrementalCompact` + # because removing dead blocks can invalidate the domtree. @timeit "domtree 2" domtree = construct_domtree(ir.cfg.blocks) - # Now go through any mutable structs and see which ones we can eliminate for (idx, (intermediaries, defuse)) in defuses intermediaries = collect(intermediaries) # Check if there are any uses we did not account for. If so, the variable @@ -806,12 +815,12 @@ function sroa_pass!(ir::IRCode) # it would have been deleted. That's fine, just ignore # the use in that case. stmt === nothing && continue - field = try_compute_fieldidx_stmt(compact, stmt::Expr, typ) + field = try_compute_fieldidx_stmt(ir, stmt::Expr, typ) field === nothing && @goto skip push!(fielddefuse[field].uses, use) end for use in defuse.defs - field = try_compute_fieldidx_stmt(compact, ir[SSAValue(use)]::Expr, typ) + field = try_compute_fieldidx_stmt(ir, ir[SSAValue(use)]::Expr, typ) field === nothing && @goto skip push!(fielddefuse[field].defs, use) end @@ -846,8 +855,9 @@ function sroa_pass!(ir::IRCode) end end end - preserve_uses = IdDict{Int, Vector{Any}}((idx=>Any[] for idx in IdSet{Int}(defuse.ccall_preserve_uses))) # Everything accounted for. Go field by field and perform idf + preserve_uses = isempty(defuse.ccall_preserve_uses) ? nothing : + IdDict{Int, Vector{Any}}((idx=>Any[] for idx in SPCSet(defuse.ccall_preserve_uses))) for fidx in 1:ndefuse du = fielddefuse[fidx] ftyp = fieldtype(typ, fidx) @@ -863,8 +873,10 @@ function sroa_pass!(ir::IRCode) ir[SSAValue(stmt)] = compute_value_for_use(ir, domtree, allblocks, du, phinodes, fidx, stmt) end if !isbitstype(ftyp) - for (use, list) in preserve_uses - push!(list, compute_value_for_use(ir, domtree, allblocks, du, phinodes, fidx, use)) + if preserve_uses !== nothing + for (use, list) in preserve_uses + push!(list, compute_value_for_use(ir, domtree, allblocks, du, phinodes, fidx, use)) + end end end for b in phiblocks @@ -881,7 +893,7 @@ function sroa_pass!(ir::IRCode) ir[SSAValue(stmt)] = nothing end end - isempty(defuse.ccall_preserve_uses) && continue + preserve_uses === nothing && continue push!(intermediaries, newidx) # Insert the new preserves for (use, new_preserves) in preserve_uses @@ -897,10 +909,7 @@ function sroa_pass!(ir::IRCode) @label skip end - - return ir end -# assertion_counter = 0 """ canonicalize_typeassert!(compact::IncrementalCompact, idx::Int, stmt::Expr)