diff --git a/base/inference.jl b/base/inference.jl index 7231b31165c4f..41bddae0c65f7 100644 --- a/base/inference.jl +++ b/base/inference.jl @@ -2369,6 +2369,36 @@ immutable InvokeData types0 fexpr texpr + orig_atypes +end + +const invoke_arg_error = ErrorException("invoke: argument type error") + +# Assume `argexprs` that we need are all `effect_free` +function emit_invoke_typecheck!(argexprs, invoke_data::InvokeData, stmts, sv) + nargs = length(argexprs) + orig_atypes = invoke_data.orig_atypes::Vector{Any} + expect_types = invoke_data.types0.parameters + pass = genlabel(sv) + fail = genlabel(sv) + for i in 1:nargs + orig_t = orig_atypes[i] + orig_t === nothing && continue + arg = argexprs[i] + expect = expect_types[i] + check = Expr(:call, GlobalRef(Core, :isa), arg, expect) + push!(stmts, Expr(:gotoifnot, check, fail.label)) + # Rewrite the type so that the user (codegen) will see the real type + newvar = newvar!(sv, typeintersect(orig_t, expect)) + push!(stmts, :($newvar = $arg)) + argexprs[i] = newvar + end + push!(stmts, GotoNode(pass.label)) + push!(stmts, fail) + # This should be the same error raised by the real invoke + push!(stmts, Expr(:call, GlobalRef(Core, :throw), invoke_arg_error)) + push!(stmts, pass) + nothing end function inline_as_constant(val::ANY, argexprs, sv::InferenceState, @@ -2376,22 +2406,34 @@ function inline_as_constant(val::ANY, argexprs, sv::InferenceState, if invoke_data === nothing invoke_fexpr = nothing invoke_texpr = nothing + orig_atypes = nothing else invoke_data = invoke_data::InvokeData invoke_fexpr = invoke_data.fexpr invoke_texpr = invoke_data.texpr + orig_atypes = invoke_data.orig_atypes end + need_typecheck = !(orig_atypes === nothing) # check if any arguments aren't effect_free and need to be kept around stmts = invoke_fexpr === nothing ? [] : Any[invoke_fexpr] for i = 1:length(argexprs) arg = argexprs[i] if !effect_free(arg, sv.src, sv.mod, false) - push!(stmts, arg) + if need_typecheck && (orig_atypes::Vector{Any})[i] !== nothing + newvar = newvar!(sv, (orig_atypes::Vector{Any})[i]) + push!(stmts, :($newvar = $arg)) + argexprs[i] = newvar + else + push!(stmts, arg) + end end if i == 1 && !(invoke_texpr === nothing) push!(stmts, invoke_texpr) end end + if need_typecheck + emit_invoke_typecheck!(argexprs, invoke_data::InvokeData, stmts, sv) + end return (QuoteNode(val), stmts) end @@ -2424,11 +2466,14 @@ function invoke_NF(argexprs, etype::ANY, atypes, sv, atype_unlimited::ANY, if invoke_data === nothing invoke_fexpr = nothing invoke_texpr = nothing + orig_atypes = nothing else invoke_data = invoke_data::InvokeData invoke_fexpr = invoke_data.fexpr invoke_texpr = invoke_data.texpr + orig_atypes = invoke_data.orig_atypes end + need_typecheck = !(orig_atypes === nothing) if nu > 1 spec_hit = nothing @@ -2439,7 +2484,7 @@ function invoke_NF(argexprs, etype::ANY, atypes, sv, atype_unlimited::ANY, ex.args = copy(argexprs) ex.typ = etype stmts = [] - arg_hoisted = false + arg_hoisted = need_typecheck for i = length(atypes):-1:1 if i == 1 && !(invoke_texpr === nothing) unshift!(stmts, invoke_texpr) @@ -2450,6 +2495,9 @@ function invoke_NF(argexprs, etype::ANY, atypes, sv, atype_unlimited::ANY, aei = ex.args[i] if !effect_free(aei, sv.src, sv.mod, false) arg_hoisted = true + if need_typecheck && (orig_atypes::Vector{Any})[i] !== nothing + ti = (orig_atypes::Vector{Any})[i] + end newvar = newvar!(sv, ti) unshift!(stmts, :($newvar = $aei)) ex.args[i] = newvar @@ -2457,6 +2505,9 @@ function invoke_NF(argexprs, etype::ANY, atypes, sv, atype_unlimited::ANY, end end invoke_fexpr === nothing || unshift!(stmts, invoke_fexpr) + if need_typecheck + emit_invoke_typecheck!(argexprs, invoke_data::InvokeData, stmts, sv) + end function splitunion(atypes::Vector{Any}, i::Int) if i == 0 local sig = argtypes_to_type(atypes) @@ -2530,11 +2581,11 @@ function invoke_NF(argexprs, etype::ANY, atypes, sv, atype_unlimited::ANY, else local cache_linfo = get_spec_lambda(atype_unlimited, invoke_data) cache_linfo === nothing && return NF - unshift!(argexprs, cache_linfo) ex = Expr(:invoke) ex.args = argexprs ex.typ = etype - if invoke_texpr === nothing + if invoke_texpr === nothing && !need_typecheck + unshift!(argexprs, cache_linfo) if invoke_fexpr === nothing return ex else @@ -2545,6 +2596,23 @@ function invoke_NF(argexprs, etype::ANY, atypes, sv, atype_unlimited::ANY, stmts = Any[invoke_fexpr, :($newvar = $(argexprs[1])), invoke_texpr] argexprs[1] = newvar + if need_typecheck + for i in 2:length(argexprs) + aei = argexprs[i] + if !effect_free(aei, sv.src, sv.mod, false) + if need_typecheck && (orig_atypes::Vector{Any})[i] !== nothing + ti = (orig_atypes::Vector{Any})[i] + else + ti = atypes[i] + end + newvar = newvar!(sv, ti) + push!(stmts, :($newvar = $aei)) + argexprs[i] = newvar + end + end + emit_invoke_typecheck!(argexprs, invoke_data::InvokeData, stmts, sv) + end + unshift!(argexprs, cache_linfo) return ex, stmts end return NF @@ -2602,6 +2670,7 @@ function inlineable(f::ANY, ft::ANY, e::Expr, atypes::Vector{Any}, sv::Inference invoke_data = nothing invoke_fexpr = nothing invoke_texpr = nothing + need_typecheck = false if f === Core.invoke && length(atypes) >= 3 ft = widenconst(atypes[2]) invoke_tt = widenconst(atypes[3]) @@ -2624,26 +2693,46 @@ function inlineable(f::ANY, ft::ANY, e::Expr, atypes::Vector{Any}, sv::Inference if effect_free(invoke_texpr, sv.src, sv.mod, false) invoke_fexpr = nothing end - invoke_data = InvokeData(ft.name.mt, invoke_entry, - invoke_types, invoke_fexpr, invoke_texpr) atype0 = atypes[2] argexpr0 = argexprs[2] atypes = atypes[4:end] argexprs = argexprs[4:end] unshift!(atypes, atype0) unshift!(argexprs, argexpr0) + atype_unlimited = argtypes_to_type(atypes) + if !(atype_unlimited <: invoke_types) + need_typecheck = true + orig_atypes = atypes + invoke_data = InvokeData(ft.name.mt, invoke_entry, + invoke_types, invoke_fexpr, invoke_texpr, + orig_atypes) + invoke_tp = invoke_types.parameters + invoke_nargs = length(atypes) + invoke_nargs == length(invoke_tp) || return NF + atypes = copy(atypes) + for i in 1:invoke_nargs + argt = atypes[i] + invt = invoke_tp[i] + if !(argt ⊑ invt) + atypes[i] = typeintersect(widenconst(argt), invt) + else + orig_atypes[i] = nothing + end + end + atype_unlimited = argtypes_to_type(atypes) + else + invoke_data = InvokeData(ft.name.mt, invoke_entry, + invoke_types, invoke_fexpr, invoke_texpr, + nothing) + end f = isdefined(ft, :instance) ? ft.instance : nothing elseif isa(f, IntrinsicFunction) || ft ⊑ IntrinsicFunction || isa(f, Builtin) || ft ⊑ Builtin return NF + else + atype_unlimited = argtypes_to_type(atypes) end - atype_unlimited = argtypes_to_type(atypes) - if !(invoke_data === nothing) - invoke_data = invoke_data::InvokeData - # TODO emit a type check and proceed for this case - atype_unlimited <: invoke_data.types0 || return NF - end if !sv.inlining return invoke_NF(argexprs, e.typ, atypes, sv, atype_unlimited, invoke_data) @@ -2829,8 +2918,8 @@ function inlineable(f::ANY, ft::ANY, e::Expr, atypes::Vector{Any}, sv::Inference end free = effect_free(aei, sv.src, sv.mod, true) if ((occ==0 && is(aeitype,Bottom)) || (occ > 1 && !inline_worthy(aei, occ*2000)) || - (affect_free && !free) || (!affect_free && !effect_free(aei, sv.src, sv.mod, false))) - if occ != 0 + (affect_free && !free) || (!affect_free && !effect_free(aei, sv.src, sv.mod, false)) || need_typecheck) + if occ != 0 || need_typecheck vnew = newvar!(sv, aeitype) argexprs[i] = vnew unshift!(prelude_stmts, Expr(:(=), vnew, aei)) @@ -2842,6 +2931,10 @@ function inlineable(f::ANY, ft::ANY, e::Expr, atypes::Vector{Any}, sv::Inference end end invoke_fexpr === nothing || unshift!(prelude_stmts, invoke_fexpr) + if need_typecheck + emit_invoke_typecheck!(argexprs, invoke_data::InvokeData, + prelude_stmts, sv) + end # re-number the SSAValues and copy their type-info to the new ast ssavalue_types = src.ssavaluetypes diff --git a/test/core.jl b/test/core.jl index 6d7df6b6d7e56..d81bf251c9e5d 100644 --- a/test/core.jl +++ b/test/core.jl @@ -4594,6 +4594,21 @@ catch e (e::ErrorException).msg end == "generated function body is not pure. this likely means it contains a closure or comprehension." +let x = 1 + global g18444 + @noinline g18444(a) = (x += 1; a[]) + f18444_1(a) = invoke(sin, Tuple{Int}, g18444(a)) + f18444_2(a) = invoke(sin, Tuple{Integer}, g18444(a)) + @test_throws ErrorException f18444_1(Ref{Any}(1.0)) + @test x == 2 + @test_throws ErrorException f18444_2(Ref{Any}(1.0)) + @test x == 3 + @test f18444_1(Ref{Any}(1)) === sin(1) + @test x == 4 + @test f18444_2(Ref{Any}(1)) === sin(1) + @test x == 5 +end + # issue #10981, long argument lists let a = fill(["sdf"], 2*10^6), temp_vcat(x...) = vcat(x...) # we introduce a new function `temp_vcat` to make sure there is no existing