Skip to content

Commit

Permalink
inference: refine PartialStruct lattice tmerge
Browse files Browse the repository at this point in the history
Be more aggressive about merging fields to greatly accelerate
convergence, but also compute anyrefine more correctly as we do now
elsewhere (since #42831, a121721)

Move the tmeet algorithm, without changes, since it is a precise lattice
operation, not a heuristic limit like tmerge.

Close #43784
  • Loading branch information
vtjnash committed Mar 11, 2022
1 parent 482d4f6 commit b8182c3
Show file tree
Hide file tree
Showing 3 changed files with 95 additions and 53 deletions.
42 changes: 42 additions & 0 deletions base/compiler/typelattice.jl
Original file line number Diff line number Diff line change
Expand Up @@ -469,3 +469,45 @@ function stupdate1!(state::VarTable, change::StateUpdate)
end
return false
end

# compute typeintersect over the extended inference lattice,
# as precisely as we can,
# where v is in the extended lattice, and t is a Type.
function tmeet(@nospecialize(v), @nospecialize(t))
if isa(v, Const)
if !has_free_typevars(t) && !isa(v.val, t)
return Bottom
end
return v
elseif isa(v, PartialStruct)
has_free_typevars(t) && return v
widev = widenconst(v)
if widev <: t
return v
end
ti = typeintersect(widev, t)
valid_as_lattice(ti) || return Bottom
@assert widev <: Tuple
new_fields = Vector{Any}(undef, length(v.fields))
for i = 1:length(new_fields)
vfi = v.fields[i]
if isvarargtype(vfi)
new_fields[i] = vfi
else
new_fields[i] = tmeet(vfi, widenconst(getfield_tfunc(t, Const(i))))
if new_fields[i] === Bottom
return Bottom
end
end
end
return tuple_tfunc(new_fields)
elseif isa(v, Conditional)
if !(Bool <: t)
return Bottom
end
return v
end
ti = typeintersect(widenconst(v), t)
valid_as_lattice(ti) || return Bottom
return ti
end
78 changes: 30 additions & 48 deletions base/compiler/typelimits.jl
Original file line number Diff line number Diff line change
Expand Up @@ -452,6 +452,7 @@ function tmerge(@nospecialize(typea), @nospecialize(typeb))
aty = widenconst(typea)
bty = widenconst(typeb)
if aty === bty
# must have egal here, since we do not create PartialStruct for non-concrete types
typea_nfields = nfields_tfunc(typea)
typeb_nfields = nfields_tfunc(typeb)
isa(typea_nfields, Const) || return aty
Expand All @@ -460,18 +461,40 @@ function tmerge(@nospecialize(typea), @nospecialize(typeb))
type_nfields === typeb_nfields.val::Int || return aty
type_nfields == 0 && return aty
fields = Vector{Any}(undef, type_nfields)
anyconst = false
anyrefine = false
for i = 1:type_nfields
ai = getfield_tfunc(typea, Const(i))
bi = getfield_tfunc(typeb, Const(i))
ity = tmerge(ai, bi)
if ai === Union{} || bi === Union{}
ity = widenconst(ity)
ft = fieldtype(aty, i)
if is_lattice_equal(ai, bi) || is_lattice_equal(ai, ft)
# Since ai===bi, the given type has no restrictions on complexity.
# and can be used to refine ft
tyi = ai
elseif is_lattice_equal(bi, ft)
tyi = bi
else
# Otherwise choose between using the fieldtype or some other simple merged type.
# The wrapper type never has restrictions on complexity,
# so try to use that to refine the estimated type too.
tni = _typename(widenconst(ai))
if tni isa Const && tni === _typename(widenconst(bi))
# A tmeet call may cause tyi to become complex, but since the inputs were
# strictly limited to being egal, this has no restrictions on complexity.
# (Otherwise, we would need to use <: and take the narrower one without
# intersection. See the similar comment in abstract_call_method.)
tyi = typeintersect(ft, (tni.val::Core.TypeName).wrapper)
else
# Since aty===bty, the fieldtype has no restrictions on complexity.
tyi = ft
end
end
fields[i] = tyi
if !anyrefine
anyrefine = has_nontrivial_const_info(tyi) || # constant information
tyi ft # just a type-level information, but more precise than the declared type
end
fields[i] = ity
anyconst |= has_nontrivial_const_info(ity)
end
return anyconst ? PartialStruct(aty, fields) : aty
return anyrefine ? PartialStruct(aty, fields) : aty
end
end
if isa(typea, PartialOpaque) && isa(typeb, PartialOpaque) && widenconst(typea) == widenconst(typeb)
Expand Down Expand Up @@ -658,44 +681,3 @@ function tuplemerge(a::DataType, b::DataType)
end
return Tuple{p...}
end

# compute typeintersect over the extended inference lattice
# where v is in the extended lattice, and t is a Type
function tmeet(@nospecialize(v), @nospecialize(t))
if isa(v, Const)
if !has_free_typevars(t) && !isa(v.val, t)
return Bottom
end
return v
elseif isa(v, PartialStruct)
has_free_typevars(t) && return v
widev = widenconst(v)
if widev <: t
return v
end
ti = typeintersect(widev, t)
valid_as_lattice(ti) || return Bottom
@assert widev <: Tuple
new_fields = Vector{Any}(undef, length(v.fields))
for i = 1:length(new_fields)
vfi = v.fields[i]
if isvarargtype(vfi)
new_fields[i] = vfi
else
new_fields[i] = tmeet(vfi, widenconst(getfield_tfunc(t, Const(i))))
if new_fields[i] === Bottom
return Bottom
end
end
end
return tuple_tfunc(new_fields)
elseif isa(v, Conditional)
if !(Bool <: t)
return Bottom
end
return v
end
ti = typeintersect(widenconst(v), t)
valid_as_lattice(ti) || return Bottom
return ti
end
28 changes: 23 additions & 5 deletions test/compiler/inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3999,15 +3999,33 @@ end
@test (a, c)
@test (b, c)

@test @eval Module() begin
const ginit = Base.ImmutableDict{Any,Any}()
Base.return_types() do
g = ginit
init = Base.ImmutableDict{Number,Number}()
a = Const(init)
b = Core.PartialStruct(typeof(init), Any[Const(init), Any, ComplexF64])
c = Core.Compiler.tmerge(a, b)
@test (a, c) && (b, c)
@test c === typeof(init)

a = Core.PartialStruct(typeof(init), Any[Const(init), ComplexF64, ComplexF64])
c = Core.Compiler.tmerge(a, b)
@test (a, c) && (b, c)
@test c.fields[2] === Any # or Number
@test c.fields[3] === ComplexF64

b = Core.PartialStruct(typeof(init), Any[Const(init), ComplexF32, Union{ComplexF32,ComplexF64}])
c = Core.Compiler.tmerge(a, b)
@test (a, c)
@test (b, c)
@test c.fields[2] === Complex
@test c.fields[3] === Complex

global const ginit43784 = Base.ImmutableDict{Any,Any}()
@test Base.return_types() do
g = ginit43784
while true
g = Base.ImmutableDict(g, 1=>2)
end
end |> only === Union{}
end
end

# Test that purity modeling doesn't accidentally introduce new world age issues
Expand Down

0 comments on commit b8182c3

Please sign in to comment.