Skip to content

Commit

Permalink
fix #114, add special cased analysis pass for task parallelism
Browse files Browse the repository at this point in the history
In Julia's task parallelism implementation, parallel code is represented
as closure and it's wrapped in `Task` object.
`NativeInterpreter` doesn't run type inference nor optimization on the
body of those closures when compiling code that creates parallel tasks,
but JET will try to run additional analysis pass by recurring into the
closures.

NOTE JET won't do anything other than doing JET analysis, e.g. won't
annotate return type of wrapped code block in order to not confuse
the original `AbstractInterpreter` routine.
  • Loading branch information
aviatesk committed Mar 8, 2021
1 parent 5904731 commit c99b7ec
Show file tree
Hide file tree
Showing 3 changed files with 83 additions and 16 deletions.
28 changes: 12 additions & 16 deletions src/JET.jl
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,8 @@ import .CC:
compute_basic_blocks,
matching_cache_argtypes,
is_argtype_match,
tuple_tfunc
tuple_tfunc,
abstract_eval_global

import Base:
parse_input_line,
Expand Down Expand Up @@ -573,27 +574,23 @@ function transform_abstract_global_symbols!(interp::JETInterpreter, src::CodeInf
end

# TODO `profile_call_builtin!` ?
function profile_gf_by_type!(interp::JETInterpreter,
@nospecialize(tt::Type{<:Tuple}),
world::UInt = get_world_counter(interp),
)
mms = _methods_by_ftype(tt, InferenceParams(interp).MAX_METHODS, world)
function profile_gf_by_type!(interp::JETInterpreter, @nospecialize(tt::Type{<:Tuple}))
mm = get_single_method_match(tt, InferenceParams(interp).MAX_METHODS, get_world_counter(interp))
return profile_method_signature!(interp, mm.method, mm.spec_types, mm.sparams)
end

function get_single_method_match(@nospecialize(tt), lim, world)
mms = _methods_by_ftype(tt, lim, world)
@assert !isa(mms, Bool) "unable to find matching method for $(tt)"

filter!(mm::MethodMatch->mm.spec_types===tt, mms)
@assert length(mms) == 1 "unable to find single target method for $(tt)"

mm = first(mms)::MethodMatch

return profile_method_signature!(interp, mm.method, mm.spec_types, mm.sparams)
return first(mms)::MethodMatch
end

function profile_method!(interp::JETInterpreter,
m::Method,
world::UInt = get_world_counter(interp),
)
return profile_method_signature!(interp, m, m.sig, method_sparams(m), world)
end
profile_method!(interp::JETInterpreter, m::Method) =
profile_method_signature!(interp, m, m.sig, method_sparams(m))

function method_sparams(m::Method)
s = TypeVar[]
Expand All @@ -609,7 +606,6 @@ function profile_method_signature!(interp::JETInterpreter,
m::Method,
@nospecialize(atype),
sparams::SimpleVector,
world::UInt = get_world_counter(interp),
)
mi = specialize_method(m, atype, sparams)

Expand Down
59 changes: 59 additions & 0 deletions src/abstractinterpretation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -372,6 +372,9 @@ function abstract_call_gf_by_type(interp::$JETInterpreter, @nospecialize(f),
delete!(sv.pclimitations, caller)
end
end

$analyze_task_parallel_code!(interp, f, argtypes, sv)

#print("=> ", rettype, "\n")
return CallMeta(rettype, info)
end
Expand Down Expand Up @@ -588,6 +591,9 @@ function abstract_call_gf_by_type(interp::$JETInterpreter, @nospecialize(f), arg
end
end
end)

$analyze_task_parallel_code!(interp, f, argtypes, sv)

return CallMeta(rettype, info)
end

Expand Down Expand Up @@ -644,6 +650,59 @@ end

end # @static if IS_LATEST

# add special cased analysis pass for task parallelism (xref: https://github.com/aviatesk/JET.jl/issues/114)
# in Julia's task parallelism implementation, parallel code is represented as closure
# and it's wrapped in `Task` object
# `NativeInterpreter` doesn't run type inference nor optimization on the body of those closures
# when compiling code that creates parallel tasks, but JET will try to run additional
# analysis pass by recurring into the closures
# NOTE JET won't do anything other than doing JET analysis, e.g. won't annotate return type
# of wrapped code block in order to not confuse the original `AbstractInterpreter` routine
# track https://github.com/JuliaLang/julia/pull/39773 for the changes in native abstract interpretation routine
function analyze_task_parallel_code!(interp::JETInterpreter, @nospecialize(f), argtypes::Vector{Any}, sv::InferenceState)
# TODO ideally JET should analyze a closure wrapped in a `Task` only when encontering `schedule` call on it
# but the `Task` construction may not happen in the same frame where `schedule` is called,
# and so we may not be able to access to the closure at the point
# as a compromise, JET now invokes the additional analysis on `Task` construction,
# regardless of whether it's really `schedule`d or not
if f === Task
# if we encounter `Task(::Function)`, try to get its inner function and run analysis on it
# the closure can be a nullary lambda that really doesn't depend on
# the captured environment, and in that case we can retrieve it as
# a function object, otherwise we will try to retrieve the type of the closure
if length(argtypes) 2
v = argtypes[2]
if v Function
ft = (isa(v, Const) ? Core.Typeof(v.val) :
isa(v, Core.PartialStruct) ? v.typ :
isa(v, DataType) ? v :
return)::Type
return profile_additional_pass_by_type!(interp, Tuple{ft}, sv)
end
end
end
return
end

# run additional interpretation with a new interpreter,
# and then append the reports to the original interpreter
function profile_additional_pass_by_type!(interp::JETInterpreter, @nospecialize(tt::Type{<:Tuple}), sv::InferenceState)
newinterp = JETInterpreter(interp)

# in order to preserve the inference termination, we keep to use the current frame
# and borrow the `AbstractInterpreter`'s cycle detection logic
# XXX the additional analysis pass by `abstract_call_method` may involve various site-effects,
# but what we're doing here is essentially equivalent to modifying the user code and inlining
# the threaded code block as a usual code block, and thus the side-effects won't (hopefully)
# confuse the abstract interpretation, which is supposed to terminate on any kind of code
# while we just run an additional analysis pass and don't record a result of the call
# for later, there still may be a risk to produce an invalid code after type inference
mm = get_single_method_match(tt, InferenceParams(newinterp).MAX_METHODS, get_world_counter(newinterp))
abstract_call_method(newinterp, mm.method, mm.spec_types, mm.sparams, false, sv)

append!(interp.reports, newinterp.reports)
end

"""
function overload_abstract_call_method_with_const_args!()
...
Expand Down
12 changes: 12 additions & 0 deletions src/abstractinterpreterinterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,18 @@ const _CONCRETIZED = BitVector()
const _TOPLEVELMOD = @__MODULE__
const _GLOBAL_SLOTS = Dict{Int,Symbol}()

# constructor to do additional JET analysis in the middle of parent (non-toplevel) interpretation
function JETInterpreter(interp::JETInterpreter)
return JETInterpreter(get_world_counter(interp);
current_frame = interp.current_frame,
cache = interp.cache,
analysis_params = AnalysisParams(interp),
inf_params = InferenceParams(interp),
opt_params = OptimizationParams(interp),
depth = interp.depth,
)
end

function Base.show(io::IO, interp::JETInterpreter)
rn = length(interp.reports)
en = length(interp.uncaught_exceptions)
Expand Down

0 comments on commit c99b7ec

Please sign in to comment.