From c99b7ec4d1fc0dce1e77ae7a88a8711889aff848 Mon Sep 17 00:00:00 2001 From: Shuhei Kadowaki Date: Tue, 9 Mar 2021 00:18:21 +0900 Subject: [PATCH] fix #114, add special cased analysis pass for task parallelism In Julia's task parallelism implementation, parallel code is represented as closure and it's wrapped in `Task` object. `NativeInterpreter` doesn't run type inference nor optimization on the body of those closures when compiling code that creates parallel tasks, but JET will try to run additional analysis pass by recurring into the closures. NOTE JET won't do anything other than doing JET analysis, e.g. won't annotate return type of wrapped code block in order to not confuse the original `AbstractInterpreter` routine. --- src/JET.jl | 28 ++++++-------- src/abstractinterpretation.jl | 59 +++++++++++++++++++++++++++++ src/abstractinterpreterinterface.jl | 12 ++++++ 3 files changed, 83 insertions(+), 16 deletions(-) diff --git a/src/JET.jl b/src/JET.jl index 9e923e068..46b183521 100644 --- a/src/JET.jl +++ b/src/JET.jl @@ -108,7 +108,8 @@ import .CC: compute_basic_blocks, matching_cache_argtypes, is_argtype_match, - tuple_tfunc + tuple_tfunc, + abstract_eval_global import Base: parse_input_line, @@ -573,27 +574,23 @@ function transform_abstract_global_symbols!(interp::JETInterpreter, src::CodeInf end # TODO `profile_call_builtin!` ? -function profile_gf_by_type!(interp::JETInterpreter, - @nospecialize(tt::Type{<:Tuple}), - world::UInt = get_world_counter(interp), - ) - mms = _methods_by_ftype(tt, InferenceParams(interp).MAX_METHODS, world) +function profile_gf_by_type!(interp::JETInterpreter, @nospecialize(tt::Type{<:Tuple})) + mm = get_single_method_match(tt, InferenceParams(interp).MAX_METHODS, get_world_counter(interp)) + return profile_method_signature!(interp, mm.method, mm.spec_types, mm.sparams) +end + +function get_single_method_match(@nospecialize(tt), lim, world) + mms = _methods_by_ftype(tt, lim, world) @assert !isa(mms, Bool) "unable to find matching method for $(tt)" filter!(mm::MethodMatch->mm.spec_types===tt, mms) @assert length(mms) == 1 "unable to find single target method for $(tt)" - mm = first(mms)::MethodMatch - - return profile_method_signature!(interp, mm.method, mm.spec_types, mm.sparams) + return first(mms)::MethodMatch end -function profile_method!(interp::JETInterpreter, - m::Method, - world::UInt = get_world_counter(interp), - ) - return profile_method_signature!(interp, m, m.sig, method_sparams(m), world) -end +profile_method!(interp::JETInterpreter, m::Method) = + profile_method_signature!(interp, m, m.sig, method_sparams(m)) function method_sparams(m::Method) s = TypeVar[] @@ -609,7 +606,6 @@ function profile_method_signature!(interp::JETInterpreter, m::Method, @nospecialize(atype), sparams::SimpleVector, - world::UInt = get_world_counter(interp), ) mi = specialize_method(m, atype, sparams) diff --git a/src/abstractinterpretation.jl b/src/abstractinterpretation.jl index 04d4738b4..7e149e4c4 100644 --- a/src/abstractinterpretation.jl +++ b/src/abstractinterpretation.jl @@ -372,6 +372,9 @@ function abstract_call_gf_by_type(interp::$JETInterpreter, @nospecialize(f), delete!(sv.pclimitations, caller) end end + + $analyze_task_parallel_code!(interp, f, argtypes, sv) + #print("=> ", rettype, "\n") return CallMeta(rettype, info) end @@ -588,6 +591,9 @@ function abstract_call_gf_by_type(interp::$JETInterpreter, @nospecialize(f), arg end end end) + + $analyze_task_parallel_code!(interp, f, argtypes, sv) + return CallMeta(rettype, info) end @@ -644,6 +650,59 @@ end end # @static if IS_LATEST +# add special cased analysis pass for task parallelism (xref: https://github.com/aviatesk/JET.jl/issues/114) +# in Julia's task parallelism implementation, parallel code is represented as closure +# and it's wrapped in `Task` object +# `NativeInterpreter` doesn't run type inference nor optimization on the body of those closures +# when compiling code that creates parallel tasks, but JET will try to run additional +# analysis pass by recurring into the closures +# NOTE JET won't do anything other than doing JET analysis, e.g. won't annotate return type +# of wrapped code block in order to not confuse the original `AbstractInterpreter` routine +# track https://github.com/JuliaLang/julia/pull/39773 for the changes in native abstract interpretation routine +function analyze_task_parallel_code!(interp::JETInterpreter, @nospecialize(f), argtypes::Vector{Any}, sv::InferenceState) + # TODO ideally JET should analyze a closure wrapped in a `Task` only when encontering `schedule` call on it + # but the `Task` construction may not happen in the same frame where `schedule` is called, + # and so we may not be able to access to the closure at the point + # as a compromise, JET now invokes the additional analysis on `Task` construction, + # regardless of whether it's really `schedule`d or not + if f === Task + # if we encounter `Task(::Function)`, try to get its inner function and run analysis on it + # the closure can be a nullary lambda that really doesn't depend on + # the captured environment, and in that case we can retrieve it as + # a function object, otherwise we will try to retrieve the type of the closure + if length(argtypes) ≥ 2 + v = argtypes[2] + if v ⊑ Function + ft = (isa(v, Const) ? Core.Typeof(v.val) : + isa(v, Core.PartialStruct) ? v.typ : + isa(v, DataType) ? v : + return)::Type + return profile_additional_pass_by_type!(interp, Tuple{ft}, sv) + end + end + end + return +end + +# run additional interpretation with a new interpreter, +# and then append the reports to the original interpreter +function profile_additional_pass_by_type!(interp::JETInterpreter, @nospecialize(tt::Type{<:Tuple}), sv::InferenceState) + newinterp = JETInterpreter(interp) + + # in order to preserve the inference termination, we keep to use the current frame + # and borrow the `AbstractInterpreter`'s cycle detection logic + # XXX the additional analysis pass by `abstract_call_method` may involve various site-effects, + # but what we're doing here is essentially equivalent to modifying the user code and inlining + # the threaded code block as a usual code block, and thus the side-effects won't (hopefully) + # confuse the abstract interpretation, which is supposed to terminate on any kind of code + # while we just run an additional analysis pass and don't record a result of the call + # for later, there still may be a risk to produce an invalid code after type inference + mm = get_single_method_match(tt, InferenceParams(newinterp).MAX_METHODS, get_world_counter(newinterp)) + abstract_call_method(newinterp, mm.method, mm.spec_types, mm.sparams, false, sv) + + append!(interp.reports, newinterp.reports) +end + """ function overload_abstract_call_method_with_const_args!() ... diff --git a/src/abstractinterpreterinterface.jl b/src/abstractinterpreterinterface.jl index 919b7bfab..cff77b775 100644 --- a/src/abstractinterpreterinterface.jl +++ b/src/abstractinterpreterinterface.jl @@ -103,6 +103,18 @@ const _CONCRETIZED = BitVector() const _TOPLEVELMOD = @__MODULE__ const _GLOBAL_SLOTS = Dict{Int,Symbol}() +# constructor to do additional JET analysis in the middle of parent (non-toplevel) interpretation +function JETInterpreter(interp::JETInterpreter) + return JETInterpreter(get_world_counter(interp); + current_frame = interp.current_frame, + cache = interp.cache, + analysis_params = AnalysisParams(interp), + inf_params = InferenceParams(interp), + opt_params = OptimizationParams(interp), + depth = interp.depth, + ) +end + function Base.show(io::IO, interp::JETInterpreter) rn = length(interp.reports) en = length(interp.uncaught_exceptions)