Skip to content

Commit

Permalink
inference: inter-procedural conditional constraint back-propagation
Browse files Browse the repository at this point in the history
This PR propagates `Conditional`s inter-procedurally when a
`Conditional` at return site imposes a constraint on the call arguments.
When inference exits local frame and the return type is annotated as
`Conditional`, it will be converted into `InterConditional` object,
which is implemented in `Core` and can be directly put into the global
cache. Finally after going back to caller frame, `InterConditional` will
be re-converted into `Conditional` in the context of the caller frame.

So now some simple "is-wrapper" functions will propagate its constraint
as expected, e.g.:
```julia
isaint(a) = isa(a, Int)
@test Base.return_types((Any,)) do a
    isaint(a) && return a # a::Int
    return 0
end == [Int]
```

This PR also tweaks `isnothing` and `ismissing` so that there is no
longer any inferrability penalties to use them instead of
`x === nothing` or `x === missing` e.g.:
```julia
@test Base.return_types((Union{Nothing,Int},)) do a
    isnothing(a) && return 0
    return a # a::Int
end == [Int]
```
(and now we don't need something like JuliaLang#38636)

There're certain limitations around. One of the biggest ones would be
that it can propagate constrains only on a single argument, and it fails
to back-propagate constrains when there're multiple conditions on
different slots, e.g. `Meta.isexpr` can't still propagate its type
constraint to the caller:
```julia
@test_broken Base.return_types((Any,)) do x
    Meta.isexpr(x, :call) && return x # still x::Any but ideally x::Expr
    return nothing
end == [Nothing,Expr]
```
(and because of this reason, this PR can't close JuliaLang#37342)
  • Loading branch information
aviatesk committed Dec 16, 2020
1 parent f184f05 commit 52e547f
Show file tree
Hide file tree
Showing 14 changed files with 150 additions and 24 deletions.
1 change: 1 addition & 0 deletions base/boot.jl
Original file line number Diff line number Diff line change
Expand Up @@ -424,6 +424,7 @@ eval(Core, :(CodeInstance(mi::MethodInstance, @nospecialize(rettype), @nospecial
mi, rettype, inferred_const, inferred, const_flags, min_world, max_world)))
eval(Core, :(Const(@nospecialize(v)) = $(Expr(:new, :Const, :v))))
eval(Core, :(PartialStruct(@nospecialize(typ), fields::Array{Any, 1}) = $(Expr(:new, :PartialStruct, :typ, :fields))))
eval(Core, :(InterConditional(slot::Int, @nospecialize(vtype), @nospecialize(elsetype)) = $(Expr(:new, :InterConditional, :slot, :vtype, :elsetype))))
eval(Core, :(MethodMatch(@nospecialize(spec_types), sparams::SimpleVector, method::Method, fully_covers::Bool) =
$(Expr(:new, :MethodMatch, :spec_types, :sparams, :method, :fully_covers))))

Expand Down
25 changes: 22 additions & 3 deletions base/compiler/abstractinterpretation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1144,7 +1144,10 @@ function abstract_eval_statement(interp::AbstractInterpreter, @nospecialize(e),
end
callinfo = abstract_call(interp, ea, argtypes, sv)
sv.stmt_info[sv.currpc] = callinfo.info
t = callinfo.rt
rt = callinfo.rt
t = isa(rt, InterConditional) ?
transform_from_interconditional(rt, ea) :
rt
elseif e.head === :new
t = instanceof_tfunc(abstract_eval_value(interp, e.args[1], vtypes, sv))[1]
if isconcretetype(t) && !t.mutable
Expand Down Expand Up @@ -1255,6 +1258,19 @@ function abstract_eval_statement(interp::AbstractInterpreter, @nospecialize(e),
return t
end

# try to convert interprocedural-conditional constraint from callee into constraints for
# the current frame
function transform_from_interconditional(rt::InterConditional, ea::Vector{Any})
i = rt.slot
if checkbounds(Bool, ea, i)
e = @inbounds ea[i]
if isa(e, Slot)
return Conditional(e, rt.vtype, rt.elsetype)
end
end
return widenconditional(rt)
end

function abstract_eval_global(M::Module, s::Symbol)
if isdefined(M,s) && isconst(M,s)
return Const(getfield(M,s))
Expand Down Expand Up @@ -1338,8 +1354,11 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState)
end
elseif isa(stmt, ReturnNode)
pc´ = n + 1
rt = widenconditional(abstract_eval_value(interp, stmt.val, s[pc], frame))
if !isa(rt, Const) && !isa(rt, Type) && !isa(rt, PartialStruct)
rt = abstract_eval_value(interp, stmt.val, s[pc], frame)
if !isa(rt, Const) &&
!isa(rt, Type) &&
!isa(rt, PartialStruct) &&
!isa(rt, Conditional)
# only propagate information we know we can store
# and is valid inter-procedurally
rt = widenconst(rt)
Expand Down
3 changes: 3 additions & 0 deletions base/compiler/tfuncs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1649,6 +1649,9 @@ function return_type_tfunc(interp::AbstractInterpreter, argtypes::Vector{Any}, s
return Const(Union{})
end
rt = abstract_call(interp, nothing, argtypes_vec, sv, -1).rt
if isa(rt, InterConditional)
rt = widenconditional(rt)
end
if isa(rt, Const)
# output was computed to be constant
return Const(typeof(rt.val))
Expand Down
50 changes: 37 additions & 13 deletions base/compiler/typeinfer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -286,28 +286,34 @@ end
function CodeInstance(result::InferenceResult, @nospecialize(inferred_result::Any),
valid_worlds::WorldRange)
local const_flags::Int32
res = result.result
rettype = widenconst(res)
if inferred_result isa Const
# use constant calling convention
rettype_const = (result.src::Const).val
const_flags = 0x3
inferred_result = nothing
else
if isa(result.result, Const)
rettype_const = (result.result::Const).val
if isa(res, Const)
rettype_const = res.val
const_flags = 0x2
elseif isconstType(result.result)
rettype_const = result.result.parameters[1]
elseif isconstType(res)
rettype_const = res.parameters[1]
const_flags = 0x2
elseif isa(result.result, PartialStruct)
rettype_const = (result.result::PartialStruct).fields
elseif isa(res, PartialStruct)
rettype_const = res.fields
const_flags = 0x2
else
if isa(res, Conditional)
# TODO: put this into its own field ?
rettype = transform_to_interconditional(res, length(result.argtypes))
end
rettype_const = nothing
const_flags = 0x00
end
end
return CodeInstance(result.linfo,
widenconst(result.result), rettype_const, inferred_result,
rettype, rettype_const, inferred_result,
const_flags, first(valid_worlds), last(valid_worlds))
end

Expand Down Expand Up @@ -724,14 +730,15 @@ function typeinf_edge(interp::AbstractInterpreter, method::Method, @nospecialize
code = get(code_cache(interp), mi, nothing)
if code isa CodeInstance # return existing rettype if the code is already inferred
update_valid_age!(caller, WorldRange(min_world(code), max_world(code)))
rettype = code.rettype
if isdefined(code, :rettype_const)
if isa(code.rettype_const, Vector{Any}) && !(Vector{Any} <: code.rettype)
return PartialStruct(code.rettype, code.rettype_const), mi
if isa(code.rettype_const, Vector{Any}) && !(isa(rettype, InterConditional) || Vector{Any} <: rettype)
return PartialStruct(rettype, code.rettype_const), mi
else
return Const(code.rettype_const), mi
end
else
return code.rettype, mi
return rettype, mi
end
end
if ccall(:jl_get_module_infer, Cint, (Any,), method.module) == 0
Expand Down Expand Up @@ -759,15 +766,32 @@ function typeinf_edge(interp::AbstractInterpreter, method::Method, @nospecialize
end
typeinf(interp, frame)
update_valid_age!(frame, caller)
return frame.bestguess, frame.inferred ? mi : nothing
bestguess = frame.bestguess
if isa(bestguess, Conditional)
bestguess = transform_to_interconditional(bestguess, length(result.argtypes))
end
return bestguess, frame.inferred ? mi : nothing
elseif frame === true
# unresolvable cycle
return Any, nothing
end
# return the current knowledge about this cycle
frame = frame::InferenceState
update_valid_age!(frame, caller)
return frame.bestguess, nothing
bestguess = frame.bestguess
if isa(bestguess, Conditional)
bestguess = transform_to_interconditional(bestguess, length(frame.result.argtypes))
end
return bestguess, nothing
end

function transform_to_interconditional(bestguess::Conditional, nargs::Int)
# keep this conditional only when it constrains a slot within call arguments
if 1 < slot_id(bestguess.var) <= nargs
return InterConditional(slot_id(bestguess.var), bestguess.vtype, bestguess.elsetype)
else
return widenconditional(bestguess)
end
end

#### entry points for inferring a MethodInstance given a type signature ####
Expand Down Expand Up @@ -860,7 +884,7 @@ function typeinf_type(interp::AbstractInterpreter, method::Method, @nospecialize
if code isa CodeInstance
# see if this rettype already exists in the cache
i == 2 && ccall(:jl_typeinf_end, Cvoid, ())
return code.rettype
return widenconst(code.rettype)
end
end
frame = InferenceResult(mi)
Expand Down
17 changes: 13 additions & 4 deletions base/compiler/typelattice.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# structs/constants #
#####################

# N.B.: Const/PartialStruct are defined in Core, to allow them to be used
# N.B.: Const/PartialStruct/InterConditional are defined in Core, to allow them to be used
# inside the global code cache.
#
# # The type of a value might be constant
Expand All @@ -18,7 +18,6 @@
# end
import Core: Const, PartialStruct


# The type of this value might be Bool.
# However, to enable a limited amount of back-propagagation,
# we also keep some information about how this Bool value was created.
Expand All @@ -45,6 +44,15 @@ struct Conditional
end
end

# # similar to `Conditional`, but conveys inter-procedural constrains imposed on call arguments
# struct InterConditional
# slot::Int
# vtype
# elsetype
# end
import Core: InterConditional
const ConditionalWrapper = Union{Conditional,InterConditional}

struct PartialTypeVar
tv::TypeVar
# N.B.: Currently unused, but would allow turning something back
Expand Down Expand Up @@ -105,7 +113,7 @@ function issubconditional(a::Conditional, b::Conditional)
end

maybe_extract_const_bool(c::Const) = isa(c.val, Bool) ? c.val : nothing
function maybe_extract_const_bool(c::Conditional)
function maybe_extract_const_bool(c::ConditionalWrapper)
(c.vtype === Bottom && !(c.elsetype === Bottom)) && return false
(c.elsetype === Bottom && !(c.vtype === Bottom)) && return true
nothing
Expand Down Expand Up @@ -205,6 +213,7 @@ function is_lattice_equal(@nospecialize(a), @nospecialize(b))
end

widenconst(c::Conditional) = Bool
widenconst(c::InterConditional) = Bool
function widenconst(c::Const)
if isa(c.val, Type)
if isvarargtype(c.val)
Expand Down Expand Up @@ -237,7 +246,7 @@ end
@inline schanged(@nospecialize(n), @nospecialize(o)) = (n !== o) && (o === NOT_FOUND || (n !== NOT_FOUND && !issubstate(n, o)))

widenconditional(@nospecialize typ) = typ
function widenconditional(typ::Conditional)
function widenconditional(typ::ConditionalWrapper)
if typ.vtype === Union{}
return Const(false)
elseif typ.elsetype === Union{}
Expand Down
29 changes: 29 additions & 0 deletions base/compiler/typelimits.jl
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,35 @@ function tmerge(@nospecialize(typea), @nospecialize(typeb))
end
return Bool
end
# type-lattice for InterConditional wrapper, InterConditional won't be merged with Conditional
if isa(typea, InterConditional) && isa(typeb, Const)
if typeb.val === true
typeb = InterConditional(typea.slot, Any, Union{})
elseif typeb.val === false
typeb = InterConditional(typea.slot, Union{}, Any)
end
end
if isa(typeb, InterConditional) && isa(typea, Const)
if typea.val === true
typea = InterConditional(typeb.slot, Any, Union{})
elseif typea.val === false
typea = InterConditional(typeb.slot, Union{}, Any)
end
end
if isa(typea, InterConditional) && isa(typeb, InterConditional)
if typea.slot === typeb.slot
vtype = tmerge(typea.vtype, typeb.vtype)
elsetype = tmerge(typea.elsetype, typeb.elsetype)
if vtype != elsetype
return InterConditional(typea.slot, vtype, elsetype)
end
end
val = maybe_extract_const_bool(typea)
if val isa Bool && val === maybe_extract_const_bool(typeb)
return Const(val)
end
return Bool
end
if (isa(typea, PartialStruct) || isa(typea, Const)) &&
(isa(typeb, PartialStruct) || isa(typeb, Const)) &&
widenconst(typea) === widenconst(typeb)
Expand Down
3 changes: 1 addition & 2 deletions base/essentials.jl
Original file line number Diff line number Diff line change
Expand Up @@ -813,8 +813,7 @@ const missing = Missing()
Indicate whether `x` is [`missing`](@ref).
"""
ismissing(::Any) = false
ismissing(::Missing) = true
ismissing(x) = x === missing

function popfirst! end

Expand Down
3 changes: 1 addition & 2 deletions base/some.jl
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,7 @@ Return `true` if `x === nothing`, and return `false` if not.
!!! compat "Julia 1.1"
This function requires at least Julia 1.1.
"""
isnothing(::Any) = false
isnothing(::Nothing) = true
isnothing(x) = x === nothing


"""
Expand Down
1 change: 1 addition & 0 deletions src/builtins.c
Original file line number Diff line number Diff line change
Expand Up @@ -1622,6 +1622,7 @@ void jl_init_primitives(void) JL_GC_DISABLED
add_builtin("Argument", (jl_value_t*)jl_argument_type);
add_builtin("Const", (jl_value_t*)jl_const_type);
add_builtin("PartialStruct", (jl_value_t*)jl_partial_struct_type);
add_builtin("InterConditional", (jl_value_t*)jl_interconditional_type);
add_builtin("MethodMatch", (jl_value_t*)jl_method_match_type);
add_builtin("IntrinsicFunction", (jl_value_t*)jl_intrinsic_type);
add_builtin("Function", (jl_value_t*)jl_function_type);
Expand Down
1 change: 1 addition & 0 deletions src/jl_exported_data.inc
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@
XX(jl_nothing_type) \
XX(jl_number_type) \
XX(jl_partial_struct_type) \
XX(jl_interconditional_type) \
XX(jl_phicnode_type) \
XX(jl_phinode_type) \
XX(jl_pinode_type) \
Expand Down
4 changes: 4 additions & 0 deletions src/jltypes.c
Original file line number Diff line number Diff line change
Expand Up @@ -2302,6 +2302,10 @@ void jl_init_types(void) JL_GC_DISABLED
jl_perm_symsvec(2, "typ", "fields"),
jl_svec2(jl_any_type, jl_array_any_type), 0, 0, 2);

jl_interconditional_type = jl_new_datatype(jl_symbol("InterConditional"), core, jl_any_type, jl_emptysvec,
jl_perm_symsvec(3, "slot", "vtype", "elsetype"),
jl_svec(3, jl_long_type, jl_any_type, jl_any_type), 0, 0, 3);

jl_method_match_type = jl_new_datatype(jl_symbol("MethodMatch"), core, jl_any_type, jl_emptysvec,
jl_perm_symsvec(4, "spec_types", "sparams", "method", "fully_covers"),
jl_svec(4, jl_type_type, jl_simplevector_type, jl_method_type, jl_bool_type), 0, 0, 4);
Expand Down
1 change: 1 addition & 0 deletions src/julia.h
Original file line number Diff line number Diff line change
Expand Up @@ -619,6 +619,7 @@ extern JL_DLLIMPORT jl_datatype_t *jl_typedslot_type JL_GLOBALLY_ROOTED;
extern JL_DLLIMPORT jl_datatype_t *jl_argument_type JL_GLOBALLY_ROOTED;
extern JL_DLLIMPORT jl_datatype_t *jl_const_type JL_GLOBALLY_ROOTED;
extern JL_DLLIMPORT jl_datatype_t *jl_partial_struct_type JL_GLOBALLY_ROOTED;
extern JL_DLLIMPORT jl_datatype_t *jl_interconditional_type JL_GLOBALLY_ROOTED;
extern JL_DLLIMPORT jl_datatype_t *jl_method_match_type JL_GLOBALLY_ROOTED;
extern JL_DLLIMPORT jl_datatype_t *jl_simplevector_type JL_GLOBALLY_ROOTED;
extern JL_DLLIMPORT jl_typename_t *jl_tuple_typename JL_GLOBALLY_ROOTED;
Expand Down
1 change: 1 addition & 0 deletions src/staticdata.c
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ jl_value_t **const*const get_tags(void) {
INSERT_TAG(jl_returnnode_type);
INSERT_TAG(jl_const_type);
INSERT_TAG(jl_partial_struct_type);
INSERT_TAG(jl_interconditional_type);
INSERT_TAG(jl_method_match_type);
INSERT_TAG(jl_pinode_type);
INSERT_TAG(jl_phinode_type);
Expand Down
35 changes: 35 additions & 0 deletions test/compiler/inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1719,6 +1719,41 @@ for expr25261 in opt25261[i:end]
end
@test foundslot

@testset "interprocedural conditional constraint propagation" begin
isaint(a) = isa(a, Int)
@test Base.return_types((Any,)) do a
isaint(a) && return a # a::Int
return 0
end == [Int]
eqnothing(a) = a === nothing
@test Base.return_types((Union{Nothing,Int},)) do a
eqnothing(a) && return 0
return a # a::Int
end == [Int]

# tests with base functions
@test Base.return_types((Any,)) do a
Base.Fix2(isa, Int)(a) && return sin(a) # a::Float64
return 0.0
end == [Float64]
@test Base.return_types((Union{Nothing,Int},)) do a
isnothing(a) && return 0
return a # a::Int
end == [Int]

# FIXME: we can't propagate conditional constraints interprocedurally when there're
# multiple possible conditions within the callee
ispositive(a) = isa(a, Int) && a > 0
@test_broken Base.return_types((Any,)) do a
ispositive(a) && return a # a::Int, ideally
return 0
end == [Int]
@test_broken Base.return_types((Any,)) do x
Meta.isexpr(x, :call) && return x # x::Expr, ideally
return nothing
end == [Nothing,Expr]
end

function f25579(g)
h = g[]
t = (h === nothing)
Expand Down

0 comments on commit 52e547f

Please sign in to comment.