Skip to content

Commit

Permalink
inference: inter-procedural conditional constraint back-propagation
Browse files Browse the repository at this point in the history
This PR propagates `Conditional`s inter-procedurally when a
`Conditional` at return site imposes a constraint on the call arguments.
When inference exits local frame and the return type is annotated as
`Conditional`, it will be converted into `InterConditional` object,
which is implemented in `Core` and can be directly put into the global
cache. Finally after going back to caller frame, `InterConditional` will
be re-converted into `Conditional` in the context of the caller frame.

 ## improvements

So now some simple "is-wrapper" functions will propagate its constraint
as expected, e.g.:
```julia
isaint(a) = isa(a, Int)
@test Base.return_types((Any,)) do a
    isaint(a) && return a # a::Int
    return 0
end == Any[Int]

isaint2(::Any) = false
isaint2(::Int) = true
@test Base.return_types((Any,)) do a
    isaint2(a) && return a # a::Int
    return 0
end == Any[Int]

function isa_int_or_float64(a)
    isa(a, Int) && return true
    isa(a, Float64) && return true
    return false
end
@test Base.return_types((Any,)) do a
    isa_int_or_float64(a) && return a # a::Union{Float64,Int}
    0
end == Any[Union{Float64,Int}]
```

(and now we don't need something like JuliaLang#38636)

 ## benchmarks

A compile time comparison:
> on the current master (82d79ce)
```
Sysimage built. Summary:
Total ───────  55.295376 seconds
Base: ───────  23.359226 seconds 42.2444%
Stdlibs: ────  31.934773 seconds 57.7531%
    JULIA usr/lib/julia/sys-o.a
Generating REPL precompile statements... 29/29
Executing precompile statements... 1283/1283
Precompilation complete. Summary:
Total ───────  91.129162 seconds
Generation ──  68.800937 seconds 75.4983%
Execution ───  22.328225 seconds 24.5017%
    LINK usr/lib/julia/sys.dylib
```

> on this PR (37e279b)
```
Sysimage built. Summary:
Total ───────  51.694730 seconds
Base: ───────  21.943914 seconds 42.449%
Stdlibs: ────  29.748987 seconds 57.5474%
    JULIA usr/lib/julia/sys-o.a
Generating REPL precompile statements... 29/29
Executing precompile statements... 1357/1357
Precompilation complete. Summary:
Total ───────  88.956226 seconds
Generation ──  67.077710 seconds 75.4053%
Execution ───  21.878515 seconds 24.5947%
    LINK usr/lib/julia/sys.dylib
```

Here is a sample code that benefits from this PR:
```julia
function summer(ary)
    r = 0
    for a in ary
        if ispositive(a)
            r += a
        end
    end
    r
end

ispositive(a) = isa(a, Int) && a > 0

ary = Any[]
for _ in 1:100_000
    if rand(Bool)
        push!(ary, rand(-100:100))
    elseif rand(Bool)
        push!(ary, rand('a':'z'))
    else
        push!(ary, nothing)
    end
end

using BenchmarkTools
@Btime summer($(ary))
```

> on the current master (82d79ce)
```
❯ julia summer.jl
  1.214 ms (24923 allocations: 389.42 KiB)
```

> on this PR (37e279b)
```
❯ julia summer.jl
  421.223 μs (0 allocations: 0 bytes)
```

 ## caveats

Within the `Conditional`/`InterConditional` framework, only a single
constraint can be back-propagated inter-procedurally. This PR implements
a naive heuristic to "pick up" a constraint to be propagated when a
return type is a boolean. The heuristic may fail to select an
"interesting" constraint in some cases. For example, we may expect
`::Expr` constraint to be imposed on the first argument of
`Meta.isexpr`, but the current heuristic ends up picking up a constraint
on the second argument (i.e. `ex.head === head`).
```julia
isexpr(@nospecialize(ex), head::Symbol) = isa(ex, Expr) && ex.head === head

@test_broken Base.return_types((Any,)) do x
    Meta.isexpr(x, :call) && return x # x::Expr, ideally
    return nothing
end == Any[Union{Nothing,Expr}]
```

I think We can get rid of this limitation by extending `Conditional` and
`InterConditional`
so that they can convey multiple constraints, but I'd like to leave this
as a future work.

---

- closes JuliaLang#38636
- closes JuliaLang#37342
  • Loading branch information
aviatesk committed Feb 6, 2021
1 parent 55baf8a commit f06ab09
Show file tree
Hide file tree
Showing 11 changed files with 234 additions and 36 deletions.
1 change: 1 addition & 0 deletions base/boot.jl
Original file line number Diff line number Diff line change
Expand Up @@ -424,6 +424,7 @@ eval(Core, :(CodeInstance(mi::MethodInstance, @nospecialize(rettype), @nospecial
mi, rettype, inferred_const, inferred, const_flags, min_world, max_world)))
eval(Core, :(Const(@nospecialize(v)) = $(Expr(:new, :Const, :v))))
eval(Core, :(PartialStruct(@nospecialize(typ), fields::Array{Any, 1}) = $(Expr(:new, :PartialStruct, :typ, :fields))))
eval(Core, :(InterConditional(slot::Int, @nospecialize(vtype), @nospecialize(elsetype)) = $(Expr(:new, :InterConditional, :slot, :vtype, :elsetype))))
eval(Core, :(MethodMatch(@nospecialize(spec_types), sparams::SimpleVector, method::Method, fully_covers::Bool) =
$(Expr(:new, :MethodMatch, :spec_types, :sparams, :method, :fully_covers))))

Expand Down
101 changes: 85 additions & 16 deletions base/compiler/abstractinterpretation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ function is_improvable(@nospecialize(rtype))
# already at Bottom
return rtype !== Union{}
end
# Could be improved to `Const` or a more precise PartialStruct
return isa(rtype, PartialStruct)
# Could be improved to `Const` or a more precise wrapper
return isa(rtype, PartialStruct) || isa(rtype, InterConditional)
end

function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f), argtypes::Vector{Any}, @nospecialize(atype), sv::InferenceState,
Expand Down Expand Up @@ -191,6 +191,8 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
end
end
end

