Skip to content

Commit

Permalink
fix #114, add special cased analysis pass for mutithreading code
Browse files Browse the repository at this point in the history
For the time being, we will add special analysis pass to analyze
multithreading code with `Threads.@spawn` and `Threads.@threads`
macros, in which threaded code is represented as a closure.
`NativeInterpreter` doesn't run type inference nor optimization on the
body of those closures when compiling threading code, but JET will try
to run additional analysis pass by recurring into the closures.

NOTE JET won't do anything other than doing JET analysis, e.g. won't
annotate return type of threaded code block in order to not confuse
the original `AbstractInterpreter` routine.
  • Loading branch information
aviatesk committed Mar 8, 2021
1 parent 5904731 commit 85d6124
Show file tree
Hide file tree
Showing 3 changed files with 99 additions and 15 deletions.
25 changes: 10 additions & 15 deletions src/JET.jl
Original file line number Diff line number Diff line change
Expand Up @@ -573,27 +573,23 @@ function transform_abstract_global_symbols!(interp::JETInterpreter, src::CodeInf
end

# TODO `profile_call_builtin!` ?
function profile_gf_by_type!(interp::JETInterpreter,
@nospecialize(tt::Type{<:Tuple}),
world::UInt = get_world_counter(interp),
)
mms = _methods_by_ftype(tt, InferenceParams(interp).MAX_METHODS, world)
function profile_gf_by_type!(interp::JETInterpreter, @nospecialize(tt::Type{<:Tuple}))
mm = get_single_method_match(tt, InferenceParams(interp).MAX_METHODS, get_world_counter(interp))
return profile_method_signature!(interp, mm.method, mm.spec_types, mm.sparams)
end

function get_single_method_match(@nospecialize(tt), lim, world)
mms = _methods_by_ftype(tt, lim, world)
@assert !isa(mms, Bool) "unable to find matching method for $(tt)"

filter!(mm::MethodMatch->mm.spec_types===tt, mms)
@assert length(mms) == 1 "unable to find single target method for $(tt)"

mm = first(mms)::MethodMatch

return profile_method_signature!(interp, mm.method, mm.spec_types, mm.sparams)
return first(mms)::MethodMatch
end

function profile_method!(interp::JETInterpreter,
m::Method,
world::UInt = get_world_counter(interp),
)
return profile_method_signature!(interp, m, m.sig, method_sparams(m), world)
end
profile_method!(interp::JETInterpreter, m::Method) =
profile_method_signature!(interp, m, m.sig, method_sparams(m))

function method_sparams(m::Method)
s = TypeVar[]
Expand All @@ -609,7 +605,6 @@ function profile_method_signature!(interp::JETInterpreter,
m::Method,
@nospecialize(atype),
sparams::SimpleVector,
world::UInt = get_world_counter(interp),
)
mi = specialize_method(m, atype, sparams)

Expand Down
77 changes: 77 additions & 0 deletions src/abstractinterpretation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -372,6 +372,9 @@ function abstract_call_gf_by_type(interp::$JETInterpreter, @nospecialize(f),
delete!(sv.pclimitations, caller)
end
end

$analyze_multithreading_pass!(interp, f, argtypes, sv)

#print("=> ", rettype, "\n")
return CallMeta(rettype, info)
end
Expand Down Expand Up @@ -588,6 +591,9 @@ function abstract_call_gf_by_type(interp::$JETInterpreter, @nospecialize(f), arg
end
end
end)

$analyze_multithreading_pass!(interp, f, argtypes, sv)

return CallMeta(rettype, info)
end

Expand Down Expand Up @@ -644,6 +650,77 @@ end

end # @static if IS_LATEST

# add special cased analysis pass for mutithreading code (xref: https://github.com/aviatesk/JET.jl/issues/114)
# we will handle multithreading with `Threads.@spawn` and `Threads.@threads` macros,
# in which threaded code is represented as a closure
# `NativeInterpreter` doesn't run type inference nor optimization on the body of those closures
# when compiling threading code, but JET will try to run additional analysis pass by recurring
# into the closures
# NOTE JET won't do anything other than doing JET analysis, e.g. won't annotate return type
# of threaded code block in order to not confuse the original `AbstractInterpreter` routine
# track https://github.com/JuliaLang/julia/pull/39773 for the changes in native abstract interpretation routine
function analyze_multithreading_pass!(interp::JETInterpreter, @nospecialize(f), argtypes::Vector{Any}, sv::InferenceState)
# special JET analysis pass for `Threads.@spawn` macro
if f === schedule
lin = get_lin(sv)
if lin.method === Symbol("macro expansion") && occursin("threadingconstructs.jl", string(lin.file))
# find task construction, and try to get its inner function
for (pc, x) in enumerate(sv.src.code)
if @isexpr(x, :(=)) && (rhs = x.args[2]; @isexpr(rhs, :call))
f = rhs.args[1]
if isa(f, GlobalRef)
v = CC.abstract_eval_global(f.mod, f.name)
if isa(v, Const) && v.val === Task
v = CC.abstract_eval_value(interp, rhs.args[2], sv.stmt_types[pc]::VarTable, sv)
# in `@spawn` macro, the closure can be a nullary lambda that
# really doesn't depend on the captured environment, and in that
# case we can retrieve a complete function object
# (otherwise, we will try to retrieve the type of the closure)
ft = (isa(v, Const) ? Core.Typeof(v.val) :
isa(v, Core.PartialStruct) ? v.typ :
isa(v, DataType) ? v :
return)::Type
return profile_additional_pass_by_type!(interp, Tuple{ft}, sv)
end
end
end
end
end
# special JET analysis pass for `Threads.@threads` macro
elseif f === Threads.threading_run
lin = get_lin(sv)
if lin.method === Symbol("macro expansion") && occursin("threadingconstructs.jl", string(lin.file))
v = argtypes[2]
# in `@threads` macro, the closure always depends on the captured environment
# (i.e. threaded range and its indices), and so we only try to retrieve the type of the closure
ft = (isa(v, Core.PartialStruct) ? v.typ :
isa(v, DataType) ? v :
return)::Type
return profile_additional_pass_by_type!(interp, Tuple{ft}, sv)
end
end
return
end

# run additional interpretation with a new interpreter,
# and then append the reports to the original interpreter
function profile_additional_pass_by_type!(interp::JETInterpreter, @nospecialize(tt::Type{<:Tuple}), sv::InferenceState)
newinterp = JETInterpreter(interp)

# in order to preserve the inference termination, we keep to use the current frame
# and borrow the `AbstractInterpreter`'s cycle detection logic
# XXX the additional analysis pass by `abstract_call_method` may involve various site-effects,
# but what we're doing here is essentially equivalent to modifying the user code and inlining
# the threaded code block as a usual code block, and thus the side-effects won't (hopefully)
# confuse the abstract interpretation, which is supposed to terminate on any kind of code
# while we just run an additional analysis pass and don't record a result of the call
# for later, there still may be a risk to produce an invalid code after type inference
mm = get_single_method_match(tt, InferenceParams(newinterp).MAX_METHODS, get_world_counter(newinterp))
abstract_call_method(newinterp, mm.method, mm.spec_types, mm.sparams, false, sv)

append!(interp.reports, newinterp.reports)
end

"""
function overload_abstract_call_method_with_const_args!()
...
Expand Down
12 changes: 12 additions & 0 deletions src/abstractinterpreterinterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,18 @@ const _CONCRETIZED = BitVector()
const _TOPLEVELMOD = @__MODULE__
const _GLOBAL_SLOTS = Dict{Int,Symbol}()

# constructor to do additional JET analysis in the middle of parent (non-toplevel) interpretation
function JETInterpreter(interp::JETInterpreter)
return JETInterpreter(get_world_counter(interp);
current_frame = interp.current_frame,
cache = interp.cache,
analysis_params = AnalysisParams(interp),
inf_params = InferenceParams(interp),
opt_params = OptimizationParams(interp),
depth = interp.depth,
)
end

function Base.show(io::IO, interp::JETInterpreter)
rn = length(interp.reports)
en = length(interp.uncaught_exceptions)
Expand Down

0 comments on commit 85d6124

Please sign in to comment.