Skip to content

Commit

Permalink
inference: form PartialStruct for extra type information propagation
Browse files Browse the repository at this point in the history
This commit forms `PartialStruct` whenever there is any type-level
refinement available about a field, even if it's not "constant" information.

In Julia "definitions" are allowed to be abstract whereas "usages"
(i.e. callsites) are often concrete. The basic idea is to allow inference
to make more use of such precise callsite type information by encoding it
as `PartialStruct`.

This may increase optimization possibilities of "unidiomatic" Julia code,
which may contain poorly-typed definitions, like this very contrived example:
```julia
struct Problem
    n; s; c; t
end

function main(args...)
    prob = Problem(args...)
    s = 0
    for i in 1:prob.n
        m = mod(i, 3)
        s += m == 0 ? sin(prob.s) : m == 1 ? cos(prob.c) : tan(prob.t)
    end
    return prob, s
end

main(10000, 1, 2, 3)
```

One of the obvious limitation is that this extra type information can be
propagated inter-procedurally only as a const-propagation.
I'm not sure this kind of "just a type-level" refinement can often make
constant-prop' successful (i.e. shape-up a method body and allow it to
be inlined, encoding the extra type information into the generated code),
thus I didn't not modify any part of const-prop' heuristics.

So the improvements from this change is almost for local analysis,
and for very simple inter-procedural calls.
  • Loading branch information
aviatesk committed Oct 29, 2021
1 parent c054dbc commit ab23760
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 29 deletions.
28 changes: 15 additions & 13 deletions base/compiler/abstractinterpretation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -192,9 +192,9 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
else
elsetype = tmeet(elsetype, widenconst(new_elsetype))
end
if (slot > 0 || condval !== false) && !(old vtype) # essentially vtype ⋤ old
if (slot > 0 || condval !== false) && vtype old
slot = id
elseif (slot > 0 || condval !== true) && !(old elsetype) # essentially elsetype ⋤ old
elseif (slot > 0 || condval !== true) && elsetype old
slot = id
else # reset: no new useful information for this slot
vtype = elsetype = Any
Expand Down Expand Up @@ -1542,22 +1542,23 @@ function abstract_eval_statement(interp::AbstractInterpreter, @nospecialize(e),
if isconcretetype(t) && !ismutabletype(t)
args = Vector{Any}(undef, length(e.args)-1)
ats = Vector{Any}(undef, length(e.args)-1)
anyconst = false
allconst = true
local anyrefine = false
local allconst = true
for i = 2:length(e.args)
at = widenconditional(abstract_eval_value(interp, e.args[i], vtypes, sv))
if !anyconst
anyconst = has_nontrivial_const_info(at)
if !anyrefine
anyrefine = has_nontrivial_const_info(at) || # constant information
at fieldtype(t, i - 1) # just a type-level information, but more precise than the declared type
end
ats[i-1] = at
if at === Bottom
t = Bottom
allconst = anyconst = false
anyrefine = allconst = false
break
elseif at isa Const
if !(at.val isa fieldtype(t, i - 1))
t = Bottom
allconst = anyconst = false
anyrefine = allconst = false
break
end
args[i-1] = at.val
Expand All @@ -1569,7 +1570,7 @@ function abstract_eval_statement(interp::AbstractInterpreter, @nospecialize(e),
if t !== Bottom && fieldcount(t) == length(ats)
if allconst
t = Const(ccall(:jl_new_structv, Any, (Any, Ptr{Cvoid}, UInt32), t, args, length(args)))
elseif anyconst
elseif anyrefine
t = PartialStruct(t, ats)
end
end
Expand Down Expand Up @@ -1741,17 +1742,18 @@ function widenreturn(@nospecialize(rt), @nospecialize(bestguess), nslots::Int, s
isa(rt, Type) && return rt
if isa(rt, PartialStruct)
fields = copy(rt.fields)
haveconst = false
local anyrefine = false
for i in 1:length(fields)
a = fields[i]
a = isvarargtype(a) ? a : widenreturn(a, bestguess, nslots, slottypes, changes)
if !haveconst && has_const_info(a)
if !anyrefine
# TODO: consider adding && const_prop_profitable(a) here?
haveconst = true
anyrefine = has_const_info(a) ||
a fieldtype(rt.typ, i)
end
fields[i] = a
end
haveconst && return PartialStruct(rt.typ, fields)
anyrefine && return PartialStruct(rt.typ, fields)
end
if isa(rt, PartialOpaque)
return rt # XXX: this case was missed in #39512
Expand Down
1 change: 1 addition & 0 deletions base/compiler/typelattice.jl
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,7 @@ function ⊑(@nospecialize(a), @nospecialize(b))
return a === b
end
end
(@nospecialize(a), @nospecialize(b)) = !(b, a)

# Check if two lattice elements are partial order equivalent. This is basically
# `a ⊑ b && b ⊑ a` but with extra performance optimizations.
Expand Down
23 changes: 23 additions & 0 deletions test/compiler/inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3645,3 +3645,26 @@ end

# issue #42646
@test only(Base.return_types(getindex, (Array{undef}, Int))) >: Union{} # check that it does not throw

# form PartialStruct for extra type information propagation
struct FieldTypeRefinement{S,T}
s::S
t::T
end
@test Base.return_types((Int,)) do s
o = FieldTypeRefinement{Any,Int}(s, s)
o.s
end |> only == Int
@test Base.return_types((Int,)) do s
o = FieldTypeRefinement{Int,Any}(s, s)
o.t
end |> only == Int
@test Base.return_types((Int,)) do s
o = FieldTypeRefinement{Any,Any}(s, s)
o.s, o.t
end |> only == Tuple{Int,Int}
@test Base.return_types((Int,)) do a
s1 = Some{Any}(a)
s2 = Some{Any}(s1)
s2.value.value
end |> only == Int
23 changes: 7 additions & 16 deletions test/compiler/irpasses.jl
Original file line number Diff line number Diff line change
Expand Up @@ -426,31 +426,22 @@ let # `getfield_elim_pass!` should work with constant globals
end
end

let # `typeassert_elim_pass!`
let
# `typeassert` elimination after SROA
# NOTE we can remove this optimization once inference is able to reason about memory-effects
src = @eval Module() begin
struct Foo; x; end
mutable struct Foo; x; end

code_typed((Int,)) do a
x1 = Foo(a)
x2 = Foo(x1)
x3 = Foo(x2)

r1 = (x2.x::Foo).x
r2 = (x2.x::Foo).x::Int
r3 = (x2.x::Foo).x::Integer
r4 = ((x3.x::Foo).x::Foo).x

return r1, r2, r3, r4
return typeassert(x2.x, Foo).x
end |> only |> first
end
# eliminate `typeassert(f2.a, Foo)`
@test all(src.code) do @nospecialize(stmt)
# eliminate `typeassert(x2.x, Foo)`
@test all(src.code) do @nospecialize stmt
Meta.isexpr(stmt, :call) || return true
ft = Core.Compiler.argextype(stmt.args[1], src, Any[], src.slottypes)
return Core.Compiler.widenconst(ft) !== typeof(typeassert)
end
# succeeding simple DCE will eliminate `Foo(a)`
@test all(src.code) do @nospecialize(stmt)
return !Meta.isexpr(stmt, :new)
end
end

0 comments on commit ab23760

Please sign in to comment.