Skip to content

Commit

Permalink
Put inference parameters in a dedicated object.
Browse files Browse the repository at this point in the history
This allows overriding parameters, while reducing global state
and simplifying the API.
  • Loading branch information
maleadt committed Oct 13, 2016
1 parent 6839451 commit 73f3de3
Show file tree
Hide file tree
Showing 3 changed files with 81 additions and 51 deletions.
3 changes: 2 additions & 1 deletion base/REPLCompletions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,8 @@ function get_type_call(expr::Expr)
length(mt) == 1 || return (Any, false)
m = first(mt)
# Typeinference
return_type = Core.Inference.typeinf_type(m[3], m[1], m[2])
params = Core.Inference.InferenceParams()
return_type = Core.Inference.typeinf_type(m[3], m[1], m[2], true, params)
return_type === nothing && return (Any, false)
return (return_type, true)
end
Expand Down
123 changes: 75 additions & 48 deletions base/inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,25 @@ import Core: _apply, svec, apply_type, Builtin, IntrinsicFunction, MethodInstanc
#### parameters limiting potentially-infinite types ####
const MAX_TYPEUNION_LEN = 3
const MAX_TYPE_DEPTH = 7
const MAX_TUPLETYPE_LEN = 15
const MAX_TUPLE_DEPTH = 4

const MAX_TUPLE_SPLAT = 16
const MAX_UNION_SPLITTING = 4
immutable InferenceParams
# optimization
inlining::Bool

# parameters limiting potentially-infinite types (configurable)
MAX_TUPLETYPE_LEN::Int
MAX_TUPLE_DEPTH::Int
MAX_TUPLE_SPLAT::Int
MAX_UNION_SPLITTING::Int

# reasonable defaults
InferenceParams(;inlining::Bool=inlining_enabled(),
tupletype_len::Int=15, tuple_depth::Int=4,
tuple_splat::Int=16, union_splitting::Int=4) =
new(inlining, tupletype_len,
tuple_depth, tuple_splat, union_splitting)
end

const UNION_SPLIT_MISMATCH_ERROR = false

# alloc_elim_pass! relies on `Slot_AssignedOnce | Slot_UsedUndef` being
Expand Down Expand Up @@ -49,6 +63,8 @@ type InferenceState
mod::Module
currpc::LineNum

params::InferenceParams

# info on the state of inference and the linfo
linfo::MethodInstance # used here for the tuple (specTypes, env, Method)
src::CodeInfo
Expand All @@ -74,14 +90,16 @@ type InferenceState
# iteration fixed-point detection
fixedpoint::Bool
inworkq::Bool
# optimization

# TODO: put these in InferenceParams (depends on proper multi-methodcache support)
optimize::Bool
inlining::Bool
cached::Bool

inferred::Bool

# src is assumed to be a newly-allocated CodeInfo, that can be modified in-place to contain intermediate results
function InferenceState(linfo::MethodInstance, src::CodeInfo, optimize::Bool, inlining::Bool, cached::Bool)
function InferenceState(linfo::MethodInstance, src::CodeInfo,
optimize::Bool, cached::Bool, params::InferenceParams)
code = src.code::Array{Any,1}
nl = label_counter(code) + 1
toplevel = !isdefined(linfo, :def)
Expand Down Expand Up @@ -113,7 +131,7 @@ type InferenceState
end
s[1][la] = VarState(Tuple, false)
else
s[1][la] = VarState(tuple_tfunc(limit_tuple_depth(tupletype_tail(atypes, la))), false)
s[1][la] = VarState(tuple_tfunc(limit_tuple_depth(params, tupletype_tail(atypes, la))), false)
end
la -= 1
end
Expand Down Expand Up @@ -168,12 +186,13 @@ type InferenceState
inmodule = toplevel ? current_module() : linfo.def.module # toplevel thunks are inferred in the current module
frame = new(
sp, nl, inmodule, 0,
params,
linfo, src, nargs, s, Union{}, W, n,
cur_hand, handler_at, n_handlers,
ssavalue_uses, ssavalue_init,
ObjectIdDict(), #Dict{InferenceState, Vector{LineNum}}(),
Vector{Tuple{InferenceState, Vector{LineNum}}}(),
false, false, optimize, inlining, cached, false)
false, false, optimize, cached, false)
push!(active, frame)
nactive[] += 1
return frame
Expand Down Expand Up @@ -669,7 +688,7 @@ function invoke_tfunc(f::ANY, types::ANY, argtype::ANY, sv::InferenceState)
if !isleaftype(Type{types})
return Any
end
argtype = typeintersect(types,limit_tuple_type(argtype))
argtype = typeintersect(types,limit_tuple_type(argtype, sv.params))
if is(argtype,Bottom)
return Bottom
end
Expand Down Expand Up @@ -703,7 +722,7 @@ function builtin_tfunction(f::ANY, argtypes::Array{Any,1}, sv::InferenceState)
if is(f,tuple)
for a in argtypes
if !isa(a, Const)
return tuple_tfunc(limit_tuple_depth(argtypes_to_type(argtypes)))
return tuple_tfunc(limit_tuple_depth(sv.params, argtypes_to_type(argtypes)))
end
end
return Const(tuple(map(a->a.val, argtypes)...))
Expand Down Expand Up @@ -766,28 +785,28 @@ function builtin_tfunction(f::ANY, argtypes::Array{Any,1}, sv::InferenceState)
return tf[3](argtypes...)
end

