Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor lattice code to expose layering and enable easy extension #46526

Merged
merged 1 commit into from
Sep 2, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
122 changes: 67 additions & 55 deletions base/compiler/abstractinterpretation.jl

Large diffs are not rendered by default.

173 changes: 173 additions & 0 deletions base/compiler/abstractlattice.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
abstract type AbstractLattice; end
function widenlattice end

"""
struct JLTypeLattice

A singleton type representing the lattice of Julia types, without any inference
extensions.
"""
struct JLTypeLattice <: AbstractLattice; end
widenlattice(::JLTypeLattice) = error("Type lattice is the least-precise lattice available")
is_valid_lattice(::JLTypeLattice, @nospecialize(elem)) = isa(elem, Type)

"""
struct ConstsLattice

A lattice extending `JLTypeLattice` and adjoining `Const` and `PartialTypeVar`.
"""
struct ConstsLattice <: AbstractLattice; end
widenlattice(::ConstsLattice) = JLTypeLattice()
is_valid_lattice(lattice::ConstsLattice, @nospecialize(elem)) =
is_valid_lattice(widenlattice(lattice), elem) || isa(elem, Const) || isa(elem, PartialTypeVar)

"""
struct PartialsLattice{L}

A lattice extending lattice `L` and adjoining `PartialStruct` and `PartialOpaque`.
"""
struct PartialsLattice{L <: AbstractLattice} <: AbstractLattice
parent::L
end
widenlattice(L::PartialsLattice) = L.parent
is_valid_lattice(lattice::PartialsLattice, @nospecialize(elem)) =
is_valid_lattice(widenlattice(lattice), elem) ||
isa(elem, PartialStruct) || isa(elem, PartialOpaque)

"""
struct ConditionalsLattice{L}

A lattice extending lattice `L` and adjoining `Conditional`.
"""
struct ConditionalsLattice{L <: AbstractLattice} <: AbstractLattice
parent::L
end
widenlattice(L::ConditionalsLattice) = L.parent
is_valid_lattice(lattice::ConditionalsLattice, @nospecialize(elem)) =
is_valid_lattice(widenlattice(lattice), elem) || isa(elem, Conditional)

struct InterConditionalsLattice{L <: AbstractLattice} <: AbstractLattice
parent::L
end
widenlattice(L::InterConditionalsLattice) = L.parent
is_valid_lattice(lattice::InterConditionalsLattice, @nospecialize(elem)) =
is_valid_lattice(widenlattice(lattice), elem) || isa(elem, InterConditional)

const AnyConditionalsLattice{L} = Union{ConditionalsLattice{L}, InterConditionalsLattice{L}}
const BaseInferenceLattice = typeof(ConditionalsLattice(PartialsLattice(ConstsLattice())))
const IPOResultLattice = typeof(InterConditionalsLattice(PartialsLattice(ConstsLattice())))

"""
struct InferenceLattice{L}

The full lattice used for abstract interpration during inference. Takes
a base lattice and adjoins `LimitedAccuracy`.
Comment on lines +63 to +64
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is only usually legal at a few explicit parts of the code (in IPO and typeinf_local), since we didn't implement handling for it in most of the rest of the code. I guess we theoretically could implement that handling though, so it is permitted here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It had support in the various lattice operations, so I added a layer for it. Of course, it technically violates a whole bunch of lattice assumptions, but that's a different problem. I think it's fine as a lattice layer, though of course the optimizer does not use it.

"""
struct InferenceLattice{L} <: AbstractLattice
parent::L
end
widenlattice(L::InferenceLattice) = L.parent
is_valid_lattice(lattice::InferenceLattice, @nospecialize(elem)) =
is_valid_lattice(widenlattice(lattice), elem) || isa(elem, LimitedAccuracy)

"""
struct OptimizerLattice

The lattice used by the optimizer. Extends
`BaseInferenceLattice` with `MaybeUndef`.
"""
struct OptimizerLattice <: AbstractLattice; end
widenlattice(L::OptimizerLattice) = BaseInferenceLattice.instance
is_valid_lattice(lattice::OptimizerLattice, @nospecialize(elem)) =
is_valid_lattice(widenlattice(lattice), elem) || isa(elem, MaybeUndef)

