Skip to content

Commit

Permalink
wip: cache reports with internal ocde cache
Browse files Browse the repository at this point in the history
  • Loading branch information
aviatesk committed Aug 27, 2020
1 parent b5757d6 commit beb8559
Show file tree
Hide file tree
Showing 6 changed files with 144 additions and 26 deletions.
8 changes: 7 additions & 1 deletion src/TypeProfiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@ import Core.Compiler:
# abstractinterpretation.jl
abstract_call_gf_by_type, abstract_call_known, abstract_call,
abstract_eval_special_value, abstract_eval_value_expr, abstract_eval_value,
abstract_eval_statement, builtin_tfunction, typeinf_local
abstract_eval_statement, builtin_tfunction, typeinf_local,
# tpcache.jl
WorldView

# usings
# ------
Expand All @@ -23,6 +25,7 @@ import Core:

import Core.Compiler:
AbstractInterpreter, NativeInterpreter, InferenceState, InferenceResult, CodeInfo,
InternalCodeCache, CodeInstance, WorldRange,
MethodInstance, Bottom, NOT_FOUND, MethodMatchInfo, UnionSplitInfo, MethodLookupResult,
Const, VarTable, SSAValue, SlotNumber, Slot, slot_id, GlobalRef, GotoIfNot, ReturnNode,
widenconst, isconstType, typeintersect, , Builtin, CallMeta,
Expand All @@ -39,6 +42,8 @@ import Base.Iterators:

using FileWatching, Requires

const CC = Core.Compiler

# includes
# --------

