Skip to content

Commit

Permalink
simplify caching logic
Browse files Browse the repository at this point in the history
  • Loading branch information
aviatesk committed Dec 15, 2022
1 parent 18344e2 commit 2c49256
Show file tree
Hide file tree
Showing 5 changed files with 74 additions and 148 deletions.
10 changes: 7 additions & 3 deletions src/abstractinterpret/abstractanalyzer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -531,23 +531,27 @@ get_reports(analyzer::AbstractAnalyzer, result::InferenceResult) = (analyzer[res
get_cached_reports(analyzer::AbstractAnalyzer, result::InferenceResult) = (analyzer[result]::JETCachedResult).reports
get_any_reports(analyzer::AbstractAnalyzer, result::InferenceResult) = (analyzer[result]::AnyJETResult).reports

# HACK to avoid runtime dispatch
@inline push_report!(reports::Vector{InferenceErrorReport}, @nospecialize(report::InferenceErrorReport)) =
@invoke push!(reports::Vector, report::InferenceErrorReport)

"""
add_new_report!(analyzer::AbstractAnalyzer, result::InferenceResult, report::InferenceErrorReport)
Adds new [`report::InferenceErrorReport`](@ref InferenceErrorReport) associated with `result::InferenceResult`.
"""
function add_new_report!(analyzer::AbstractAnalyzer, result::InferenceResult, @nospecialize(report::InferenceErrorReport))
push!(get_reports(analyzer, result), report)
push_report!(get_reports(analyzer, result), report)
return report
end

function add_cached_report!(analyzer::AbstractAnalyzer, caller::InferenceResult, @nospecialize(cached::InferenceErrorReport))
cached = copy_report′(cached)
push!(get_reports(analyzer, caller), cached)
push_report!(get_reports(analyzer, caller), cached)
return cached
end

add_caller_cache!(analyzer::AbstractAnalyzer, @nospecialize(report::InferenceErrorReport)) = push!(get_caller_cache(analyzer), report)
add_caller_cache!(analyzer::AbstractAnalyzer, @nospecialize(report::InferenceErrorReport)) = push_report!(get_caller_cache(analyzer), report)
add_caller_cache!(analyzer::AbstractAnalyzer, reports::Vector{InferenceErrorReport}) = append!(get_caller_cache(analyzer), reports)

# AbstractInterpreter
Expand Down
112 changes: 18 additions & 94 deletions src/abstractinterpret/typeinfer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -240,8 +240,8 @@ end
# cache
# =====

cache_report!(cache, @nospecialize(report::InferenceErrorReport)) =
push!(cache, copy_report′(report)::InferenceErrorReport)
cache_report!(cache::Vector{InferenceErrorReport}, @nospecialize(report::InferenceErrorReport)) =
push_report!(cache, copy_report′(report)::InferenceErrorReport)

struct AbstractAnalyzerView{Analyzer<:AbstractAnalyzer}
analyzer::Analyzer
Expand Down Expand Up @@ -340,6 +340,7 @@ end # @static if hasmethod(CC.transform_result_for_cache, (...))

function CC.transform_result_for_cache(analyzer::AbstractAnalyzer,
linfo::MethodInstance, valid_worlds::WorldRange, result::InferenceResult)
istoplevel(linfo) && return nothing
cache = InferenceErrorReport[]
for report in get_reports(analyzer, result)
@static if JET_DEV_MODE
Expand Down Expand Up @@ -543,104 +544,27 @@ function filter_lineages!(analyzer::AbstractAnalyzer, caller::InferenceResult, c
filter!(!islineage(caller.linfo, current), get_reports(analyzer, caller))
end

# in this overload we can work on `frame.src::CodeInfo` (and also `frame::InferenceState`)
# where type inference (and also optimization if applied) already ran on
function CC._typeinf(analyzer::AbstractAnalyzer, frame::InferenceState)
CC.typeinf_nocycle(analyzer, frame) || return false # frame is now part of a higher cycle
# with no active ip's, frame is done
frames = frame.callers_in_cycle
isempty(frames) && push!(frames, frame)
valid_worlds = WorldRange()
for caller in frames
@assert !(caller.dont_work_on_me)
caller.dont_work_on_me = true
# might might not fully intersect these earlier, so do that now
valid_worlds = CC.intersect(caller.valid_worlds, valid_worlds)
end
for caller in frames
caller.valid_worlds = valid_worlds
CC.finish(caller, analyzer)
# finalize and record the linfo result
caller.inferred = true
end
# NOTE we don't discard `InferenceState`s here so that some analyzers can use them in `finish!`
# # collect results for the new expanded frame
# results = Tuple{InferenceResult, Vector{Any}, Bool}[
# ( frames[i].result,
# frames[i].stmt_edges[1]::Vector{Any},
# frames[i].cached )
# for i in 1:length(frames) ]
# empty!(frames)
for frame in frames
caller = frame.result
opt = caller.src
if (@static VERSION v"1.9.0-DEV.1636" ?
(opt isa OptimizationState{typeof(analyzer)}) :
(opt isa OptimizationState))
CC.optimize(analyzer, opt, OptimizationParams(analyzer), caller)
# # COMBAK we may want to enable inlining ?
# if opt.const_api
# # XXX: The work in ir_to_codeinf! is essentially wasted. The only reason
# # we're doing it is so that code_llvm can return the code
# # for the `return ...::Const` (which never runs anyway). We should do this
# # as a post processing step instead.
# CC.ir_to_codeinf!(opt)
# if result_type isa Const
# caller.src = result_type
# else
# @assert CC.isconstType(result_type)
# caller.src = Const(result_type.parameters[1])
# end
# end
caller.valid_worlds = CC.getindex((opt.inlining.et::CC.EdgeTracker).valid_worlds)
end
end
function CC.finish!(analyzer::AbstractAnalyzer, caller::InferenceResult)
reports = get_reports(analyzer, caller)

for frame in frames
caller = frame.result
edges = frame.stmt_edges[1]::Vector{Any}
cached = frame.cached
valid_worlds = caller.valid_worlds
if CC.last(valid_worlds) >= get_world_counter()
# if we aren't cached, we don't need this edge
# but our caller might, so let's just make it anyways
CC.store_backedges(caller, edges)
end
CC.finish!(analyzer, frame)

reports = get_reports(analyzer, caller)
# XXX this is a dirty fix for performance problem, we need more "proper" fix
# https://github.com/aviatesk/JET.jl/issues/75
unique!(aggregation_policy(analyzer), reports)

# XXX this is a dirty fix for performance problem, we need more "proper" fix
# https://github.com/aviatesk/JET.jl/issues/75
unique!(aggregation_policy(analyzer), reports)
if get_entry(analyzer) !== caller.linfo
# inter-procedural handling: get back to the caller what we got from these results
add_caller_cache!(analyzer, reports)

# global cache management
if cached && !istoplevel(frame)
CC.cache_result!(analyzer, caller)
end

if frame.parent !== nothing
# inter-procedural handling: get back to the caller what we got from these results
add_caller_cache!(analyzer, reports)

# local cache management
# TODO there are duplicated work here and `transform_result_for_cache`
cache = InferenceErrorReport[]
for report in reports
cache_report!(cache, report)
end
set_cached_result!(analyzer, caller, cache)
# local cache management
# TODO there are duplicated work here and `transform_result_for_cache`
cache = InferenceErrorReport[]
for report in reports
cache_report!(cache, report)
end
set_cached_result!(analyzer, caller, cache)
end

return true
end

# by default, this overload just is forwarded to the AbstractInterpreter's implementation
# but the only reason we have this overload is that some analyzers (like `JETAnalyzer`)
# can further overload this to generate `InferenceErrorReport` with an access to `frame`
function CC.finish!(analyzer::AbstractAnalyzer, frame::InferenceState)
return CC.finish!(analyzer, frame.result)
return @invoke CC.finish!(analyzer::AbstractInterpreter, caller::InferenceResult)
end

# top-level bridge
Expand Down
84 changes: 44 additions & 40 deletions src/analyzers/jetanalyzer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -167,18 +167,17 @@ function CC.InferenceState(result::InferenceResult, cache::Symbol, analyzer::JET
return frame
end

function CC.finish!(analyzer::JETAnalyzer, frame::InferenceState)
src = @invoke CC.finish!(analyzer::AbstractAnalyzer, frame::InferenceState)

if isnothing(src)
# caught in cycle, similar error should have been reported where the source is available
return src
else
code = (src::CodeInfo).code
function CC.finish!(analyzer::JETAnalyzer, caller::InferenceResult)
src = @invoke CC.finish!(analyzer::AbstractInterpreter, caller::InferenceResult)
if src isa CodeInfo
# report pass for uncaught `throw` calls
ReportPass(analyzer)(UncaughtExceptionReport, analyzer, frame, code)
return src
ReportPass(analyzer)(UncaughtExceptionReport, analyzer, caller, src)
else
# very much optimized (nothing to report), or very much unoptimized:
# in a case of the latter, similar error should have been reported
# where the source is available
end
return @invoke CC.finish!(analyzer::AbstractAnalyzer, caller::InferenceResult)
end

let # overload `abstract_call_gf_by_type`
Expand Down Expand Up @@ -484,56 +483,60 @@ end
Represents general `throw` calls traced during inference.
This is reported only when it's not caught by control flow.
"""
@jetreport struct UncaughtExceptionReport <: InferenceErrorReport
throw_calls::Vector{Tuple{Int,Expr}} # (pc, call)
end
function UncaughtExceptionReport(sv::InferenceState, throw_calls::Vector{Tuple{Int,Expr}})
vf = get_virtual_frame(sv.linfo)
sig = Any[]
ncalls = length(throw_calls)
for (i, (pc, call)) in enumerate(throw_calls)
call_sig = get_sig_nowrap((sv, pc), call)
append!(sig, call_sig)
i ncalls && push!(sig, ", ")
end
return UncaughtExceptionReport([vf], Signature(sig), throw_calls)
end
function print_report_message(io::IO, (; throw_calls)::UncaughtExceptionReport)
msg = length(throw_calls) == 1 ? "may throw" : "may throw either of"
print(io, msg)
end
@jetreport struct UncaughtExceptionReport <: InferenceErrorReport end
print_report_message(io::IO, ::UncaughtExceptionReport) = print(io, "may throw")
print_signature(::UncaughtExceptionReport) = false

# @jetreport struct UncaughtExceptionReport <: InferenceErrorReport
# throw_calls::Vector{Tuple{Int,Expr}} # (pc, call)
# end
# function UncaughtExceptionReport(caller::InferenceResult, throw_calls::Vector{Tuple{Int,Expr}})
# vf = get_virtual_frame(caller.linfo)
# sig = Any[]
# ncalls = length(throw_calls)
# for (i, (pc, call)) in enumerate(throw_calls)
# call_sig = get_sig_nowrap((caller.src::CodeInfo, pc), call)
# append!(sig, call_sig)
# i ≠ ncalls && push!(sig, ", ")
# end
# return UncaughtExceptionReport([vf], Signature(sig), throw_calls)
# end
# function print_report_message(io::IO, (; throw_calls)::UncaughtExceptionReport)
# msg = length(throw_calls) == 1 ? "may throw" : "may throw either of"
# print(io, msg)
# end

# report `throw` calls "appropriately"
# this error report pass is very special, since 1.) it's tightly bound to the report pass of
# `SeriousExceptionReport` and 2.) it involves "report filtering" on its own
function (::BasicPass)(::Type{UncaughtExceptionReport}, analyzer::JETAnalyzer, frame::InferenceState, stmts::Vector{Any})
if frame.bestguess === Bottom
report_uncaught_exceptions!(analyzer, frame, stmts)
function (::BasicPass)(::Type{UncaughtExceptionReport}, analyzer::JETAnalyzer, caller::InferenceResult, src::CodeInfo)
if caller.result === Bottom
report_uncaught_exceptions!(analyzer, caller, src)
return true
else
# the non-`Bottom` result may mean `throw` calls from the children frames
# (if exists) are caught and not propagated here
# we don't want to cache the caught `UncaughtExceptionReport`s for this frame and
# its parents, and just filter them away now
filter!(get_reports(analyzer, frame.result)) do @nospecialize(report::InferenceErrorReport)
filter!(get_reports(analyzer, caller)) do @nospecialize(report::InferenceErrorReport)
return !isa(report, UncaughtExceptionReport)
end
end
return false
end
(::SoundPass)(::Type{UncaughtExceptionReport}, analyzer::JETAnalyzer, frame::InferenceState, stmts::Vector{Any}) =
report_uncaught_exceptions!(analyzer, frame, stmts) # yes, you want tons of false positives !
function report_uncaught_exceptions!(analyzer::JETAnalyzer, frame::InferenceState, stmts::Vector{Any})
(::SoundPass)(::Type{UncaughtExceptionReport}, analyzer::JETAnalyzer, caller::InferenceResult, src::CodeInfo) =
report_uncaught_exceptions!(analyzer, caller, src) # yes, you want tons of false positives !
function report_uncaught_exceptions!(analyzer::JETAnalyzer, caller::InferenceResult, src::CodeInfo)
# if the return type here is `Bottom` annotated, this _may_ mean there're uncaught
# `throw` calls
# XXX it's possible that the `throw` calls within them are all caught but the other
# critical errors still make the return type `Bottom`
# NOTE to reduce the false positive cases described above, we count `throw` calls
# after optimization, since it may have eliminated "unreachable" `throw` calls
codelocs = frame.src.codelocs
linetable = frame.src.linetable::LineTable
codelocs = src.codelocs
linetable = src.linetable::LineTable
reported_locs = nothing
for report in get_reports(analyzer, frame.result)
for report in get_reports(analyzer, caller)
if isa(report, SeriousExceptionReport)
if isnothing(reported_locs)
reported_locs = LineInfoNode[]
Expand All @@ -542,7 +545,7 @@ function report_uncaught_exceptions!(analyzer::JETAnalyzer, frame::InferenceStat
end
end
throw_calls = nothing
for (pc, stmt) in enumerate(stmts)
for (pc, stmt) in enumerate(src.code)
isa(stmt, Expr) || continue
is_throw_call(stmt) || continue
# if this `throw` is already reported, don't duplciate
Expand All @@ -555,7 +558,8 @@ function report_uncaught_exceptions!(analyzer::JETAnalyzer, frame::InferenceStat
push!(throw_calls, (pc, stmt))
end
if !isnothing(throw_calls) && !isempty(throw_calls)
add_new_report!(analyzer, frame.result, UncaughtExceptionReport(frame, throw_calls))
# TODO add_new_report!(analyzer, caller, UncaughtExceptionReport(caller, throw_calls))
add_new_report!(analyzer, caller, UncaughtExceptionReport(caller))
return true
end
return false
Expand Down
14 changes: 4 additions & 10 deletions src/analyzers/optanalyzer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ struct OptAnalysisPass <: ReportPass end

optanalyzer_function_filter(@nospecialize ft) = true

# TODO better to work only `finish!`
# TODO better to work only `finish!`, i.e. only work on `CodeInfo` (with static parameters)
function CC.finish(frame::InferenceState, analyzer::OptAnalyzer)
ret = @invoke CC.finish(frame::InferenceState, analyzer::AbstractAnalyzer)

Expand Down Expand Up @@ -271,20 +271,15 @@ function (::OptAnalysisPass)(::Type{CapturedVariableReport}, analyzer::OptAnalyz
return reported
end

function CC.finish!(analyzer::OptAnalyzer, frame::InferenceState)
caller = frame.result

function CC.finish!(analyzer::OptAnalyzer, caller::InferenceResult)
# get the source before running `finish!` to keep the reference to `OptimizationState`
src = caller.src

ret = @invoke CC.finish!(analyzer::AbstractAnalyzer, frame::InferenceState)

if popfirst!(analyzer.__analyze_frame)
ReportPass(analyzer)(OptimizationFailureReport, analyzer, caller)

if (@static VERSION v"1.9.0-DEV.1636" ?
(src isa OptimizationState{typeof(analyzer)}) :
(src isa OptimizationState)) # the compiler optimized it, analyze it
src.ir === nothing || CC.ir_to_codeinf!(src)
ReportPass(analyzer)(RuntimeDispatchReport, analyzer, caller, src)
elseif (@static JET_DEV_MODE ? true : false)
if isa(src, CC.ConstAPI)
Expand All @@ -297,8 +292,7 @@ function CC.finish!(analyzer::OptAnalyzer, frame::InferenceState)
end
end
end

return ret
return @invoke CC.finish!(analyzer::AbstractAnalyzer, caller::InferenceResult)
end

# report optimization failure due to recursive calls, etc.
Expand Down
2 changes: 1 addition & 1 deletion test/abstractinterpret/test_inferenceerrorreport.jl
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ end
result = report_call(m.foo, (String,))
r = only(get_reports_with_test(result))
@test isa(r, UncaughtExceptionReport)
@test Any['(', 's', String, ')', ArgumentError] r.sig._sig
@test_broken Any['(', 's', String, ')', ArgumentError] r.sig._sig
end

sparams1(::Type{T}) where T = zero(T)
Expand Down

0 comments on commit 2c49256

Please sign in to comment.