"""
tmeet(lattice, a, b::Type)

Compute the lattice meet of lattice elements `a` and `b` over the lattice
`lattice`. If `lattice` is `JLTypeLattice`, this is equiavalent to type
intersection. Note that currently `b` is restricted to being a type (interpreted
as a lattice element in the JLTypeLattice sub-lattice of `lattice`).
"""
function tmeet end

function tmeet(::JLTypeLattice, @nospecialize(a::Type), @nospecialize(b::Type))
ti = typeintersect(a, b)
valid_as_lattice(ti) || return Bottom
return ti
end

"""
tmerge(lattice, a, b)

Compute a lattice join of elements `a` and `b` over the lattice `lattice`.
Note that the computed element need not be the least upper bound of `a` and
`b`, but rather, we impose additional limitations on the complexity of the
joined element, ideally without losing too much precision in common cases and
remaining mostly associative and commutative.
"""
function tmerge end

"""
⊑(lattice, a, b)

Compute the lattice ordering (i.e. less-than-or-equal) relationship between
lattice elements `a` and `b` over the lattice `lattice`. If `lattice` is
`JLTypeLattice`, this is equiavalent to subtyping.
"""
function ⊑ end

⊑(::JLTypeLattice, @nospecialize(a::Type), @nospecialize(b::Type)) = a <: b

"""
⊏(lattice, a, b) -> Bool

The strict partial order over the type inference lattice.
This is defined as the irreflexive kernel of `⊑`.
"""
⊏(lattice::AbstractLattice, @nospecialize(a), @nospecialize(b)) = ⊑(lattice, a, b) && !⊑(lattice, b, a)

"""
⋤(lattice, a, b) -> Bool

This order could be used as a slightly more efficient version of the strict order `⊏`,
where we can safely assume `a ⊑ b` holds.
"""
⋤(lattice::AbstractLattice, @nospecialize(a), @nospecialize(b)) = !⊑(lattice, b, a)

"""
is_lattice_equal(lattice, a, b) -> Bool

Check if two lattice elements are partial order equivalent.
This is basically `a ⊑ b && b ⊑ a` but (optionally) with extra performance optimizations.
"""
function is_lattice_equal(lattice::AbstractLattice, @nospecialize(a), @nospecialize(b))
a === b && return true
⊑(lattice, a, b) && ⊑(lattice, b, a)
end

"""
has_nontrivial_const_info(lattice, t) -> Bool

Determine whether the given lattice element `t` of `lattice` has non-trivial
constant information that would not be available from the type itself.
"""
has_nontrivial_const_info(lattice::AbstractLattice, @nospecialize t) =
has_nontrivial_const_info(widenlattice(lattice), t)
has_nontrivial_const_info(::JLTypeLattice, @nospecialize(t)) = false

# Curried versions
⊑(lattice::AbstractLattice) = (@nospecialize(a), @nospecialize(b)) -> ⊑(lattice, a, b)
⊏(lattice::AbstractLattice) = (@nospecialize(a), @nospecialize(b)) -> ⊏(lattice, a, b)
⋤(lattice::AbstractLattice) = (@nospecialize(a), @nospecialize(b)) -> ⋤(lattice, a, b)

# Fallbacks for external packages using these methods
const fallback_lattice = InferenceLattice(BaseInferenceLattice.instance)
const fallback_ipo_lattice = InferenceLattice(IPOResultLattice.instance)

⊑(@nospecialize(a), @nospecialize(b)) = ⊑(fallback_lattice, a, b)
tmeet(@nospecialize(a), @nospecialize(b)) = tmeet(fallback_lattice, a, b)
tmerge(@nospecialize(a), @nospecialize(b)) = tmerge(fallback_lattice, a, b)
⊏(@nospecialize(a), @nospecialize(b)) = ⊏(fallback_lattice, a, b)
⋤(@nospecialize(a), @nospecialize(b)) = ⋤(fallback_lattice, a, b)
is_lattice_equal(@nospecialize(a), @nospecialize(b)) = is_lattice_equal(fallback_lattice, a, b)
2 changes: 2 additions & 0 deletions base/compiler/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,8 @@ include("compiler/ssair/basicblock.jl")
include("compiler/ssair/domtree.jl")
include("compiler/ssair/ir.jl")

include("compiler/abstractlattice.jl")

include("compiler/inferenceresult.jl")
include("compiler/inferencestate.jl")