limit_tuple_depth(t::ANY) = limit_tuple_depth_(t,0)
limit_tuple_depth(params::InferenceParams, t::ANY) = limit_tuple_depth_(params,t,0)

function limit_tuple_depth_(t::ANY, d::Int)
function limit_tuple_depth_(params::InferenceParams, t::ANY, d::Int)
if isa(t,Union)
# also limit within Union types.
# may have to recur into other stuff in the future too.
return Union{map(x->limit_tuple_depth_(x,d+1), t.types)...}
return Union{map(x->limit_tuple_depth_(params,x,d+1), t.types)...}
end
if isa(t,TypeVar)
return limit_tuple_depth_(t.ub, d)
return limit_tuple_depth_(params, t.ub, d)
end
if !(isa(t,DataType) && t.name === Tuple.name)
return t
end
if d > MAX_TUPLE_DEPTH
if d > params.MAX_TUPLE_DEPTH
return Tuple
end
p = map(x->limit_tuple_depth_(x,d+1), t.parameters)
p = map(x->limit_tuple_depth_(params,x,d+1), t.parameters)
Tuple{p...}
end

limit_tuple_type = (t::ANY) -> limit_tuple_type_n(t, MAX_TUPLETYPE_LEN)
limit_tuple_type = (t::ANY, params::InferenceParams) -> limit_tuple_type_n(t, params.MAX_TUPLETYPE_LEN)

function limit_tuple_type_n(t::ANY, lim::Int)
p = t.parameters
Expand All @@ -802,7 +821,7 @@ end

#### recursing into expression ####