@assert !(rettype isa Conditional) "invalid lattice element returned from inter-procedural context"
#print("=> ", rettype, "\n")
if rettype isa LimitedAccuracy
union!(sv.pclimitations, rettype.causes)
Expand Down Expand Up @@ -1062,9 +1064,29 @@ function abstract_call(interp::AbstractInterpreter, fargs::Union{Nothing,Vector{
add_remark!(interp, sv, "Could not identify method table for call")
return CallMeta(Any, false)
end
return abstract_call_gf_by_type(interp, nothing, argtypes, argtypes_to_type(argtypes), sv, max_methods)
callinfo = abstract_call_gf_by_type(interp, nothing, argtypes, argtypes_to_type(argtypes), sv, max_methods)
return callinfo_from_interprocedural(callinfo, fargs)
end
callinfo = abstract_call_known(interp, f, fargs, argtypes, sv, max_methods)
return callinfo_from_interprocedural(callinfo, fargs)
end

function callinfo_from_interprocedural(callinfo::CallMeta, ea::Union{Nothing,Vector{Any}})
rt = callinfo.rt
if isa(rt, InterConditional)
if ea !== nothing
# convert inter-procedural conditional constraint from callee into the constraint
# on slots of the current frame; `InterConditional` only comes from a "valid"
# `abstract_call` as such its slot should always be within the bound of this
# call arguments `ea`
e = ea[rt.slot]
if isa(e, Slot)
return CallMeta(Conditional(e, rt.vtype, rt.elsetype), callinfo.info)
end
end
return CallMeta(widenconditional(rt), callinfo.info)
end
return abstract_call_known(interp, f, fargs, argtypes, sv, max_methods)
return callinfo
end

function sp_type_rewrap(@nospecialize(T), linfo::MethodInstance, isreturn::Bool)
Expand Down Expand Up @@ -1319,6 +1341,9 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState)
W = frame.ip
s = frame.stmt_types
n = frame.nstmts
nargs = frame.nargs
def = frame.linfo.def
isva = isa(def, Method) && def.isva
while frame.pc´´ <= n
# make progress on the active ip set
local pc::Int = frame.pc´´ # current program-counter
Expand Down Expand Up @@ -1369,12 +1394,8 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState)
frame.handler_at[l] = frame.cur_hand
changes_else = changes
if isa(condt, Conditional)
if condt.elsetype !== Any && condt.elsetype !== changes[slot_id(condt.var)]
changes_else = StateUpdate(condt.var, VarState(condt.elsetype, false), changes_else)
end
if condt.vtype !== Any && condt.vtype !== changes[slot_id(condt.var)]
changes = StateUpdate(condt.var, VarState(condt.vtype, false), changes)
end
changes_else = conditional_changes(changes_else, condt.elsetype, condt.var)
changes = conditional_changes(changes, condt.vtype, condt.var)
end
newstate_else = stupdate!(s[l], changes_else)
if newstate_else !== false
Expand All @@ -1388,10 +1409,37 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState)
end
elseif isa(stmt, ReturnNode)
pc´ = n + 1
rt = widenconditional(abstract_eval_value(interp, stmt.val, s[pc], frame))
if !isa(rt, Const) && !isa(rt, Type) && !isa(rt, PartialStruct)
# only propagate information we know we can store
# and is valid inter-procedurally
bestguess = frame.bestguess
rt = abstract_eval_value(interp, stmt.val, s[pc], frame)
if isva
# give up inter-procedural constraint back-propagation from vararg methods
# because types of same slot may differ between callee and caller
rt = widenconditional(rt)
else
if isa(rt, Conditional) && !(1 slot_id(rt.var) nargs)
# discard this `Conditional` imposed on non-call arguments,
# since it's not interesting in inter-procedural context;
# we may give constraints on other call argument
rt = widenconditional(rt)
end
if !isa(rt, Conditional) && rt Bool
if isa(bestguess, InterConditional)
# if the bestguess so far is already `Conditional`, try to convert
# this `rt` into `Conditional` on the slot to avoid overapproximation
# due to conflict of different slots
rt = boolean_rt_to_conditional(rt, changes, bestguess.slot)
elseif nargs 1
# pick up the first "interesting" slot, convert `rt` to its `Conditional`
# TODO: this is very naive heuristic, ideally we want `Conditional`
# and `InterConditional` to convey constraints on multiple slots
rt = boolean_rt_to_conditional(rt, changes, nargs > 1 ? 2 : 1)
end
end
end
# only propagate information we know we can store and is valid inter-procedurally
if isa(rt, Conditional)
rt = InterConditional(slot_id(rt.var), rt.vtype, rt.elsetype)
elseif !isa(rt, Const) && !isa(rt, Type) && !isa(rt, PartialStruct)
rt = widenconst(rt)
end
# copy limitations to return value
Expand All @@ -1402,9 +1450,9 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState)
if !isempty(frame.limitations)
rt = LimitedAccuracy(rt, copy(frame.limitations))
end
if tchanged(rt, frame.bestguess)
if tchanged(rt, bestguess)
# new (wider) return type for frame
frame.bestguess = tmerge(frame.bestguess, rt)
frame.bestguess = tmerge(bestguess, rt)
for (caller, caller_pc) in frame.cycle_backedges
# notify backedges of updated type information
typeassert(caller.stmt_types[caller_pc], VarTable) # we must have visited this statement before
Expand Down Expand Up @@ -1510,6 +1558,27 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState)
nothing
end