Expand Down
13 changes: 7 additions & 6 deletions base/compiler/inferenceresult.jl
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
# This file is a part of Julia. License is MIT: https://julialang.org/license

function is_argtype_match(@nospecialize(given_argtype),
function is_argtype_match(lattice::AbstractLattice,
@nospecialize(given_argtype),
@nospecialize(cache_argtype),
overridden_by_const::Bool)
if is_forwardable_argtype(given_argtype)
return is_lattice_equal(given_argtype, cache_argtype)
return is_lattice_equal(lattice, given_argtype, cache_argtype)
end
return !overridden_by_const
end
Expand Down Expand Up @@ -91,7 +92,7 @@ function matching_cache_argtypes(
for i in 1:nargs
given_argtype = given_argtypes[i]
cache_argtype = cache_argtypes[i]
if !is_argtype_match(given_argtype, cache_argtype, false)
if !is_argtype_match(fallback_lattice, given_argtype, cache_argtype, false)
# prefer the argtype we were given over the one computed from `linfo`
cache_argtypes[i] = given_argtype
overridden_by_const[i] = true
Expand Down Expand Up @@ -207,7 +208,7 @@ function matching_cache_argtypes(linfo::MethodInstance, ::Nothing)
return cache_argtypes, falses(length(cache_argtypes))
end

function cache_lookup(linfo::MethodInstance, given_argtypes::Vector{Any}, cache::Vector{InferenceResult})
function cache_lookup(lattice::AbstractLattice, linfo::MethodInstance, given_argtypes::Vector{Any}, cache::Vector{InferenceResult})
method = linfo.def::Method
nargs::Int = method.nargs
method.isva && (nargs -= 1)
Expand All @@ -218,15 +219,15 @@ function cache_lookup(linfo::MethodInstance, given_argtypes::Vector{Any}, cache:
cache_argtypes = cached_result.argtypes
cache_overridden_by_const = cached_result.overridden_by_const
for i in 1:nargs
if !is_argtype_match(given_argtypes[i],
if !is_argtype_match(lattice, given_argtypes[i],
cache_argtypes[i],
cache_overridden_by_const[i])
cache_match = false
break
end
end
if method.isva && cache_match
cache_match = is_argtype_match(tuple_tfunc(given_argtypes[(nargs + 1):end]),
cache_match = is_argtype_match(lattice, tuple_tfunc(lattice, given_argtypes[(nargs + 1):end]),
cache_argtypes[end],
cache_overridden_by_const[end])
end
Expand Down
14 changes: 7 additions & 7 deletions base/compiler/optimize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ function inlining_policy(interp::AbstractInterpreter, @nospecialize(src), stmt_f
# inferred source in the local cache
# we still won't find a source for recursive call because the "single-level" inlining
# seems to be more trouble and complex than it's worth
inf_result = cache_lookup(mi, argtypes, get_inference_cache(interp))
inf_result = cache_lookup(optimizer_lattice(interp), mi, argtypes, get_inference_cache(interp))
inf_result === nothing && return nothing
src = inf_result.src
if isa(src, CodeInfo)
Expand Down Expand Up @@ -215,7 +215,7 @@ function stmt_effect_flags(@nospecialize(stmt), @nospecialize(rt), src::Union{IR
isa(stmt, PhiNode) && return (true, true)
isa(stmt, ReturnNode) && return (false, true)
isa(stmt, GotoNode) && return (false, true)
isa(stmt, GotoIfNot) && return (false, argextype(stmt.cond, src) ⊑ Bool)
isa(stmt, GotoIfNot) && return (false, argextype(stmt.cond, src) ⊑ Bool)
isa(stmt, Slot) && return (false, false) # Slots shouldn't occur in the IR at this point, but let's be defensive here
if isa(stmt, GlobalRef)
nothrow = isdefined(stmt.mod, stmt.name)
Expand Down Expand Up @@ -248,7 +248,7 @@ function stmt_effect_flags(@nospecialize(stmt), @nospecialize(rt), src::Union{IR
return (total, total)
end
rt === Bottom && return (false, false)
nothrow = _builtin_nothrow(f, Any[argextype(args[i], src) for i = 2:length(args)], rt)
nothrow = _builtin_nothrow(f, Any[argextype(args[i], src) for i = 2:length(args)], rt, OptimizerLattice())
nothrow || return (false, false)
return (contains_is(_EFFECT_FREE_BUILTINS, f), nothrow)
elseif head === :new
Expand All @@ -262,7 +262,7 @@ function stmt_effect_flags(@nospecialize(stmt), @nospecialize(rt), src::Union{IR
for fld_idx in 1:(length(args) - 1)
eT = argextype(args[fld_idx + 1], src)
fT = fieldtype(typ, fld_idx)
eT ⊑ fT || return (false, false)
eT ⊑ fT || return (false, false)
end
return (true, true)
elseif head === :foreigncall
Expand All @@ -277,11 +277,11 @@ function stmt_effect_flags(@nospecialize(stmt), @nospecialize(rt), src::Union{IR
typ = argextype(args[1], src)
typ, isexact = instanceof_tfunc(typ)
isexact || return (false, false)
typ ⊑ Tuple || return (false, false)
typ ⊑ Tuple || return (false, false)
rt_lb = argextype(args[2], src)
rt_ub = argextype(args[3], src)
source = argextype(args[4], src)
if !(rt_lb ⊑ Type && rt_ub ⊑ Type && source ⊑ Method)
if !(rt_lb ⊑ Type && rt_ub ⊑ Type && source ⊑ Method)
return (false, false)
end
return (true, true)
Expand Down Expand Up @@ -448,7 +448,7 @@ function finish(interp::AbstractInterpreter, opt::OptimizationState,
else
# compute the cost (size) of inlining this code
cost_threshold = default = params.inline_cost_threshold
if result ⊑ Tuple && !isconcretetype(widenconst(result))
if ⊑(optimizer_lattice(interp), result, Tuple) && !isconcretetype(widenconst(result))
cost_threshold += params.inline_tupleret_bonus
end
# if the method is declared as `@inline`, increase the cost threshold 20x
Expand Down
28 changes: 15 additions & 13 deletions base/compiler/ssair/inlining.jl
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,8 @@ function CFGInliningState(ir::IRCode)
)
end

⊑ₒ(@nospecialize(a), @nospecialize(b)) = ⊑(OptimizerLattice(), a, b)

# Tells the inliner that we're now inlining into block `block`, meaning
# all previous blocks have been processed and can be added to the new cfg
function inline_into_block!(state::CFGInliningState, block::Int)
Expand Down Expand Up @@ -381,7 +383,7 @@ function ir_inline_item!(compact::IncrementalCompact, idx::Int, argexprs::Vector
nonva_args = argexprs[1:end-1]
va_arg = argexprs[end]
tuple_call = Expr(:call, TOP_TUPLE, def, nonva_args...)
tuple_type = tuple_tfunc(Any[argextype(arg, compact) for arg in nonva_args])
tuple_type = tuple_tfunc(OptimizerLattice(), Any[argextype(arg, compact) for arg in nonva_args])
tupl = insert_node_here!(compact, NewInstruction(tuple_call, tuple_type, topline))
apply_iter_expr = Expr(:call, Core._apply_iterate, iterate, Core._compute_sparams, tupl, va_arg)
sparam_vals = insert_node_here!(compact,
Expand Down Expand Up @@ -476,7 +478,7 @@ function fix_va_argexprs!(compact::IncrementalCompact,
push!(tuple_call.args, arg)
push!(tuple_typs, argextype(arg, compact))
end
tuple_typ = tuple_tfunc(tuple_typs)
tuple_typ = tuple_tfunc(OptimizerLattice(), tuple_typs)
tuple_inst = NewInstruction(tuple_call, tuple_typ, line_idx)
push!(newargexprs, insert_node_here!(compact, tuple_inst))
return newargexprs
Expand Down Expand Up @@ -1080,8 +1082,8 @@ function inline_apply!(
nonempty_idx = 0
for i = (arg_start + 1):length(argtypes)
ti = argtypes[i]
ti ⊑ Tuple{} && continue
if ti ⊑ Tuple && nonempty_idx == 0
ti ⊑ Tuple{} && continue
if ti ⊑ Tuple && nonempty_idx == 0
nonempty_idx = i
continue
end
Expand Down Expand Up @@ -1123,9 +1125,9 @@ end
# TODO: this test is wrong if we start to handle Unions of function types later
is_builtin(s::Signature) =
isa(s.f, IntrinsicFunction) ||
s.ft ⊑ IntrinsicFunction ||
s.ft ⊑ IntrinsicFunction ||
isa(s.f, Builtin) ||
s.ft ⊑ Builtin
s.ft ⊑ Builtin

function inline_invoke!(
ir::IRCode, idx::Int, stmt::Expr, info::InvokeCallInfo, flag::UInt8,
Expand Down Expand Up @@ -1165,7 +1167,7 @@ function narrow_opaque_closure!(ir::IRCode, stmt::Expr, @nospecialize(info), sta
ub, exact = instanceof_tfunc(ubt)
exact || return
# Narrow opaque closure type
newT = widenconst(tmeet(tmerge(lb, info.unspec.rt), ub))
newT = widenconst(tmeet(OptimizerLattice(), tmerge(OptimizerLattice(), lb, info.unspec.rt), ub))
if newT != ub
# N.B.: Narrowing the ub requires a backdge on the mi whose type
# information we're using, since a change in that function may
Expand Down Expand Up @@ -1222,7 +1224,7 @@ function process_simple!(ir::IRCode, idx::Int, state::InliningState, todo::Vecto
ir.stmts[idx][:inst] = earlyres.val
return nothing
end
if (sig.f === modifyfield! || sig.ft ⊑ typeof(modifyfield!)) && 5 <= length(stmt.args) <= 6
if (sig.f === modifyfield! || sig.ft ⊑ typeof(modifyfield!)) && 5 <= length(stmt.args) <= 6
let info = ir.stmts[idx][:info]
info isa MethodResultPure && (info = info.info)
info isa ConstCallInfo && (info = info.call)
Expand All @@ -1240,7 +1242,7 @@ function process_simple!(ir::IRCode, idx::Int, state::InliningState, todo::Vecto
end

if check_effect_free!(ir, idx, stmt, rt)
if sig.f === typeassert || sig.ft ⊑ typeof(typeassert)
if sig.f === typeassert || sig.ft ⊑ typeof(typeassert)
# typeassert is a no-op if effect free
ir.stmts[idx][:inst] = stmt.args[2]
return nothing
Expand Down Expand Up @@ -1648,7 +1650,7 @@ function early_inline_special_case(
elseif ispuretopfunction(f) || contains_is(_PURE_BUILTINS, f)
return SomeCase(quoted(val))
elseif contains_is(_EFFECT_FREE_BUILTINS, f)
if _builtin_nothrow(f, argtypes[2:end], type)
if _builtin_nothrow(f, argtypes[2:end], type, OptimizerLattice())
return SomeCase(quoted(val))
end
elseif f === Core.get_binding_type
Expand Down Expand Up @@ -1694,17 +1696,17 @@ function late_inline_special_case!(
elseif length(argtypes) == 3 && istopfunction(f, :(>:))
# special-case inliner for issupertype
# that works, even though inference generally avoids inferring the `>:` Method
if isa(type, Const) && _builtin_nothrow(<:, Any[argtypes[3], argtypes[2]], type)
if isa(type, Const) && _builtin_nothrow(<:, Any[argtypes[3], argtypes[2]], type, OptimizerLattice())
return SomeCase(quoted(type.val))
end
subtype_call = Expr(:call, GlobalRef(Core, :(<:)), stmt.args[3], stmt.args[2])
return SomeCase(subtype_call)
elseif f === TypeVar && 2 <= length(argtypes) <= 4 && (argtypes[2] ⊑ Symbol)
elseif f === TypeVar && 2 <= length(argtypes) <= 4 && (argtypes[2] ⊑ Symbol)
typevar_call = Expr(:call, GlobalRef(Core, :_typevar), stmt.args[2],
length(stmt.args) < 4 ? Bottom : stmt.args[3],
length(stmt.args) == 2 ? Any : stmt.args[end])
return SomeCase(typevar_call)
elseif f === UnionAll && length(argtypes) == 3 && (argtypes[2] ⊑ TypeVar)
elseif f === UnionAll && length(argtypes) == 3 && (argtypes[2] ⊑ TypeVar)
unionall_call = Expr(:foreigncall, QuoteNode(:jl_type_unionall), Any, svec(Any, Any),
0, QuoteNode(:ccall), stmt.args[2], stmt.args[3])
return SomeCase(unionall_call)
Expand Down
Loading