diff --git a/src/JET.jl b/src/JET.jl index 9e923e068..f52042a9c 100644 --- a/src/JET.jl +++ b/src/JET.jl @@ -573,27 +573,23 @@ function transform_abstract_global_symbols!(interp::JETInterpreter, src::CodeInf end # TODO `profile_call_builtin!` ? -function profile_gf_by_type!(interp::JETInterpreter, - @nospecialize(tt::Type{<:Tuple}), - world::UInt = get_world_counter(interp), - ) - mms = _methods_by_ftype(tt, InferenceParams(interp).MAX_METHODS, world) +function profile_gf_by_type!(interp::JETInterpreter, @nospecialize(tt::Type{<:Tuple})) + mm = get_single_method_match(tt, InferenceParams(interp).MAX_METHODS, get_world_counter(interp)) + return profile_method_signature!(interp, mm.method, mm.spec_types, mm.sparams) +end + +function get_single_method_match(@nospecialize(tt), lim, world) + mms = _methods_by_ftype(tt, lim, world) @assert !isa(mms, Bool) "unable to find matching method for $(tt)" filter!(mm::MethodMatch->mm.spec_types===tt, mms) @assert length(mms) == 1 "unable to find single target method for $(tt)" - mm = first(mms)::MethodMatch - - return profile_method_signature!(interp, mm.method, mm.spec_types, mm.sparams) + return first(mms)::MethodMatch end -function profile_method!(interp::JETInterpreter, - m::Method, - world::UInt = get_world_counter(interp), - ) - return profile_method_signature!(interp, m, m.sig, method_sparams(m), world) -end +profile_method!(interp::JETInterpreter, m::Method) = + profile_method_signature!(interp, m, m.sig, method_sparams(m)) function method_sparams(m::Method) s = TypeVar[] @@ -609,7 +605,6 @@ function profile_method_signature!(interp::JETInterpreter, m::Method, @nospecialize(atype), sparams::SimpleVector, - world::UInt = get_world_counter(interp), ) mi = specialize_method(m, atype, sparams) diff --git a/src/abstractinterpretation.jl b/src/abstractinterpretation.jl index 04d4738b4..b26e9b93a 100644 --- a/src/abstractinterpretation.jl +++ b/src/abstractinterpretation.jl @@ -372,6 +372,9 @@ function abstract_call_gf_by_type(interp::$JETInterpreter, @nospecialize(f), delete!(sv.pclimitations, caller) end end + + $analyze_multithreading_pass!(interp, f, argtypes, sv) + #print("=> ", rettype, "\n") return CallMeta(rettype, info) end @@ -588,6 +591,9 @@ function abstract_call_gf_by_type(interp::$JETInterpreter, @nospecialize(f), arg end end end) + + $analyze_multithreading_pass!(interp, f, argtypes, sv) + return CallMeta(rettype, info) end @@ -644,6 +650,77 @@ end end # @static if IS_LATEST +# add special cased analysis pass for mutithreading code (xref: https://github.com/aviatesk/JET.jl/issues/114) +# we will handle multithreading with `Threads.@spawn` and `Threads.@threads` macros, +# in which threaded code is represented as a closure +# `NativeInterpreter` doesn't run type inference nor optimization on the body of those closures +# when compiling threading code, but JET will try to run additional analysis pass by recurring +# into the closures +# NOTE JET won't do anything other than doing JET analysis, e.g. won't annotate return type +# of threaded code block in order to not confuse the original `AbstractInterpreter` routine +# track https://github.com/JuliaLang/julia/pull/39773 for the changes in native abstract interpretation routine +function analyze_multithreading_pass!(interp::JETInterpreter, @nospecialize(f), argtypes::Vector{Any}, sv::InferenceState) + # special JET analysis pass for `Threads.@spawn` macro + if f === schedule + lin = get_lin(sv) + if lin.method === Symbol("macro expansion") && occursin("threadingconstructs.jl", string(lin.file)) + # find task construction, and try to get its inner function + for (pc, x) in enumerate(sv.src.code) + if @isexpr(x, :(=)) && (rhs = x.args[2]; @isexpr(rhs, :call)) + f = rhs.args[1] + if isa(f, GlobalRef) + v = CC.abstract_eval_global(f.mod, f.name) + if isa(v, Const) && v.val === Task + v = CC.abstract_eval_value(interp, rhs.args[2], sv.stmt_types[pc]::VarTable, sv) + # in `@spawn` macro, the closure can be a nullary lambda that + # really doesn't depend on the captured environment, and in that + # case we can retrieve a complete function object + # (otherwise, we will try to retrieve the type of the closure) + ft = (isa(v, Const) ? Core.Typeof(v.val) : + isa(v, Core.PartialStruct) ? v.typ : + isa(v, DataType) ? v : + return)::Type + return profile_additional_pass_by_type!(interp, Tuple{ft}, sv) + end + end + end + end + end + # special JET analysis pass for `Threads.@threads` macro + elseif f === Threads.threading_run + lin = get_lin(sv) + if lin.method === Symbol("macro expansion") && occursin("threadingconstructs.jl", string(lin.file)) + v = argtypes[2] + # in `@threads` macro, the closure always depends on the captured environment + # (i.e. threaded range and its indices), and so we only try to retrieve the type of the closure + ft = (isa(v, Core.PartialStruct) ? v.typ : + isa(v, DataType) ? v : + return)::Type + return profile_additional_pass_by_type!(interp, Tuple{ft}, sv) + end + end + return +end + +# run additional interpretation with a new interpreter, +# and then append the reports to the original interpreter +function profile_additional_pass_by_type!(interp::JETInterpreter, @nospecialize(tt::Type{<:Tuple}), sv::InferenceState) + newinterp = JETInterpreter(interp) + + # in order to preserve the inference termination, we keep to use the current frame + # and borrow the `AbstractInterpreter`'s cycle detection logic + # XXX the additional analysis pass by `abstract_call_method` may involve various site-effects, + # but what we're doing here is essentially equivalent to modifying the user code and inlining + # the threaded code block as a usual code block, and thus the side-effects won't (hopefully) + # confuse the abstract interpretation, which is supposed to terminate on any kind of code + # while we just run an additional analysis pass and don't record a result of the call + # for later, there still may be a risk to produce an invalid code after type inference + mm = get_single_method_match(tt, InferenceParams(newinterp).MAX_METHODS, get_world_counter(newinterp)) + abstract_call_method(newinterp, mm.method, mm.spec_types, mm.sparams, false, sv) + + append!(interp.reports, newinterp.reports) +end + """ function overload_abstract_call_method_with_const_args!() ... diff --git a/src/abstractinterpreterinterface.jl b/src/abstractinterpreterinterface.jl index 919b7bfab..cff77b775 100644 --- a/src/abstractinterpreterinterface.jl +++ b/src/abstractinterpreterinterface.jl @@ -103,6 +103,18 @@ const _CONCRETIZED = BitVector() const _TOPLEVELMOD = @__MODULE__ const _GLOBAL_SLOTS = Dict{Int,Symbol}() +# constructor to do additional JET analysis in the middle of parent (non-toplevel) interpretation +function JETInterpreter(interp::JETInterpreter) + return JETInterpreter(get_world_counter(interp); + current_frame = interp.current_frame, + cache = interp.cache, + analysis_params = AnalysisParams(interp), + inf_params = InferenceParams(interp), + opt_params = OptimizationParams(interp), + depth = interp.depth, + ) +end + function Base.show(io::IO, interp::JETInterpreter) rn = length(interp.reports) en = length(interp.uncaught_exceptions)