Skip to content

Commit

Permalink
fix #142, global variable woe, again
Browse files Browse the repository at this point in the history
1. trying to propagate types of a global variable when in a toplevel 
frame
2. don't propagate type of abstract global variables when in a local 
frame

With 2., JET analysis is now more consistent how Julia's native type
inference works, and we can eliminate the hacky logic to invalidate
a dummy cache associated with the previously analyzed abstract global's
type.
  • Loading branch information
aviatesk committed Mar 28, 2021
1 parent 326f659 commit ae53bda
Show file tree
Hide file tree
Showing 12 changed files with 200 additions and 164 deletions.
5 changes: 1 addition & 4 deletions src/JET.jl
Original file line number Diff line number Diff line change
Expand Up @@ -387,9 +387,6 @@ function is_constant_propagated(frame::InferenceState)
return !frame.cached && CC.any(frame.result.overridden_by_const)
end

@inline istoplevel(linfo::MethodInstance) = linfo.def === __virtual_toplevel__
@inline istoplevel(sv::InferenceState) = istoplevel(sv.linfo)

prewalk_inf_frame(@nospecialize(f), ::Nothing) = return
function prewalk_inf_frame(@nospecialize(f), frame::InferenceState)
ret = f(frame)
Expand Down Expand Up @@ -621,7 +618,7 @@ function analyze_toplevel!(interp::JETInterpreter, src::CodeInfo)
mi.specTypes = Tuple{}

transform_abstract_global_symbols!(interp, src)
mi.def = __virtual_toplevel__ # set to the dummy module
mi.def = interp.toplevelmod

result = InferenceResult(mi);
# toplevel frame doesn't need to be cached (and so it won't be optimized)
Expand Down
180 changes: 79 additions & 101 deletions src/abstractinterpretation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,41 +2,29 @@
mutable struct AbstractGlobal
# analyzed type
t::Any
# `id` of `JETInterpreter` that defined this
id::Symbol
# `Symbol` of a dummy generic function that generates dummy backedge (i.e. `li`)
edge_sym::Symbol
# dummy backedge, which will be invalidated on update of `t`
li::MethodInstance
# whether this abstract global variable is declarared as constant or not
iscd::Bool
end
Wraps a global variable whose type is analyzed by abtract interpretation.
`AbstractGlobal` object will be actually evaluated into the context module, and a later
analysis may refer to its type or alter it on another assignment.
On the refinement of the abstract global variable, the dummy backedge associated with it
will be invalidated, and inference depending on that will be re-run on the next analysis.
!!! note
The type of the wrapped global variable will be propagated only when in a toplevel frame,
and thus we don't care about the analysis cache invalidation on a refinement of the
wrapped global variable, since JET doesn't cache the toplevel frame.
"""
mutable struct AbstractGlobal
# analyzed type
t::Any
# `id` of `JETInterpreter` that lastly assigned this global variable
id::Symbol
# the name of a dummy generic function that generates dummy backedge `li`
edge_sym::Symbol
# dummy backedge associated to this global variable, which will be invalidated on update of `t`
li::MethodInstance
# whether this abstract global variable is declarared as constant or not
iscd::Bool

function AbstractGlobal(@nospecialize(t),
id::Symbol,
edge_sym::Symbol,
li::MethodInstance,
iscd::Bool,
)
return new(t, id, edge_sym, li, iscd)
return new(t,
iscd)
end
end

Expand Down Expand Up @@ -93,7 +81,7 @@ An overload for `abstract_call_gf_by_type(interp::JETInterpreter, ...)`, which k
[`virtual_process!`](@ref).
"""
function CC.bail_out_toplevel_call(interp::JETInterpreter, @nospecialize(sig), sv)
return isa(sv.linfo.def, Module) && !isdispatchtuple(sig) && !istoplevel(sv)
return isa(sv.linfo.def, Module) && !isdispatchtuple(sig) && !istoplevel(interp, sv)
end

