Skip to content

Commit

Permalink
Inline invoke (take 3)
Browse files Browse the repository at this point in the history
Fix #9608
  • Loading branch information
yuyichao committed Oct 9, 2016
1 parent f6882fe commit 689235a
Show file tree
Hide file tree
Showing 2 changed files with 227 additions and 29 deletions.
172 changes: 145 additions & 27 deletions base/inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -681,8 +681,8 @@ function invoke_tfunc(f::ANY, types::ANY, argtype::ANY, sv::InferenceState)
return Any
end
meth = entry.func
(ti, env) = ccall(:jl_match_method, Any, (Any, Any, Any),
argtype, meth.sig, meth.tvars)::SimpleVector
(ti, env) = ccall(:jl_match_method, Ref{SimpleVector}, (Any, Any, Any),
argtype, meth.sig, meth.tvars)
return typeinf_edge(meth::Method, ti, env, sv)
end

Expand Down Expand Up @@ -2363,14 +2363,34 @@ end

#### post-inference optimizations ####

function inline_as_constant(val::ANY, argexprs, sv::InferenceState)
immutable InvokeData
mt::MethodTable
entry::TypeMapEntry
types0
fexpr
texpr
end

function inline_as_constant(val::ANY, argexprs, sv::InferenceState,
invoke_data::ANY)
if invoke_data === nothing
invoke_fexpr = nothing
invoke_texpr = nothing
else
invoke_data = invoke_data::InvokeData
invoke_fexpr = invoke_data.fexpr
invoke_texpr = invoke_data.texpr
end
# check if any arguments aren't effect_free and need to be kept around
stmts = Any[]
stmts = invoke_fexpr === nothing ? [] : Any[invoke_fexpr]
for i = 1:length(argexprs)
arg = argexprs[i]
if !effect_free(arg, sv.src, sv.mod, false)
push!(stmts, arg)
end
if i == 1 && !(invoke_texpr === nothing)
push!(stmts, invoke_texpr)
end
end
return (QuoteNode(val), stmts)
end
Expand All @@ -2385,10 +2405,31 @@ function countunionsplit(atypes::Vector{Any})
return nu
end

function invoke_NF(argexprs, etype::ANY, atypes, sv, atype_unlimited::ANY)
function get_spec_lambda(atypes::ANY, invoke_data::ANY)
if invoke_data === nothing
return ccall(:jl_get_spec_lambda, Any, (Any,), atypes)
else
invoke_data = invoke_data::InvokeData
# TODO compute intersection and throws an error
atypes <: invoke_data.types0 || return nothing
return ccall(:jl_get_invoke_lambda, Any, (Any, Any, Any),
invoke_data.mt, invoke_data.entry, atypes)
end
end

function invoke_NF(argexprs, etype::ANY, atypes, sv, atype_unlimited::ANY,
invoke_data::ANY)
# converts a :call to :invoke
nu = countunionsplit(atypes)
nu > MAX_UNION_SPLITTING && return NF
if invoke_data === nothing
invoke_fexpr = nothing
invoke_texpr = nothing
else
invoke_data = invoke_data::InvokeData
invoke_fexpr = invoke_data.fexpr
invoke_texpr = invoke_data.texpr
end

