From 4d8a57bece49af5cedcec0f825aefb4552113b3d Mon Sep 17 00:00:00 2001 From: Shuhei Kadowaki Date: Thu, 17 Dec 2020 13:25:23 +0900 Subject: [PATCH] inference: inter-procedural conditional constraint back-propagation This PR propagates `Conditional`s inter-procedurally when a `Conditional` at return site imposes a constraint on the call arguments. When inference exits local frame and the return type is annotated as `Conditional`, it will be converted into `InterConditional` object, which is implemented in `Core` and can be directly put into the global cache. Finally after going back to caller frame, `InterConditional` will be re-converted into `Conditional` in the context of the caller frame. So now some simple "is-wrapper" functions will propagate its constraint as expected, e.g.: ```julia isaint(a) = isa(a, Int) @test Base.return_types((Any,)) do a isaint(a) && return a # a::Int return 0 end == [Int] ``` This PR also tweaks `isnothing` and `ismissing` so that there is no longer any inferrability penalties to use them instead of `x === nothing` or `x === missing` e.g.: ```julia @test Base.return_types((Union{Nothing,Int},)) do a isnothing(a) && return 0 return a # a::Int end == [Int] ``` (and now we don't need something like #38636) There're certain limitations around. One of the biggest ones would be that it can propagate constrains only on a single argument, and it fails to back-propagate constrains when there're multiple conditions on different slots, e.g. `Meta.isexpr` can't still propagate its type constraint to the caller: ```julia @test_broken Base.return_types((Any,)) do x Meta.isexpr(x, :call) && return x # still x::Any but ideally x::Expr return nothing end == [Nothing,Expr] ``` (and because of this reason, this PR can't close #37342) --- base/boot.jl | 1 + base/compiler/abstractinterpretation.jl | 31 +++++++++++++--- base/compiler/tfuncs.jl | 3 ++ base/compiler/typeinfer.jl | 49 ++++++++++++++++++------- base/compiler/typelattice.jl | 35 +++++++++++------- base/compiler/typelimits.jl | 29 +++++++++++++++ base/essentials.jl | 3 +- base/some.jl | 3 +- src/builtins.c | 1 + src/jl_exported_data.inc | 1 + src/jltypes.c | 4 ++ src/julia.h | 1 + src/staticdata.c | 1 + test/compiler/inference.jl | 35 ++++++++++++++++++ 14 files changed, 161 insertions(+), 36 deletions(-) diff --git a/base/boot.jl b/base/boot.jl index 149b940d5d352f..c732d8ebd58abb 100644 --- a/base/boot.jl +++ b/base/boot.jl @@ -424,6 +424,7 @@ eval(Core, :(CodeInstance(mi::MethodInstance, @nospecialize(rettype), @nospecial mi, rettype, inferred_const, inferred, const_flags, min_world, max_world))) eval(Core, :(Const(@nospecialize(v)) = $(Expr(:new, :Const, :v)))) eval(Core, :(PartialStruct(@nospecialize(typ), fields::Array{Any, 1}) = $(Expr(:new, :PartialStruct, :typ, :fields)))) +eval(Core, :(InterConditional(slot::Int, @nospecialize(vtype), @nospecialize(elsetype)) = $(Expr(:new, :InterConditional, :slot, :vtype, :elsetype)))) eval(Core, :(MethodMatch(@nospecialize(spec_types), sparams::SimpleVector, method::Method, fully_covers::Bool) = $(Expr(:new, :MethodMatch, :spec_types, :sparams, :method, :fully_covers)))) diff --git a/base/compiler/abstractinterpretation.jl b/base/compiler/abstractinterpretation.jl index 4ef63f3cbaae17..b91733b2bea227 100644 --- a/base/compiler/abstractinterpretation.jl +++ b/base/compiler/abstractinterpretation.jl @@ -24,8 +24,8 @@ function is_improvable(@nospecialize(rtype)) # already at Bottom return rtype !== Union{} end - # Could be improved to `Const` or a more precise PartialStruct - return isa(rtype, PartialStruct) + # Could be improved to `Const` or a more precise wrapper + return isa(rtype, PartialStruct) || isa(rtype, InterConditional) end function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f), argtypes::Vector{Any}, @nospecialize(atype), sv::InferenceState, @@ -1144,7 +1144,10 @@ function abstract_eval_statement(interp::AbstractInterpreter, @nospecialize(e), end callinfo = abstract_call(interp, ea, argtypes, sv) sv.stmt_info[sv.currpc] = callinfo.info - t = callinfo.rt + rt = callinfo.rt + t = isa(rt, InterConditional) ? + transform_from_interconditional(rt, ea) : + rt elseif e.head === :new t = instanceof_tfunc(abstract_eval_value(interp, e.args[1], vtypes, sv))[1] if isconcretetype(t) && !t.mutable @@ -1255,6 +1258,19 @@ function abstract_eval_statement(interp::AbstractInterpreter, @nospecialize(e), return t end +# try to convert interprocedural-conditional constraint from callee into constraints for +# the current frame +function transform_from_interconditional(rt::InterConditional, ea::Vector{Any}) + i = rt.slot + if checkbounds(Bool, ea, i) + e = @inbounds ea[i] + if isa(e, Slot) + return Conditional(e, rt.vtype, rt.elsetype) + end + end + return widenconditional(rt) +end + function abstract_eval_global(M::Module, s::Symbol) if isdefined(M,s) && isconst(M,s) return Const(getfield(M,s)) @@ -1338,10 +1354,15 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState) end elseif isa(stmt, ReturnNode) pc´ = n + 1 - rt = widenconditional(abstract_eval_value(interp, stmt.val, s[pc], frame)) - if !isa(rt, Const) && !isa(rt, Type) && !isa(rt, PartialStruct) + rt = abstract_eval_value(interp, stmt.val, s[pc], frame) + if !isa(rt, Const) && + !isa(rt, Type) && + !isa(rt, PartialStruct) && + !isa(rt, Conditional) # only propagate information we know we can store # and is valid inter-procedurally + # some of them will further be transformed by `bestguess_to_interprocedural` + # in `finish(::InferenceState, ::AbstractInterpreter)` rt = widenconst(rt) end if tchanged(rt, frame.bestguess) diff --git a/base/compiler/tfuncs.jl b/base/compiler/tfuncs.jl index ccf65fafc737fd..573e25f3203bcd 100644 --- a/base/compiler/tfuncs.jl +++ b/base/compiler/tfuncs.jl @@ -1649,6 +1649,9 @@ function return_type_tfunc(interp::AbstractInterpreter, argtypes::Vector{Any}, s return Const(Union{}) end rt = abstract_call(interp, nothing, argtypes_vec, sv, -1).rt + if isa(rt, InterConditional) + rt = widenconditional(rt) + end if isa(rt, Const) # output was computed to be constant return Const(typeof(rt.val)) diff --git a/base/compiler/typeinfer.jl b/base/compiler/typeinfer.jl index 1162d05721944c..2cb720e576e992 100644 --- a/base/compiler/typeinfer.jl +++ b/base/compiler/typeinfer.jl @@ -286,20 +286,24 @@ end function CodeInstance(result::InferenceResult, @nospecialize(inferred_result::Any), valid_worlds::WorldRange) local const_flags::Int32 + res = result.result if inferred_result isa Const # use constant calling convention rettype_const = (result.src::Const).val const_flags = 0x3 inferred_result = nothing else - if isa(result.result, Const) - rettype_const = (result.result::Const).val + if isa(res, Const) + rettype_const = res.val const_flags = 0x2 - elseif isconstType(result.result) - rettype_const = result.result.parameters[1] + elseif isconstType(res) + rettype_const = res.parameters[1] const_flags = 0x2 - elseif isa(result.result, PartialStruct) - rettype_const = (result.result::PartialStruct).fields + elseif isa(res, PartialStruct) + rettype_const = res.fields + const_flags = 0x2 + elseif isa(res, InterConditional) + rettype_const = res const_flags = 0x2 else rettype_const = nothing @@ -307,7 +311,7 @@ function CodeInstance(result::InferenceResult, @nospecialize(inferred_result::An end end return CodeInstance(result.linfo, - widenconst(result.result), rettype_const, inferred_result, + widenconst(res), rettype_const, inferred_result, const_flags, first(valid_worlds), last(valid_worlds)) end @@ -412,10 +416,22 @@ function finish(me::InferenceState, interp::AbstractInterpreter) me.result.src = OptimizationState(me, OptimizationParams(interp), interp) end me.result.valid_worlds = me.valid_worlds - me.result.result = me.bestguess + me.result.result = me.bestguess = + bestguess_to_interprocedural(me.bestguess, length(me.result.argtypes)) nothing end +bestguess_to_interprocedural(@nospecialize(bestguess), _) = bestguess +function bestguess_to_interprocedural(bestguess::Conditional, nargs::Int) + # keep `Conditional` return type only when it constrains any of call argument + i = slot_id(bestguess.var) + if 1 < i ≤ nargs + return InterConditional(i, bestguess.vtype, bestguess.elsetype) + else + return widenconditional(bestguess) + end +end + function finish(src::CodeInfo, interp::AbstractInterpreter) # convert all type information into the form consumed by the cache for inlining and code-generation widen_all_consts!(src) @@ -724,14 +740,18 @@ function typeinf_edge(interp::AbstractInterpreter, method::Method, @nospecialize code = get(code_cache(interp), mi, nothing) if code isa CodeInstance # return existing rettype if the code is already inferred update_valid_age!(caller, WorldRange(min_world(code), max_world(code))) + rettype = code.rettype if isdefined(code, :rettype_const) - if isa(code.rettype_const, Vector{Any}) && !(Vector{Any} <: code.rettype) - return PartialStruct(code.rettype, code.rettype_const), mi + rettype_const = code.rettype_const + if isa(rettype_const, InterConditional) + return rettype_const, mi + elseif isa(rettype_const, Vector{Any}) && !(Vector{Any} <: rettype) + return PartialStruct(rettype, rettype_const), mi else - return Const(code.rettype_const), mi + return Const(rettype_const), mi end else - return code.rettype, mi + return rettype, mi end end if ccall(:jl_get_module_infer, Cint, (Any,), method.module) == 0 @@ -802,7 +822,8 @@ function typeinf_ext(interp::AbstractInterpreter, mi::MethodInstance) if invoke_api(code) == 2 i == 2 && ccall(:jl_typeinf_end, Cvoid, ()) tree = ccall(:jl_new_code_info_uninit, Ref{CodeInfo}, ()) - tree.code = Any[ ReturnNode(quoted(code.rettype_const)) ] + rettype_const = code.rettype_const + tree.code = Any[ ReturnNode(quoted(rettype_const)) ] nargs = Int(method.nargs) tree.slotnames = ccall(:jl_uncompress_argnames, Vector{Symbol}, (Any,), method.slot_syms) tree.slotflags = fill(0x00, nargs) @@ -814,7 +835,7 @@ function typeinf_ext(interp::AbstractInterpreter, mi::MethodInstance) tree.pure = true tree.inlineable = true tree.parent = mi - tree.rettype = Core.Typeof(code.rettype_const) + tree.rettype = Core.Typeof(widenconditional(rettype_const)) tree.min_world = code.min_world tree.max_world = code.max_world return tree diff --git a/base/compiler/typelattice.jl b/base/compiler/typelattice.jl index 1ae1f437a6e71e..9e99aa9ffeeade 100644 --- a/base/compiler/typelattice.jl +++ b/base/compiler/typelattice.jl @@ -4,7 +4,7 @@ # structs/constants # ##################### -# N.B.: Const/PartialStruct are defined in Core, to allow them to be used +# N.B.: Const/PartialStruct/InterConditional are defined in Core, to allow them to be used # inside the global code cache. # # # The type of a value might be constant @@ -18,7 +18,6 @@ # end import Core: Const, PartialStruct - # The type of this value might be Bool. # However, to enable a limited amount of back-propagagation, # we also keep some information about how this Bool value was created. @@ -45,6 +44,15 @@ struct Conditional end end +# # similar to `Conditional`, but conveys inter-procedural constrains imposed on call arguments +# struct InterConditional +# slot::Int +# vtype +# elsetype +# end +import Core: InterConditional +const AnyConditional = Union{Conditional,InterConditional} + struct PartialTypeVar tv::TypeVar # N.B.: Currently unused, but would allow turning something back @@ -90,11 +98,8 @@ const CompilerTypes = Union{MaybeUndef, Const, Conditional, NotFound, PartialStr # lattice logic # ################# -function issubconditional(a::Conditional, b::Conditional) - avar = a.var - bvar = b.var - if (isa(avar, Slot) && isa(bvar, Slot) && slot_id(avar) === slot_id(bvar)) || - (isa(avar, SSAValue) && isa(bvar, SSAValue) && avar === bvar) +function issubconditional(a::AnyConditional, b::AnyConditional) + if is_same_conditionals(a, b) if a.vtype ⊑ b.vtype if a.elsetype ⊑ b.elsetype return true @@ -103,9 +108,13 @@ function issubconditional(a::Conditional, b::Conditional) end return false end +is_same_conditionals(a::Conditional, b::Conditional) = slot_id(a.var) === slot_id(b.var) +is_same_conditionals(a::Conditional, b::InterConditional) = slot_id(a.var) === b.slot +is_same_conditionals(a::InterConditional, b::Conditional) = is_same_conditionals(b, a) +is_same_conditionals(a::InterConditional, b::InterConditional) = a.slot === b.slot maybe_extract_const_bool(c::Const) = isa(c.val, Bool) ? c.val : nothing -function maybe_extract_const_bool(c::Conditional) +function maybe_extract_const_bool(c::AnyConditional) (c.vtype === Bottom && !(c.elsetype === Bottom)) && return false (c.elsetype === Bottom && !(c.vtype === Bottom)) && return true nothing @@ -122,14 +131,14 @@ function ⊑(@nospecialize(a), @nospecialize(b)) (a === Any || b === NOT_FOUND) && return false a === Union{} && return true b === Union{} && return false - if isa(a, Conditional) - if isa(b, Conditional) + if isa(a, AnyConditional) + if isa(b, AnyConditional) return issubconditional(a, b) elseif isa(b, Const) && isa(b.val, Bool) return maybe_extract_const_bool(a) === b.val end a = Bool - elseif isa(b, Conditional) + elseif isa(b, AnyConditional) return false end if isa(a, PartialStruct) @@ -204,7 +213,7 @@ function is_lattice_equal(@nospecialize(a), @nospecialize(b)) return a ⊑ b && b ⊑ a end -widenconst(c::Conditional) = Bool +widenconst(c::AnyConditional) = Bool function widenconst(c::Const) if isa(c.val, Type) if isvarargtype(c.val) @@ -237,7 +246,7 @@ end @inline schanged(@nospecialize(n), @nospecialize(o)) = (n !== o) && (o === NOT_FOUND || (n !== NOT_FOUND && !issubstate(n, o))) widenconditional(@nospecialize typ) = typ -function widenconditional(typ::Conditional) +function widenconditional(typ::AnyConditional) if typ.vtype === Union{} return Const(false) elseif typ.elsetype === Union{} diff --git a/base/compiler/typelimits.jl b/base/compiler/typelimits.jl index aa1695569d061f..50ed0024afa99a 100644 --- a/base/compiler/typelimits.jl +++ b/base/compiler/typelimits.jl @@ -327,6 +327,35 @@ function tmerge(@nospecialize(typea), @nospecialize(typeb)) end return Bool end + # type-lattice for InterConditional wrapper, InterConditional won't be merged with Conditional + if isa(typea, InterConditional) && isa(typeb, Const) + if typeb.val === true + typeb = InterConditional(typea.slot, Any, Union{}) + elseif typeb.val === false + typeb = InterConditional(typea.slot, Union{}, Any) + end + end + if isa(typeb, InterConditional) && isa(typea, Const) + if typea.val === true + typea = InterConditional(typeb.slot, Any, Union{}) + elseif typea.val === false + typea = InterConditional(typeb.slot, Union{}, Any) + end + end + if isa(typea, InterConditional) && isa(typeb, InterConditional) + if typea.slot === typeb.slot + vtype = tmerge(typea.vtype, typeb.vtype) + elsetype = tmerge(typea.elsetype, typeb.elsetype) + if vtype != elsetype + return InterConditional(typea.slot, vtype, elsetype) + end + end + val = maybe_extract_const_bool(typea) + if val isa Bool && val === maybe_extract_const_bool(typeb) + return Const(val) + end + return Bool + end if (isa(typea, PartialStruct) || isa(typea, Const)) && (isa(typeb, PartialStruct) || isa(typeb, Const)) && widenconst(typea) === widenconst(typeb) diff --git a/base/essentials.jl b/base/essentials.jl index bbba482b84974d..545f1eff8bf64e 100644 --- a/base/essentials.jl +++ b/base/essentials.jl @@ -813,8 +813,7 @@ const missing = Missing() Indicate whether `x` is [`missing`](@ref). """ -ismissing(::Any) = false -ismissing(::Missing) = true +ismissing(x) = x === missing function popfirst! end diff --git a/base/some.jl b/base/some.jl index 041c3359e04333..ad106d3264cf0f 100644 --- a/base/some.jl +++ b/base/some.jl @@ -63,8 +63,7 @@ Return `true` if `x === nothing`, and return `false` if not. !!! compat "Julia 1.1" This function requires at least Julia 1.1. """ -isnothing(::Any) = false -isnothing(::Nothing) = true +isnothing(x) = x === nothing """ diff --git a/src/builtins.c b/src/builtins.c index 96637080af5376..b20e94a7214b8c 100644 --- a/src/builtins.c +++ b/src/builtins.c @@ -1622,6 +1622,7 @@ void jl_init_primitives(void) JL_GC_DISABLED add_builtin("Argument", (jl_value_t*)jl_argument_type); add_builtin("Const", (jl_value_t*)jl_const_type); add_builtin("PartialStruct", (jl_value_t*)jl_partial_struct_type); + add_builtin("InterConditional", (jl_value_t*)jl_interconditional_type); add_builtin("MethodMatch", (jl_value_t*)jl_method_match_type); add_builtin("IntrinsicFunction", (jl_value_t*)jl_intrinsic_type); add_builtin("Function", (jl_value_t*)jl_function_type); diff --git a/src/jl_exported_data.inc b/src/jl_exported_data.inc index ee09992c6a584e..58bb6a4dc85d53 100644 --- a/src/jl_exported_data.inc +++ b/src/jl_exported_data.inc @@ -70,6 +70,7 @@ XX(jl_nothing_type) \ XX(jl_number_type) \ XX(jl_partial_struct_type) \ + XX(jl_interconditional_type) \ XX(jl_phicnode_type) \ XX(jl_phinode_type) \ XX(jl_pinode_type) \ diff --git a/src/jltypes.c b/src/jltypes.c index 0166430671bb12..943d1b38642139 100644 --- a/src/jltypes.c +++ b/src/jltypes.c @@ -2302,6 +2302,10 @@ void jl_init_types(void) JL_GC_DISABLED jl_perm_symsvec(2, "typ", "fields"), jl_svec2(jl_any_type, jl_array_any_type), 0, 0, 2); + jl_interconditional_type = jl_new_datatype(jl_symbol("InterConditional"), core, jl_any_type, jl_emptysvec, + jl_perm_symsvec(3, "slot", "vtype", "elsetype"), + jl_svec(3, jl_long_type, jl_any_type, jl_any_type), 0, 0, 3); + jl_method_match_type = jl_new_datatype(jl_symbol("MethodMatch"), core, jl_any_type, jl_emptysvec, jl_perm_symsvec(4, "spec_types", "sparams", "method", "fully_covers"), jl_svec(4, jl_type_type, jl_simplevector_type, jl_method_type, jl_bool_type), 0, 0, 4); diff --git a/src/julia.h b/src/julia.h index d97162c8f8aa3f..6591bcde6ce073 100644 --- a/src/julia.h +++ b/src/julia.h @@ -619,6 +619,7 @@ extern JL_DLLIMPORT jl_datatype_t *jl_typedslot_type JL_GLOBALLY_ROOTED; extern JL_DLLIMPORT jl_datatype_t *jl_argument_type JL_GLOBALLY_ROOTED; extern JL_DLLIMPORT jl_datatype_t *jl_const_type JL_GLOBALLY_ROOTED; extern JL_DLLIMPORT jl_datatype_t *jl_partial_struct_type JL_GLOBALLY_ROOTED; +extern JL_DLLIMPORT jl_datatype_t *jl_interconditional_type JL_GLOBALLY_ROOTED; extern JL_DLLIMPORT jl_datatype_t *jl_method_match_type JL_GLOBALLY_ROOTED; extern JL_DLLIMPORT jl_datatype_t *jl_simplevector_type JL_GLOBALLY_ROOTED; extern JL_DLLIMPORT jl_typename_t *jl_tuple_typename JL_GLOBALLY_ROOTED; diff --git a/src/staticdata.c b/src/staticdata.c index 1b827bd12af177..798a8a79f35025 100644 --- a/src/staticdata.c +++ b/src/staticdata.c @@ -68,6 +68,7 @@ jl_value_t **const*const get_tags(void) { INSERT_TAG(jl_returnnode_type); INSERT_TAG(jl_const_type); INSERT_TAG(jl_partial_struct_type); + INSERT_TAG(jl_interconditional_type); INSERT_TAG(jl_method_match_type); INSERT_TAG(jl_pinode_type); INSERT_TAG(jl_phinode_type); diff --git a/test/compiler/inference.jl b/test/compiler/inference.jl index 304a704856ebb8..db665d3d9eae70 100644 --- a/test/compiler/inference.jl +++ b/test/compiler/inference.jl @@ -1719,6 +1719,41 @@ for expr25261 in opt25261[i:end] end @test foundslot +@testset "interprocedural conditional constraint propagation" begin + isaint(a) = isa(a, Int) + @test Base.return_types((Any,)) do a + isaint(a) && return a # a::Int + return 0 + end == [Int] + eqnothing(a) = a === nothing + @test Base.return_types((Union{Nothing,Int},)) do a + eqnothing(a) && return 0 + return a # a::Int + end == [Int] + + # tests with base functions + @test Base.return_types((Any,)) do a + Base.Fix2(isa, Int)(a) && return sin(a) # a::Float64 + return 0.0 + end == [Float64] + @test Base.return_types((Union{Nothing,Int},)) do a + isnothing(a) && return 0 + return a # a::Int + end == [Int] + + # FIXME: we can't propagate conditional constraints interprocedurally when there're + # multiple possible conditions within the callee + ispositive(a) = isa(a, Int) && a > 0 + @test_broken Base.return_types((Any,)) do a + ispositive(a) && return a # a::Int, ideally + return 0 + end == [Int] + @test_broken Base.return_types((Any,)) do x + Meta.isexpr(x, :call) && return x # x::Expr, ideally + return nothing + end == [Nothing,Expr] +end + function f25579(g) h = g[] t = (h === nothing)