function conditional_changes(changes::VarTable, @nospecialize(typ), var::Slot)
if typ (changes[slot_id(var)]::VarState).typ
return StateUpdate(var, VarState(typ, false), changes)
end
return changes
end

function boolean_rt_to_conditional(@nospecialize(rt), state::VarTable, slot_id::Int)
typ = widenconditional((state[slot_id]::VarState).typ) # avoid nested conditional
if isa(rt, Const)
if rt.val === true
return Conditional(SlotNumber(slot_id), typ, Bottom)
elseif rt.val === false
return Conditional(SlotNumber(slot_id), Bottom, typ)
end
elseif rt === Bool
return Conditional(SlotNumber(slot_id), typ, typ)
end
return rt
end

# make as much progress on `frame` as possible (by handling cycles)
function typeinf_nocycle(interp::AbstractInterpreter, frame::InferenceState)
typeinf_local(interp, frame)
Expand Down
15 changes: 11 additions & 4 deletions base/compiler/typeinfer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,9 @@ function CodeInstance(result::InferenceResult, @nospecialize(inferred_result::An
elseif isa(result_type, PartialStruct)
rettype_const = result_type.fields
const_flags = 0x2
elseif isa(result_type, InterConditional)
rettype_const = result_type
const_flags = 0x2
else
rettype_const = nothing
const_flags = 0x00
Expand Down Expand Up @@ -770,14 +773,18 @@ function typeinf_edge(interp::AbstractInterpreter, method::Method, @nospecialize
code = get(code_cache(interp), mi, nothing)
if code isa CodeInstance # return existing rettype if the code is already inferred
update_valid_age!(caller, WorldRange(min_world(code), max_world(code)))
rettype = code.rettype
if isdefined(code, :rettype_const)
if isa(code.rettype_const, Vector{Any}) && !(Vector{Any} <: code.rettype)
return PartialStruct(code.rettype, code.rettype_const), mi
rettype_const = code.rettype_const
if isa(rettype_const, InterConditional)
return rettype_const, mi
elseif isa(rettype_const, Vector{Any}) && !(Vector{Any} <: rettype)
return PartialStruct(rettype, rettype_const), mi
else
return Const(code.rettype_const), mi
return Const(rettype_const), mi
end
else
return code.rettype, mi
return rettype, mi
end
end
if ccall(:jl_get_module_infer, Cint, (Any,), method.module) == 0
Expand Down
40 changes: 26 additions & 14 deletions base/compiler/typelattice.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# structs/constants #
#####################

# N.B.: Const/PartialStruct are defined in Core, to allow them to be used
# N.B.: Const/PartialStruct/InterConditional are defined in Core, to allow them to be used
# inside the global code cache.
#
# # The type of a value might be constant
Expand All @@ -18,7 +18,6 @@
# end
import Core: Const, PartialStruct


# The type of this value might be Bool.
# However, to enable a limited amount of back-propagagation,
# we also keep some information about how this Bool value was created.
Expand All @@ -45,6 +44,18 @@ struct Conditional
end
end

# # Similar to `Conditional`, but conveys inter-procedural constraints imposed on call arguments.
# # This is separate from `Conditional` to catch logic errors: the lattice element name is InterConditional
# # while processing a call, then Conditional everywhere else. Thus InterConditional does not appear in
# # CompilerTypes—these type's usages are disjoint—though we define the lattice for InterConditional.
# struct InterConditional
# slot::Int
# vtype
# elsetype
# end
import Core: InterConditional
const AnyConditional = Union{Conditional,InterConditional}

struct PartialTypeVar
tv::TypeVar
# N.B.: Currently unused, but would allow turning something back
Expand Down Expand Up @@ -101,11 +112,10 @@ const CompilerTypes = Union{MaybeUndef, Const, Conditional, NotFound, PartialStr
# lattice logic #
#################

function issubconditional(a::Conditional, b::Conditional)
avar = a.var
bvar = b.var
if (isa(avar, Slot) && isa(bvar, Slot) && slot_id(avar) === slot_id(bvar)) ||
(isa(avar, SSAValue) && isa(bvar, SSAValue) && avar === bvar)
# `Conditional` and `InterConditional` are valid in opposite contexts
# (i.e. local inference and inter-procedural call), as such they will never be compared
function issubconditional(a::C, b::C) where {C<:AnyConditional}
if is_same_conditionals(a, b)
if a.vtype b.vtype
if a.elsetype b.elsetype
return true
Expand All @@ -114,9 +124,11 @@ function issubconditional(a::Conditional, b::Conditional)
end
return false
end
is_same_conditionals(a::Conditional, b::Conditional) = slot_id(a.var) === slot_id(b.var)
is_same_conditionals(a::InterConditional, b::InterConditional) = a.slot === b.slot

maybe_extract_const_bool(c::Const) = isa(c.val, Bool) ? c.val : nothing
function maybe_extract_const_bool(c::Conditional)
maybe_extract_const_bool(c::Const) = (val = c.val; isa(val, Bool)) ? val : nothing
function maybe_extract_const_bool(c::AnyConditional)
(c.vtype === Bottom && !(c.elsetype === Bottom)) && return false
(c.elsetype === Bottom && !(c.vtype === Bottom)) && return true
nothing
Expand Down Expand Up @@ -145,14 +157,14 @@ function ⊑(@nospecialize(a), @nospecialize(b))
b === Union{} && return false
@assert !isa(a, TypeVar) "invalid lattice item"
@assert !isa(b, TypeVar) "invalid lattice item"
if isa(a, Conditional)
if isa(b, Conditional)
if isa(a, AnyConditional)
if isa(b, AnyConditional)
return issubconditional(a, b)
elseif isa(b, Const) && isa(b.val, Bool)
return maybe_extract_const_bool(a) === b.val
end
a = Bool
elseif isa(b, Conditional)
elseif isa(b, AnyConditional)
return false
end
if isa(a, PartialStruct)
Expand Down Expand Up @@ -226,7 +238,7 @@ function is_lattice_equal(@nospecialize(a), @nospecialize(b))
return a b && b a
end

widenconst(c::Conditional) = Bool
widenconst(c::AnyConditional) = Bool
function widenconst(c::Const)
if isa(c.val, Type)
if isvarargtype(c.val)
Expand Down Expand Up @@ -260,7 +272,7 @@ end
@inline schanged(@nospecialize(n), @nospecialize(o)) = (n !== o) && (o === NOT_FOUND || (n !== NOT_FOUND && !issubstate(n, o)))

widenconditional(@nospecialize typ) = typ
function widenconditional(typ::Conditional)
function widenconditional(typ::AnyConditional)
if typ.vtype === Union{}
return Const(false)
elseif typ.elsetype === Union{}
Expand Down
32 changes: 31 additions & 1 deletion base/compiler/typelimits.jl
Original file line number Diff line number Diff line change
Expand Up @@ -334,7 +334,7 @@ function tmerge(@nospecialize(typea), @nospecialize(typeb))
end
end
if isa(typea, Conditional) && isa(typeb, Conditional)
if typea.var === typeb.var
if is_same_conditionals(typea, typeb)
vtype = tmerge(typea.vtype, typeb.vtype)
elsetype = tmerge(typea.elsetype, typeb.elsetype)
if vtype != elsetype
Expand All @@ -347,6 +347,36 @@ function tmerge(@nospecialize(typea), @nospecialize(typeb))
end
return Bool
end
# type-lattice for InterConditional wrapper, InterConditional will never be merged with Conditional
if isa(typea, InterConditional) && isa(typeb, Const)
if typeb.val === true
typeb = InterConditional(typea.slot, Any, Union{})
elseif typeb.val === false
typeb = InterConditional(typea.slot, Union{}, Any)
end
end
if isa(typeb, InterConditional) && isa(typea, Const)
if typea.val === true
typea = InterConditional(typeb.slot, Any, Union{})
elseif typea.val === false
typea = InterConditional(typeb.slot, Union{}, Any)
end
end
if isa(typea, InterConditional) && isa(typeb, InterConditional)
if is_same_conditionals(typea, typeb)
vtype = tmerge(typea.vtype, typeb.vtype)
elsetype = tmerge(typea.elsetype, typeb.elsetype)
if vtype != elsetype
return InterConditional(typea.slot, vtype, elsetype)
end
end
val = maybe_extract_const_bool(typea)
if val isa Bool && val === maybe_extract_const_bool(typeb)
return Const(val)
end
return Bool
end
# type-lattice for Const and PartialStruct wrappers
if (isa(typea, PartialStruct) || isa(typea, Const)) &&
(isa(typeb, PartialStruct) || isa(typeb, Const)) &&
widenconst(typea) === widenconst(typeb)
Expand Down
1 change: 1 addition & 0 deletions src/builtins.c
Original file line number Diff line number Diff line change
Expand Up @@ -1611,6 +1611,7 @@ void jl_init_primitives(void) JL_GC_DISABLED
add_builtin("Argument", (jl_value_t*)jl_argument_type);
add_builtin("Const", (jl_value_t*)jl_const_type);
add_builtin("PartialStruct", (jl_value_t*)jl_partial_struct_type);
add_builtin("InterConditional", (jl_value_t*)jl_interconditional_type);
add_builtin("MethodMatch", (jl_value_t*)jl_method_match_type);
add_builtin("IntrinsicFunction", (jl_value_t*)jl_intrinsic_type);
add_builtin("Function", (jl_value_t*)jl_function_type);
Expand Down
1 change: 1 addition & 0 deletions src/jl_exported_data.inc
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@
XX(jl_nothing_type) \
XX(jl_number_type) \
XX(jl_partial_struct_type) \
XX(jl_interconditional_type) \
XX(jl_phicnode_type) \
XX(jl_phinode_type) \
XX(jl_pinode_type) \
Expand Down
4 changes: 4 additions & 0 deletions src/jltypes.c
Original file line number Diff line number Diff line change
Expand Up @@ -2308,6 +2308,10 @@ void jl_init_types(void) JL_GC_DISABLED
jl_perm_symsvec(2, "typ", "fields"),
jl_svec2(jl_any_type, jl_array_any_type), 0, 0, 2);

jl_interconditional_type = jl_new_datatype(jl_symbol("InterConditional"), core, jl_any_type, jl_emptysvec,
jl_perm_symsvec(3, "slot", "vtype", "elsetype"),
jl_svec(3, jl_long_type, jl_any_type, jl_any_type), 0, 0, 3);

jl_method_match_type = jl_new_datatype(jl_symbol("MethodMatch"), core, jl_any_type, jl_emptysvec,
jl_perm_symsvec(4, "spec_types", "sparams", "method", "fully_covers"),
jl_svec(4, jl_type_type, jl_simplevector_type, jl_method_type, jl_bool_type), 0, 0, 4);
Expand Down
1 change: 1 addition & 0 deletions src/julia.h
Original file line number Diff line number Diff line change
Expand Up @@ -633,6 +633,7 @@ extern JL_DLLIMPORT jl_datatype_t *jl_typedslot_type JL_GLOBALLY_ROOTED;
extern JL_DLLIMPORT jl_datatype_t *jl_argument_type JL_GLOBALLY_ROOTED;
extern JL_DLLIMPORT jl_datatype_t *jl_const_type JL_GLOBALLY_ROOTED;
extern JL_DLLIMPORT jl_datatype_t *jl_partial_struct_type JL_GLOBALLY_ROOTED;
extern JL_DLLIMPORT jl_datatype_t *jl_interconditional_type JL_GLOBALLY_ROOTED;
extern JL_DLLIMPORT jl_datatype_t *jl_method_match_type JL_GLOBALLY_ROOTED;
extern JL_DLLIMPORT jl_datatype_t *jl_simplevector_type JL_GLOBALLY_ROOTED;
extern JL_DLLIMPORT jl_typename_t *jl_tuple_typename JL_GLOBALLY_ROOTED;
Expand Down
1 change: 1 addition & 0 deletions src/staticdata.c
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ jl_value_t **const*const get_tags(void) {
INSERT_TAG(jl_returnnode_type);
INSERT_TAG(jl_const_type);
INSERT_TAG(jl_partial_struct_type);
INSERT_TAG(jl_interconditional_type);
INSERT_TAG(jl_method_match_type);
INSERT_TAG(jl_pinode_type);
INSERT_TAG(jl_phinode_type);
Expand Down
Loading

0 comments on commit f06ab09

Please sign in to comment.