Expand All @@ -51,6 +56,7 @@ include("virtualprocess.jl")
include("abstractinterpreterinterface.jl")
include("abstractinterpretation.jl")
include("tfuncs.jl")
include("tpcache.jl")
include("print.jl")
include("profile.jl")
include("watch.jl")
Expand Down
15 changes: 10 additions & 5 deletions src/abstractinterpretation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ isroot(frame::InferenceState) = isnothing(frame.parent)
# report undef var error
function check_global_ref!(interp::TPInterpreter, sv::InferenceState, m::Module, s::Symbol)
return if !isdefined(m, s)
add_remark!(interp, sv, UndefVarErrorReport(sv, m, s))
add_remark!(interp, sv, UndefVarErrorReport(interp, sv, m, s))
true
else
false
Expand All @@ -45,13 +45,13 @@ function abstract_call_gf_by_type(interp::TPInterpreter, @nospecialize(f), argty
if isa(info.results, MethodLookupResult) && isempty(info.results.matches)
# no method match for this union split
# ret.rt = Bottom # maybe we want to be more strict on error cases ?
add_remark!(interp, sv, NoMethodErrorReport(sv, true))
add_remark!(interp, sv, NoMethodErrorReport(interp, sv, true))
end
end
elseif isa(info, MethodMatchInfo) && isa(info.results, MethodLookupResult) && isempty(info.results.matches)
# really no method found
typeassert(ret.rt, TypeofBottom) # return type is initialized as `Bottom`, and should never change in these passes
add_remark!(interp, sv, NoMethodErrorReport(sv, false))
add_remark!(interp, sv, NoMethodErrorReport(interp, sv, false))
end

return ret
Expand All @@ -67,7 +67,7 @@ function abstract_eval_special_value(interp::TPInterpreter, @nospecialize(e), vt
# t = vtypes[id].typ
# if t === NOT_FOUND || t === Bottom
# s = sv.src.slotnames[id]
# add_remark!(interp, sv, UndefVarErrorReport(sv, sv.mod, s))
# add_remark!(interp, sv, UndefVarErrorReport(interp, sv, sv.mod, s))
# end
elseif isa(e, GlobalRef)
vgv = getvirtualglobalvar(interp, e.mod, e.name)
Expand All @@ -89,7 +89,7 @@ function abstract_eval_value(interp::TPInterpreter, @nospecialize(e), vtypes::Va
if isa(stmt, GotoIfNot)
t = widenconst(ret)
if t !== Bottom && !(Bool, t)
add_remark!(interp, sv, NonBooleanCondErrorReport(sv, t))
add_remark!(interp, sv, NonBooleanCondErrorReport(interp, sv, t))
ret = Bottom
end
end
Expand All @@ -114,6 +114,8 @@ end
# virtual global assignments should happen here because `SlotNumber`s can be optimized away
# after the optimization happens
function typeinf_local(interp::TPInterpreter, frame::InferenceState)
set_current_frame!(interp, frame)

ret = invoke_native(typeinf_local, interp, frame)

# virtual global variable assignment
Expand All @@ -126,6 +128,9 @@ function typeinf_local(interp::TPInterpreter, frame::InferenceState)
return ret
end

set_current_frame!(interp::TPInterpreter, frame::InferenceState) = interp.frame[] = frame
get_current_frame(interp::TPInterpreter) = interp.frame[]

istoplevel(interp::TPInterpreter) = interp.istoplevel

function getvirtualglobalvar(interp, mod, sym)
Expand Down
16 changes: 13 additions & 3 deletions src/abstractinterpreterinterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ struct TPInterpreter <: AbstractInterpreter
discard_trees::Bool

# TypeProfiler.jl specific
frame::Ref{InferenceState}
istoplevel::Bool
virtualglobalvartable::Dict{Module,Dict{Symbol,Any}} # maybe we don't need this nested dicts

Expand All @@ -22,15 +23,24 @@ struct TPInterpreter <: AbstractInterpreter
virtualglobalvartable::AbstractDict = Dict()
)
native = NativeInterpreter(world; inf_params, opt_params)
return new(native, [], optimize, compress, discard_trees, istoplevel, virtualglobalvartable)
return new(native,
[],
optimize,
compress,
discard_trees,
Ref{InferenceState}(),
istoplevel,
virtualglobalvartable
)
end
end

InferenceParams(interp::TPInterpreter) = InferenceParams(interp.native)
OptimizationParams(interp::TPInterpreter) = OptimizationParams(interp.native)
get_world_counter(interp::TPInterpreter) = get_world_counter(interp.native)
get_inference_cache(interp::TPInterpreter) = get_inference_cache(interp.native)
code_cache(interp::TPInterpreter) = code_cache(interp.native)

code_cache(interp::TPInterpreter) = TPCache(interp, code_cache(interp.native))

# TP only works for runtime inference
lock_mi_inference(::TPInterpreter, ::MethodInstance) = nothing
Expand All @@ -41,7 +51,7 @@ function add_remark!(interp::TPInterpreter, ::InferenceState, report::InferenceE
return
end
function add_remark!(interp::TPInterpreter, sv::InferenceState, s::String)
add_remark!(interp, sv, NativeRemark(sv, s))
add_remark!(interp, sv, NativeRemark(interp, sv, s))
return
end

Expand Down
78 changes: 62 additions & 16 deletions src/reports.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,18 @@ const VirtualFrame = @NamedTuple begin
sig::String
end
const VirtualStackTrace = Vector{VirtualFrame}
const ViewedVirtualStackTrace = typeof(view(VirtualStackTrace(), :))

"""
const VirtualFrame = NamedTuple{(:file,:line,:sig),Tuple{Symbol,Int,String}}
const VirtualStackTrace = Vector{VirtualFrame}
const ViewedVirtualStackTrace = typeof(view(VirtualStackTrace(), :))
- `VirtualStackTrace` represents virtual back trace of profiled errors (supposed to be
ordered from call site to error point
- `ViewedVirtualStackTrace` is view of `VirtualStackTrace` and will be kept in [`TPCACHE`](@ref)
"""
VirtualFrame, VirtualStackTrace, ViewedVirtualStackTrace

# helps inference
function Base.getproperty(er::InferenceErrorReport, sym::Symbol)
Expand Down Expand Up @@ -95,51 +107,67 @@ struct NoMethodErrorReport <: InferenceErrorReport
msg::String
sig::String

function NoMethodErrorReport(sv::InferenceState, unionsplit)
st = track_abstract_call_stack!(sv)
function NoMethodErrorReport(interp, sv::InferenceState, unionsplit)
msg = get_msg(NoMethodErrorReport, sv, unionsplit)
sig = get_sig(sv)
st = track_abstract_call_stack!(sv) do sv, st
key = hash(sv.linfo)
TPCACHE[key] = hash(interp) => InferenceReportCache{NoMethodErrorReport}(view(st, :), msg, sig)
end
return new(st, msg, sig)
end
NoMethodErrorReport(args...) = new(args...)
end

struct InvalidBuiltinCallErrorReport <: InferenceErrorReport
st::VirtualStackTrace
msg::String
sig::String

function InvalidBuiltinCallErrorReport(sv::InferenceState)
st = track_abstract_call_stack!(sv)
function InvalidBuiltinCallErrorReport(interp, sv::InferenceState)
msg = get_msg(InvalidBuiltinCallErrorReport, sv)
sig = get_sig(sv)
st = track_abstract_call_stack!(sv) do sv, st
key = hash(sv.linfo)
TPCACHE[key] = hash(interp) => InferenceReportCache{InvalidBuiltinCallErrorReport}(view(st, :), msg, sig)
end
return new(st, msg, sig)
end
InvalidBuiltinCallErrorReport(args...) = new(args...)
end

struct UndefVarErrorReport <: InferenceErrorReport
st::VirtualStackTrace
msg::String
sig::String

function UndefVarErrorReport(sv::InferenceState, mod, name)
st = track_abstract_call_stack!(sv)
function UndefVarErrorReport(interp, sv::InferenceState, mod, name)
msg = get_msg(UndefVarErrorReport, sv, mod, name)
sig = get_sig(sv)
st = track_abstract_call_stack!(sv) do sv, st
key = hash(sv.linfo)
TPCACHE[key] = hash(interp) => InferenceReportCache{UndefVarErrorReport}(view(st, :), msg, sig)
end
return new(st, msg, sig)
end
UndefVarErrorReport(args...) = new(args...)
end

struct NonBooleanCondErrorReport <: InferenceErrorReport
st::VirtualStackTrace
msg::String
sig::String

function NonBooleanCondErrorReport(sv::InferenceState, @nospecialize(t))
st = track_abstract_call_stack!(sv)
function NonBooleanCondErrorReport(interp, sv::InferenceState, @nospecialize(t))
msg = get_msg(NonBooleanCondErrorReport, sv, t)
sig = get_sig(sv)
st = track_abstract_call_stack!(sv) do sv, st
key = hash(sv.linfo)
TPCACHE[key] = hash(interp) => InferenceReportCache{NonBooleanCondErrorReport}(view(st, :), msg, sig)
end
return new(st, msg, sig)
end
NonBooleanCondErrorReport(args...) = new(args...)
end

"""
Expand All @@ -153,24 +181,37 @@ struct NativeRemark <: InferenceErrorReport
msg::String
sig::String

function NativeRemark(sv::InferenceState, s)
st = track_abstract_call_stack!(sv)
function NativeRemark(interp, sv::InferenceState, s)
msg = get_msg(NativeRemark, sv, s)
sig = get_sig(sv)
st = track_abstract_call_stack!(sv) do sv, st
key = hash(sv.linfo)
TPCACHE[key] = hash(interp) => InferenceReportCache{NativeRemark}(view(st, :), msg, sig)
end
return new(st, msg, sig)
end
NativeRemark(args...) = new(args...)
end

# traces the current abstract call stack
function track_abstract_call_stack!(sv, st = VirtualFrame[])::VirtualStackTrace
isroot(sv) || track_abstract_call_stack!(sv.parent, st) # prewalk
function track_abstract_call_stack!(f, sv, st = VirtualFrame[])::VirtualStackTrace
sig = get_sig(sv)
file, line = get_file_line(sv)
push!(st, (; file, line, sig))
frame = (; file, line, sig)

pushfirst!(st, frame) # NOTE: change this to `push!` if this turns out to be a bottleneck
f(sv, st)

isroot(sv) || track_abstract_call_stack!(f, sv.parent, st) # postwalk

return st
end

function get_file_line(frame::InferenceState)
if length(frame.src.code) < get_cur_pc(frame)
return :FIXME, -1
end

loc = frame.src.codelocs[get_cur_pc(frame)]
linfo = frame.src.linetable[loc]
return linfo.file, linfo.line
Expand All @@ -194,7 +235,12 @@ function get_sig(linfo::MethodInstance)
end

# FIXME: obviously these implementations are not exhaustive
get_sig(sv::InferenceState) = get_sig(sv, get_cur_stmt(sv))
function get_sig(sv::InferenceState)
if length(sv.src.code) < get_cur_pc(sv)
return "FIXME"
end
get_sig(sv, get_cur_stmt(sv))
end
function get_sig(sv::InferenceState, expr::Expr)
head = expr.head
return if head === :call
Expand All @@ -213,8 +259,8 @@ function get_sig(sv::InferenceState, expr::Expr)
end
function get_sig(sv::InferenceState, ssa::SSAValue)
ssa_sig = get_sig(sv, sv.src.code[ssa.id])
typ = widenconst(sv.src.ssavaluetypes[ssa.id])
return string(ssa_sig, "::", typ)
typ = string(widenconst(sv.src.ssavaluetypes[ssa.id]))
return endswith(ssa_sig, typ) ? ssa_sig : string(ssa_sig, "::", typ)
end
function get_sig(sv::InferenceState, slot::SlotNumber)
slot_sig = string(sv.src.slotnames[slot.id])
Expand Down
2 changes: 1 addition & 1 deletion src/tfuncs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ function builtin_tfunction(interp::TPInterpreter, @nospecialize(f), argtypes::Ar
if f === throw
# TODO: needs a special case here
elseif ret === Bottom
add_remark!(interp, sv, InvalidBuiltinCallErrorReport(sv))
add_remark!(interp, sv, InvalidBuiltinCallErrorReport(interp, sv))
end

return ret
Expand Down
51 changes: 51 additions & 0 deletions src/tpcache.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# global cache
# ------------

struct InferenceReportCache{T<:InferenceErrorReport}
st::ViewedVirtualStackTrace
msg::String
sig::String
end
const TPCACHE = Dict{UInt,Pair{UInt,InferenceReportCache}}()

function get_report_cache(sv::InferenceState, cache::InferenceReportCache{T}) where {T<:InferenceErrorReport}
cur_st = track_abstract_call_stack!((args...)->nothing, sv)
st = append!(cur_st, cache.st)
return T(st, cache.msg, cache.sig)
end

# code cache interface
# --------------------

struct TPCache{NativeCache}
interp::TPInterpreter
native::NativeCache
TPCache(interp, native::NativeCache) where {NativeCache} = new{NativeCache}(interp, native)
end
WorldView(tpc::TPCache, wr::WorldRange) = TPCache(tpc.interp, WorldView(tpc.native, wr))

CC.haskey(tpc::TPCache, mi::MethodInstance) = CC.haskey(tpc.native, mi)

function CC.get(tpc::TPCache, mi::MethodInstance, default)
ret = CC.get(tpc.native, mi, default)

# cache hit
if ret !== default
key = hash(mi)
if haskey(TPCACHE, key)
interp_key, cache = TPCACHE[key]
interp = tpc.interp
if hash(interp) != interp_key
frame = get_current_frame(interp)
report = get_report_cache(frame, cache)
push!(interp.reports, report)
end
end
end

return ret
end

CC.getindex(tpc::TPCache, mi::MethodInstance) = CC.getindex(tpc.native, mi)

CC.setindex!(tpc::TPCache, ci::CodeInstance, mi::MethodInstance) = CC.setindex!(tpc.native, ci, mi)

0 comments on commit beb8559

Please sign in to comment.