function abstract_call_gf_by_type(f::ANY, argtype::ANY, sv)
function abstract_call_gf_by_type(f::ANY, argtype::ANY, sv::InferenceState)
tm = _topmod(sv)
# don't consider more than N methods. this trades off between
# compiler performance and generated code performance.
Expand All @@ -811,7 +830,7 @@ function abstract_call_gf_by_type(f::ANY, argtype::ANY, sv)
# It is important for N to be >= the number of methods in the error()
# function, so we can still know that error() is always Bottom.
# here I picked 4.
argtype = limit_tuple_type(argtype)
argtype = limit_tuple_type(argtype, sv.params)
argtypes = argtype.parameters
applicable = _methods_by_ftype(argtype, 4)
rettype = Bottom
Expand Down Expand Up @@ -1006,9 +1025,10 @@ function abstract_apply(af::ANY, fargs, aargtypes::Vector{Any}, vtypes::VarTable
# can be collapsed to a call to the applied func
at = append_any(Any[type_typeof(af)], ctypes...)
n = length(at)
if n-1 > MAX_TUPLETYPE_LEN
tail = foldl((a,b)->tmerge(a,unwrapva(b)), Bottom, at[MAX_TUPLETYPE_LEN+1:n])
at = vcat(at[1:MAX_TUPLETYPE_LEN], Any[Vararg{tail}])
if n-1 > sv.params.MAX_TUPLETYPE_LEN
tail = foldl((a,b)->tmerge(a,unwrapva(b)), Bottom,
at[sv.params.MAX_TUPLETYPE_LEN+1:n])
at = vcat(at[1:sv.params.MAX_TUPLETYPE_LEN], Any[Vararg{tail}])
end
return abstract_call(af, (), at, vtypes, sv)
end
Expand All @@ -1024,6 +1044,7 @@ function pure_eval_call(f::ANY, argtypes::ANY, atype::ANY, vtypes::VarTable, sv:
end

if f === return_type && length(argtypes) == 3
# NOTE: only considering calls to return_type without InferenceParams arg
tt = argtypes[3]
if isType(tt)
af_argtype = tt.parameters[1]
Expand Down Expand Up @@ -1127,7 +1148,7 @@ function abstract_call(f::ANY, fargs, argtypes::Vector{Any}, vtypes::VarTable, s
return Type
end

if sv.inlining
if sv.params.inlining
# need to model the special inliner for ^
# to ensure we have added the same edge
if isdefined(Main, :Base) &&
Expand Down Expand Up @@ -1520,7 +1541,8 @@ end


# build (and start inferring) the inference frame for the linfo
function typeinf_frame(linfo::MethodInstance, optimize::Bool, cached::Bool, caller)
function typeinf_frame(linfo::MethodInstance, caller, optimize::Bool, cached::Bool,
params::InferenceParams)
frame = nothing
if linfo.inInference
# inference on this signature may be in progress,
Expand Down Expand Up @@ -1549,7 +1571,7 @@ function typeinf_frame(linfo::MethodInstance, optimize::Bool, cached::Bool, call
src = get_source(linfo)
end
linfo.inInference = true
frame = InferenceState(linfo, src, optimize, inlining_enabled(), cached)
frame = InferenceState(linfo, src, optimize, cached, params)
end
frame = frame::InferenceState

Expand Down Expand Up @@ -1590,7 +1612,7 @@ function typeinf_edge(method::Method, atypes::ANY, sparams::SimpleVector, caller
end
end
end
frame = typeinf_frame(code, true, true, caller)
frame = typeinf_frame(code, caller, true, true, caller.params)
frame === nothing && return Any
frame = frame::InferenceState
return frame.bestguess
Expand All @@ -1599,12 +1621,14 @@ end
#### entry points for inferring a MethodInstance given a type signature ####

# compute an inferred AST and return type
function typeinf_code(method::Method, atypes::ANY, sparams::SimpleVector, optimize::Bool, cached::Bool)
function typeinf_code(method::Method, atypes::ANY, sparams::SimpleVector,
optimize::Bool, cached::Bool, params::InferenceParams)
code = code_for_method(method, atypes, sparams)
code === nothing && return (nothing, Any)
return typeinf_code(code::MethodInstance, optimize, cached)
return typeinf_code(code::MethodInstance, optimize, cached, params)
end
function typeinf_code(linfo::MethodInstance, optimize::Bool, cached::Bool)
function typeinf_code(linfo::MethodInstance, optimize::Bool, cached::Bool,
params::InferenceParams)
for i = 1:2 # test-and-lock-and-test
if cached && isdefined(linfo, :inferred)
# see if this code already exists in the cache
Expand Down Expand Up @@ -1635,7 +1659,7 @@ function typeinf_code(linfo::MethodInstance, optimize::Bool, cached::Bool)
end
i == 1 && ccall(:jl_typeinf_begin, Void, ())
end
frame = typeinf_frame(linfo, optimize, cached, nothing)
frame = typeinf_frame(linfo, nothing, optimize, cached, params)
ccall(:jl_typeinf_end, Void, ())
frame === nothing && return (nothing, Any)
frame = frame::InferenceState
Expand All @@ -1644,7 +1668,8 @@ function typeinf_code(linfo::MethodInstance, optimize::Bool, cached::Bool)
end

# compute (and cache) an inferred AST and return the inferred return type
function typeinf_type(method::Method, atypes::ANY, sparams::SimpleVector, cached::Bool=true)
function typeinf_type(method::Method, atypes::ANY, sparams::SimpleVector,
cached::Bool, params::InferenceParams)
code = code_for_method(method, atypes, sparams)
code === nothing && return nothing
code = code::MethodInstance
Expand All @@ -1661,7 +1686,7 @@ function typeinf_type(method::Method, atypes::ANY, sparams::SimpleVector, cached
end
i == 1 && ccall(:jl_typeinf_begin, Void, ())
end
frame = typeinf_frame(code, cached, cached, nothing)
frame = typeinf_frame(code, nothing, cached, cached, params)
ccall(:jl_typeinf_end, Void, ())
frame === nothing && return nothing
frame = frame::InferenceState
Expand All @@ -1672,13 +1697,14 @@ end
function typeinf_ext(linfo::MethodInstance)
if isdefined(linfo, :def)
# method lambda - infer this specialization via the method cache
(code, typ) = typeinf_code(linfo, true, true)
(code, typ) = typeinf_code(linfo, true, true, InferenceParams())
return code
else
# toplevel lambda - infer directly
linfo.inInference = true
ccall(:jl_typeinf_begin, Void, ())
frame = InferenceState(linfo, linfo.inferred::CodeInfo, true, inlining_enabled(), true)
frame = InferenceState(linfo, linfo.inferred::CodeInfo,
true, true, InferenceParams())
typeinf_loop(frame)
ccall(:jl_typeinf_end, Void, ())
@assert frame.inferred # TODO: deal with this better
Expand Down Expand Up @@ -2138,7 +2164,6 @@ function type_annotate!(sv::InferenceState)
body = src.code::Array{Any,1}
nexpr = length(body)
i = 1
optimize = sv.optimize::Bool
while i <= nexpr
st_i = states[i]
expr = body[i]
Expand All @@ -2151,7 +2176,7 @@ function type_annotate!(sv::InferenceState)
id = expr.args[1].id
record_slot_type!(id, widenconst(states[i+1][id].typ), src.slottypes)
end
elseif optimize
elseif sv.optimize
if ((isa(expr, Expr) && is_meta_expr(expr::Expr)) ||
isa(expr, LineNumberNode))
i += 1
Expand Down Expand Up @@ -2426,7 +2451,7 @@ function inlineable(f::ANY, ft::ANY, e::Expr, atypes::Vector{Any}, sv::Inference
end
topmod = _topmod(sv)
# special-case inliners for known pure functions that compute types
if sv.inlining
if sv.params.inlining
if isconstType(e.typ,true)
if (is(f, apply_type) || is(f, fieldtype) || is(f, typeof) ||
istopfunction(topmod, f, :typejoin) ||
Expand Down Expand Up @@ -2460,7 +2485,7 @@ function inlineable(f::ANY, ft::ANY, e::Expr, atypes::Vector{Any}, sv::Inference
function invoke_NF()
# converts a :call to :invoke
local nu = countunionsplit(atypes)
nu > MAX_UNION_SPLITTING && return NF
nu > sv.params.MAX_UNION_SPLITTING && return NF

if nu > 1
local spec_hit = nothing
Expand Down Expand Up @@ -2560,12 +2585,12 @@ function inlineable(f::ANY, ft::ANY, e::Expr, atypes::Vector{Any}, sv::Inference
end
return NF
end
if !sv.inlining
if !sv.params.inlining
return invoke_NF()
end

if length(atype_unlimited.parameters) - 1 > MAX_TUPLETYPE_LEN
atype = limit_tuple_type(atype_unlimited)
if length(atype_unlimited.parameters) - 1 > sv.params.MAX_TUPLETYPE_LEN
atype = limit_tuple_type(atype_unlimited, sv.params)
else
atype = atype_unlimited
end
Expand Down Expand Up @@ -2646,7 +2671,7 @@ function inlineable(f::ANY, ft::ANY, e::Expr, atypes::Vector{Any}, sv::Inference
linfo = code_for_method(method, metharg, methsp)
end
if isa(linfo, MethodInstance)
frame = typeinf_frame(linfo::MethodInstance, true, true, nothing)
frame = typeinf_frame(linfo::MethodInstance, nothing, true, true, sv.params)
end
end
if isa(frame, InferenceState) && frame.inferred
Expand Down Expand Up @@ -3071,7 +3096,7 @@ function inlining_pass(e::Expr, sv::InferenceState)
end
end

if sv.inlining
if sv.params.inlining
if isdefined(Main, :Base) &&
((isdefined(Main.Base, :^) && is(f, Main.Base.:^)) ||
(isdefined(Main.Base, :.^) && is(f, Main.Base.:.^))) &&
Expand Down Expand Up @@ -3150,7 +3175,7 @@ function inlining_pass(e::Expr, sv::InferenceState)
elseif isa(aarg, Tuple)
newargs[i-2] = Any[ QuoteNode(x) for x in aarg ]
elseif isa(t, DataType) && t.name === Tuple.name && !isvatuple(t) &&
effect_free(aarg, sv.src, sv.mod, true) && length(t.parameters) <= MAX_TUPLE_SPLAT
effect_free(aarg, sv.src, sv.mod, true) && length(t.parameters) <= sv.params.MAX_TUPLE_SPLAT
# apply(f,t::(x,y)) => f(t[1],t[2])
tp = t.parameters
newargs[i-2] = Any[ mk_getfield(aarg,j,tp[j]) for j=1:length(tp) ]
Expand Down Expand Up @@ -3951,10 +3976,12 @@ function reindex_labels!(sv::InferenceState)
end
end

function return_type(f::ANY, t::ANY)
function return_type(f::ANY, t::ANY, params::InferenceParams=InferenceParams())
# NOTE: if not processed by pure_eval_call during inference, a call to return_type
# might use difference InferenceParams than the method it is contained in...
rt = Union{}
for m in _methods(f, t, -1)
ty = typeinf_type(m[3], m[1], m[2])
ty = typeinf_type(m[3], m[1], m[2], true, params)
ty === nothing && return Any
rt = tmerge(rt, ty)
rt === Any && break
Expand Down Expand Up @@ -3988,7 +4015,7 @@ let fs = Any[typeinf_ext, typeinf_loop, typeinf_edge, occurs_outside_getfield, e
typ[i] = typ[i].ub
end
end
typeinf_type(m[3], Tuple{typ...}, m[2])
typeinf_type(m[3], Tuple{typ...}, m[2], true, InferenceParams())
end
end
end
Loading

0 comments on commit 73f3de3

Please sign in to comment.