@doc """
Expand Down Expand Up @@ -324,7 +312,8 @@ end
end # @static if isdefined(CC, :abstract_invoke)

function CC.abstract_eval_special_value(interp::JETInterpreter, @nospecialize(e), vtypes::VarTable, sv::InferenceState)
if istoplevel(sv)
toplevel = istoplevel(interp, sv)
if toplevel
if isa(e, Slot) && is_global_slot(interp, e)
if get_slottype(sv, e) === Bottom
# if this abstract global variable is not initialized, form the global
Expand All @@ -341,33 +330,30 @@ function CC.abstract_eval_special_value(interp::JETInterpreter, @nospecialize(e)

ret = @invoke abstract_eval_special_value(interp::AbstractInterpreter, e, vtypes::VarTable, sv::InferenceState)

if isa(ret, Const)
# unwrap abstract global variable to actual type
val = ret.val
if isa(val, AbstractGlobal)
# add dummy backedge, which will be invalidated on update of this vitual global variable
add_backedge!(val.li, sv)

ret = val.t
end
elseif isa(e, GlobalRef)
if isa(e, GlobalRef)
mod, name = e.mod, e.name
if isdefined(mod, name)
# we don't track types of global variables except when we're in toplevel frame,
# and here we just annotate this as `Any`; NOTE: this is becasue:
# - we can't track side-effects of assignments of global variables that happen in
# (possibly deeply nested) callees, and it might be possible if we just ignore
# assignments happens in callees that aren't reached by type inference by
# the widening heuristics
# - consistency with Julia's native type inference
# - it's hard to track side effects for cached frames

# TODO: add report pass here (for performance linting)

# special case and propagate `Main` module as constant
# XXX this was somewhat critical for accuracy and performance, but I'm not sure this still holds
if name === :Main
# special case and propagate `Main` module as constant
# XXX this was somewhat critical for accuracy and performance, but I'm not sure this still holds
ret = Const(Main)
elseif toplevel
# here we will eagerly propagate the type of this global variable
# of course the traced type might be difference from its type in actual execution
# e.g. we don't track a global variable assignment wrapped in a function,
# but it's highly possible this is a toplevel callsite and we need to take a
# risk here, otherwise we can't enter the analysis !
val = getfield(mod, name)
ret = isa(val, AbstractGlobal) ? val.t : Const(val)
else
# we don't track types of global variables except when we're in toplevel frame,
# and here `ret` should be just annotated as `Any`
# NOTE: this is becasue:
# - it's hard (or even impossible) to correctly track side-effects of global
# variable assignments that happen in (possibly deeply nested) frames
# - to be consistent with Julia's native type inference
# - it's hard to track side effects for cached frames
# TODO: add report pass here (for performance linting)
end
else
# report access to undefined global variable
Expand Down Expand Up @@ -422,7 +408,7 @@ function CC.abstract_eval_value(interp::JETInterpreter, @nospecialize(e), vtypes
end

function CC.abstract_eval_statement(interp::JETInterpreter, @nospecialize(e), vtypes::VarTable, sv::InferenceState)
if istoplevel(sv)
if istoplevel(interp, sv)
if interp.concretized[get_currpc(sv)]
return Any # bail out if it has been interpreted by `ConcreteInterpreter`
end
Expand All @@ -434,7 +420,7 @@ end
function CC.finish(me::InferenceState, interp::JETInterpreter)
@invoke finish(me::InferenceState, interp::AbstractInterpreter)

if istoplevel(me)
if istoplevel(interp, me)
# find assignments of abstract global variables, and assign types to them,
# so that later analysis can refer to them

Expand Down Expand Up @@ -518,70 +504,74 @@ function is_nondeterministic(pc, bbs)
end

function set_abstract_global!(interp, mod, name, @nospecialize(t), isnd, sv)
local update::Bool = false
id = get_id(interp)

prev_agv = nothing
prev_t = nothing
iscd = is_constant_declared(name, sv)

t′, id′, (edge_sym, li) = if isdefined(mod, name)
# check if this global variable is already assigned previously
if isdefined(mod, name)
val = getfield(mod, name)
if isa(val, AbstractGlobal)
t′ = val.t
if val.iscd && widenconst(t′) !== widenconst(t)
report!(interp, InvalidConstantRedefinition(interp, sv, mod, name, widenconst(t′), widenconst(t)))
prev_t = val.t
if val.iscd && widenconst(prev_t) !== widenconst(t)
report!(interp, InvalidConstantRedefinition(interp, sv, mod, name, widenconst(prev_t), widenconst(t)))
return
end

# update previously-defined abstract global variable
update = true
t′, val.id, (val.edge_sym, val.li)
prev_agv = val
else
prev_t = Core.Typeof(val)
if isconst(mod, name)
t′ = typeof(val)
if t′ !== widenconst(t)
report!(interp, InvalidConstantRedefinition(interp, sv, mod, name, t′, widenconst(t)))
invalid = prev_t !== widenconst(t)
if invalid || !isa(t, Const)
@warn """
JET.jl can't update the definition of this constant declared global variable: $mod.$name
This may fail, cause incorrect analysis, or produce unexpected errors.
"""
invalid && report!(interp, InvalidConstantRedefinition(interp, sv, mod, name, prev_t, widenconst(t)))
return
end
# otherwise, we can just redefine this constant, and Julia will warn it
ex = iscd ? :(const $name = $(QuoteNode(t.val))) : :($name = $(QuoteNode(t.val)))
return Core.eval(mod, ex)
end

# this pass hopefully won't happen within the current design
@warn "JET.jl can't trace updates of global variable that already have values" mod name val
return
end
else
# define new abstract global variable
Bottom, id, gen_dummy_backedge(mod)
end

# if this is constant declared and it's value is known to be constant, let's concretize
# it for good reasons; we will be able to use it in concrete interpretation and so
# this allows us to define structs with global type aliases, etc.
# XXX maybe check for constant declaration is not necessary
if isa(t, Const) && iscd
return Core.eval(mod, :(const $(name) = $(QuoteNode(t.val))))
isnew = isnothing(prev_t)

# if this global variable is known to be constant statically, let's concretize it for
# good reasons; we will be able to use it in concrete interpretation and so this allows
# us to define structs with type aliases, etc.
if isa(t, Const)
if iscd
if !isnew && prev_t !== widenconst(t)
@warn """
JET.jl can't update the definition of this constant declared global variable: $mod.$name
This may fail, cause incorrect analysis, or produce unexpected errors.
"""
report!(interp, InvalidConstantRedefinition(interp, sv, mod, name, t′, widenconst(t)))
return
end
return Core.eval(mod, :(const $name = $(QuoteNode(t.val))))
else
return Core.eval(mod, :($name = $(QuoteNode(t.val))))
end
end

# if this assignment happens in an non-deterministic way, we need to perform type merge
isnd && (t = tmerge(t′, t))

if id !== id′
# invalidate the dummy backedge that is bound to this abstract global variable,
# so that depending `MethodInstance` will run fresh type inference on the next hit
li = force_invalidate!(mod, edge_sym)
# if this assignment happens non-deterministically, we need to perform type merge
if !isnew
t = tmerge(prev_t, t)
end

ex = if update
quote let name = $name::$AbstractGlobal
name.t = $(t)
name.id = $(QuoteNode(id))
name.edge_sym = $(QuoteNode(edge_sym))
name.li = $(li)
# okay, we will define new abstract global variable from here on
if isa(prev_agv, AbstractGlobal)
return Core.eval(mod, :(let name = $name::$AbstractGlobal
name.t = $t
name
end end
end))
else
:(const $name = $(AbstractGlobal(t, id, edge_sym, li, iscd)))
return Core.eval(mod, :($name = $(AbstractGlobal(t, iscd))))
end
return Core.eval(mod, ex)
end

function is_constant_declared(name, sv)
Expand All @@ -596,15 +586,3 @@ function is_constant_declared(name, sv)
return false
end
end

function gen_dummy_backedge(mod)
edge_sym = gensym(:dummy_edge_sym)
return edge_sym, force_invalidate!(mod, edge_sym) # just generate dummy `MethodInstance` to be invalidated
end

# TODO: find a more fine-grained way to do this ? re-evaluating an entire function seems to be over-kill for this
function force_invalidate!(mod, edge_sym)
λ = Core.eval(mod, :($(edge_sym)() = return))::Function
m = first(methods(λ))
return specialize_method(m, Tuple{typeof(λ)}, svec())::MethodInstance
end
32 changes: 18 additions & 14 deletions src/abstractinterpreterinterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -266,15 +266,15 @@ mutable struct JETInterpreter <: AbstractInterpreter

## virtual toplevel execution ##

# for sequential assignment of abstract global variables
id::Symbol

# will be used in toplevel analysis (skip inference on actually interpreted statements)
concretized::BitVector

# virtual toplevel module
toplevelmod::Module

# toplevel modules concretized by JET (only active within sequential toplevel analysis)
toplevelmods::Set{Module}

# slots to represent toplevel global variables
global_slots::Dict{Int,Symbol}

Expand All @@ -298,10 +298,10 @@ end
analysis_params = nothing,
inf_params = nothing,
opt_params = nothing,
id = gensym(:JETInterpreterID),
concretized = _CONCRETIZED,
toplevelmod = _TOPLEVELMOD,
global_slots = _GLOBAL_SLOTS,
concretized = BitVector(),
toplevelmod = __toplevelmod__,
toplevelmods = Set{Module}(),
global_slots = Dict{Int,Symbol}(),
logger = nothing,
depth = 0,
jetconfigs...)
Expand All @@ -317,18 +317,17 @@ end
cache,
analysis_params,
false,
id,
concretized,
toplevelmod,
toplevelmods,
global_slots,
logger,
depth,
)
end
# dummies to interpret non-toplevel frames
const _CONCRETIZED = BitVector()
const _TOPLEVELMOD = @__MODULE__
const _GLOBAL_SLOTS = Dict{Int,Symbol}()

# dummies for non-toplevel analysis
module __toplevelmod__ end

# constructor for sequential toplevel JET analysis
function JETInterpreter(interp::JETInterpreter, concretized, toplevelmod)
Expand All @@ -340,6 +339,7 @@ function JETInterpreter(interp::JETInterpreter, concretized, toplevelmod)
opt_params = OptimizationParams(interp),
concretized = concretized, # or construct partial `CodeInfo` from remaining abstract statements ?
toplevelmod = toplevelmod,
toplevelmods = push!(interp.toplevelmods, toplevelmod),
logger = JETLogger(interp),
)
end
Expand Down Expand Up @@ -410,8 +410,6 @@ JETAnalysisParams(interp::JETInterpreter) = interp.analysis_params

JETLogger(interp::JETInterpreter) = interp.logger

get_id(interp::JETInterpreter) = interp.id

# TODO do report filtering or something configured by `JETAnalysisParams(interp)`
function report!(interp::JETInterpreter, report::InferenceErrorReport)
push!(interp.reports, report)
Expand All @@ -421,6 +419,12 @@ function stash_uncaught_exception!(interp::JETInterpreter, report::UncaughtExcep
push!(interp.uncaught_exceptions, report)
end

# check if we're in a toplevel module
@inline istoplevel(interp::JETInterpreter, sv::InferenceState) = istoplevel(interp, sv.linfo)
@inline istoplevel(interp::JETInterpreter, linfo::MethodInstance) = interp.toplevelmod === linfo.def

@inline istoplevelmod(interp::JETInterpreter, mod::Module) = mod in interp.toplevelmods

is_global_slot(interp::JETInterpreter, slot::Int) = slot in keys(interp.global_slots)
is_global_slot(interp::JETInterpreter, slot::Slot) = is_global_slot(interp, slot_id(slot))
is_global_slot(interp::JETInterpreter, sym::Symbol) = sym in values(interp.global_slots)
Expand Down
2 changes: 1 addition & 1 deletion src/legacy/abstractinterpretation
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ function abstract_call_gf_by_type(interp::$JETInterpreter, @nospecialize(f), arg
method = match.method
sig = match.spec_types
#=== abstract_call_gf_by_type patch point 2 start ===#
if istoplevel && !isdispatchtuple(sig) && !$istoplevel(sv) # keep going for "our" toplevel frame
if istoplevel && !isdispatchtuple(sig) && !$istoplevel(interp, sv) # keep going for "our" toplevel frame
#=== abstract_call_gf_by_type patch point 2 end ===#
# only infer concrete call sites in top-level expressions
add_remark!(interp, sv, "Refusing to infer non-concrete call site in top-level expression")
Expand Down
Loading

0 comments on commit ae53bda

Please sign in to comment.