if nu > 1
spec_hit = nothing
Expand All @@ -2400,7 +2441,12 @@ function invoke_NF(argexprs, etype::ANY, atypes, sv, atype_unlimited::ANY)
ex.typ = etype
stmts = []
arg_hoisted = false
arg0_hoisted = false
for i = length(atypes):-1:1
if i == 1 && !(invoke_texpr === nothing)
unshift!(stmts, invoke_texpr)
arg_hoisted = true
end
ti = atypes[i]
if arg_hoisted || isa(ti, Union)
aei = ex.args[i]
Expand All @@ -2409,13 +2455,17 @@ function invoke_NF(argexprs, etype::ANY, atypes, sv, atype_unlimited::ANY)
newvar = newvar!(sv, ti)
unshift!(stmts, :($newvar = $aei))
ex.args[i] = newvar
if i == 1
arg0_hoisted = true
end
end
end
end
invoke_fexpr === nothing || unshift!(stmts, invoke_fexpr)
function splitunion(atypes::Vector{Any}, i::Int)
if i == 0
local sig = argtypes_to_type(atypes)
local li = ccall(:jl_get_spec_lambda, Any, (Any,), sig)
local li = get_spec_lambda(sig, invoke_data)
li === nothing && return false
local stmt = []
push!(stmt, Expr(:(=), linfo_var, li))
Expand Down Expand Up @@ -2483,13 +2533,24 @@ function invoke_NF(argexprs, etype::ANY, atypes, sv, atype_unlimited::ANY)
return (ret_var, stmts)
end
else
local cache_linfo = ccall(:jl_get_spec_lambda, Any, (Any,), atype_unlimited)
local cache_linfo = get_spec_lambda(atype_unlimited, invoke_data)
cache_linfo === nothing && return NF
unshift!(argexprs, cache_linfo)
ex = Expr(:invoke)
ex.args = argexprs
ex.typ = etype
return ex
if invoke_texpr === nothing
if invoke_fexpr === nothing
return ex
else
return ex, Any[invoke_fexpr]
end
end
newvar = newvar!(sv, atypes[1])
stmts = Any[invoke_fexpr, :($newvar = $(argexprs[1])),
invoke_texpr]
argexprs[1] = newvar
return ex, stmts
end
return NF
end
Expand Down Expand Up @@ -2543,41 +2604,94 @@ function inlineable(f::ANY, ft::ANY, e::Expr, atypes::Vector{Any}, sv::Inference
end
end
end
if isa(f, IntrinsicFunction) || ft IntrinsicFunction ||
invoke_data = nothing
invoke_fexpr = nothing
invoke_texpr = nothing
if f === Core.invoke && length(atypes) >= 3
ft = widenconst(atypes[2])
invoke_tt = widenconst(atypes[3])
if !isleaftype(ft) || !isleaftype(invoke_tt) || !isType(invoke_tt)
return NF
end
if !(isa(invoke_tt.parameters[1], Type) &&
invoke_tt.parameters[1] <: Tuple)
return NF
end
invoke_tt_params = invoke_tt.parameters[1].parameters
invoke_types = Tuple{ft, invoke_tt_params...}
invoke_entry = ccall(:jl_gf_invoke_lookup, Any, (Any,), invoke_types)
invoke_entry === nothing && return NF
invoke_fexpr = argexprs[1]
invoke_texpr = argexprs[3]
if effect_free(invoke_fexpr, sv.src, sv.mod, false)
invoke_fexpr = nothing
end
if effect_free(invoke_texpr, sv.src, sv.mod, false)
invoke_fexpr = nothing
end
invoke_data = InvokeData(ft.name.mt, invoke_entry,
invoke_types, invoke_fexpr, invoke_texpr)
atype0 = atypes[2]
argexpr0 = argexprs[2]
atypes = atypes[4:end]
argexprs = argexprs[4:end]
unshift!(atypes, atype0)
unshift!(argexprs, argexpr0)
f = isdefined(ft, :instance) ? ft.instance : nothing
elseif isa(f, IntrinsicFunction) || ft IntrinsicFunction ||
isa(f, Builtin) || ft Builtin
return NF
end

local atype_unlimited = argtypes_to_type(atypes)
atype_unlimited = argtypes_to_type(atypes)
if !(invoke_data === nothing)
invoke_data = invoke_data::InvokeData
# TODO emit a type check and proceed for this case
atype_unlimited <: invoke_data.types0 || return NF
end
if !sv.inlining
return invoke_NF(argexprs, e.typ, atypes, sv, atype_unlimited)
return invoke_NF(argexprs, e.typ, atypes, sv, atype_unlimited,
invoke_data)
end

if length(atype_unlimited.parameters) - 1 > MAX_TUPLETYPE_LEN
atype = limit_tuple_type(atype_unlimited)
else
atype = atype_unlimited
end
meth = _methods_by_ftype(atype, 1)
if meth === false || length(meth) != 1
return invoke_NF(argexprs, e.typ, atypes, sv, atype_unlimited)
if invoke_data === nothing
meth = _methods_by_ftype(atype, 1)
if meth === false || length(meth) != 1
return invoke_NF(argexprs, e.typ, atypes, sv,
atype_unlimited, invoke_data)
end
meth = meth[1]::SimpleVector
metharg = meth[1]::Type
methsp = meth[2]::SimpleVector
method = meth[3]::Method
else
invoke_data = invoke_data::InvokeData
method = invoke_data.entry.func
(metharg, methsp) = ccall(:jl_match_method, Ref{SimpleVector},
(Any, Any, Any),
atype_unlimited, method.sig, method.tvars)
methsp = methsp::SimpleVector
end
meth = meth[1]::SimpleVector
metharg = meth[1]::Type
methsp = meth[2]
method = meth[3]::Method
# check whether call can be inlined to just a quoted constant value
if isa(f, widenconst(ft)) && !method.isstaged && (method.source.pure || f === return_type)
if isconstType(e.typ,false)
return inline_as_constant(e.typ.parameters[1], argexprs, sv)
return inline_as_constant(e.typ.parameters[1], argexprs, sv,
invoke_data)
elseif isa(e.typ,Const)
return inline_as_constant(e.typ.val, argexprs, sv)
return inline_as_constant(e.typ.val, argexprs, sv,
invoke_data)
end
end

methsig = method.sig
if !(atype <: metharg)
return invoke_NF(argexprs, e.typ, atypes, sv, atype_unlimited)
return invoke_NF(argexprs, e.typ, atypes, sv, atype_unlimited,
invoke_data)
end

argexprs0 = argexprs
Expand Down Expand Up @@ -2653,11 +2767,12 @@ function inlineable(f::ANY, ft::ANY, e::Expr, atypes::Vector{Any}, sv::Inference

if isa(linfo, MethodInstance) && linfo.jlcall_api == 2
# in this case function can be inlined to a constant
return inline_as_constant(linfo.inferred, argexprs, sv)
return inline_as_constant(linfo.inferred, argexprs, sv, invoke_data)
end

if !isa(src, CodeInfo) || !src.inferred || !src.inlineable
return invoke_NF(argexprs0, e.typ, atypes, sv, atype_unlimited)
return invoke_NF(argexprs0, e.typ, atypes, sv, atype_unlimited,
invoke_data)
end
ast = src.code
rettype = linfo.rettype
Expand All @@ -2673,8 +2788,7 @@ function inlineable(f::ANY, ft::ANY, e::Expr, atypes::Vector{Any}, sv::Inference
end
end

methargs = metharg.parameters
nm = length(methargs)
nm = length(metharg.parameters)

if !isa(ast, Array{Any,1})
ast = ccall(:jl_uncompress_ast, Any, (Any, Any), method, ast)
Expand All @@ -2688,14 +2802,17 @@ function inlineable(f::ANY, ft::ANY, e::Expr, atypes::Vector{Any}, sv::Inference
propagate_inbounds = src.propagate_inbounds

# see if each argument occurs only once in the body expression
stmts = Any[]
prelude_stmts = Any[]
stmts = []
prelude_stmts = []
stmts_free = true # true = all entries of stmts are effect_free

for i=na:-1:1 # stmts_free needs to be calculated in reverse-argument order
#args_i = args[i]
aei = argexprs[i]
aeitype = argtype = widenconst(exprtype(aei, sv.src, sv.mod))
if i == 1 && !(invoke_texpr === nothing)
unshift!(prelude_stmts, invoke_texpr)
end

# ok for argument to occur more than once if the actual argument
# is a symbol or constant, or is not affected by previous statements
Expand Down Expand Up @@ -2729,6 +2846,7 @@ function inlineable(f::ANY, ft::ANY, e::Expr, atypes::Vector{Any}, sv::Inference
end
end
end
invoke_fexpr === nothing || unshift!(prelude_stmts, invoke_fexpr)

# re-number the SSAValues and copy their type-info to the new ast
ssavalue_types = src.ssavaluetypes
Expand Down
Loading

0 comments on commit 689235a

Please sign in to comment.