Skip to content

Commit

Permalink
inference: apply tmerge limit elementwise to the Union
Browse files Browse the repository at this point in the history
This allows forming larger unions, as long as each element in the Union
is both relatively distinct and relatively simple. For example:

    tmerge(Base.BitSigned, Nothing) == Union{Nothing, Int128, Int16, Int32, Int64, Int8}
    tmerge(Tuple{Base.BitSigned, Int}, Nothing) == Union{Nothing, Tuple{Any, Int64}}
    tmerge(AbstractVector{Int}, Vector) == AbstractVector

Disables a test from dc8d885.

Co-authored-by: Oscar Smith <oscardssmith@gmail.com>
  • Loading branch information
2 people authored and vtjnash committed Sep 12, 2023
1 parent 832e46d commit f4ac759
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 38 deletions.
69 changes: 51 additions & 18 deletions base/compiler/typelimits.jl
Original file line number Diff line number Diff line change
Expand Up @@ -692,6 +692,33 @@ end
return tmerge_types_slow(typea, typeb)
end

@nospecializeinfer @noinline function tname_intersect(aname::Core.TypeName, bname::Core.TypeName)
aname === bname && return aname
if !isabstracttype(aname.wrapper) && !isabstracttype(bname.wrapper)
return nothing # fast path
end
Any.name === aname && return aname
a = unwrap_unionall(aname.wrapper)
heighta = 0
while a !== Any
heighta += 1
a = a.super
end
b = unwrap_unionall(bname.wrapper)
heightb = 0
while b !== Any
b.name === aname && return aname
heightb += 1
b = b.super
end
a = unwrap_unionall(aname.wrapper)
while heighta > heightb
a = a.super
heighta -= 1
end
return a.name === bname ? bname : nothing
end

@nospecializeinfer @noinline function tmerge_types_slow(@nospecialize(typea::Type), @nospecialize(typeb::Type))
# collect the list of types from past tmerge calls returning Union
# and then reduce over that list
Expand All @@ -716,9 +743,12 @@ end
# in which case, simplify this tmerge by replacing it with
# the widest possible version of itself (the wrapper)
for i in 1:length(types)
typenames[i] === Any.name && continue
ti = types[i]
for j in (i + 1):length(types)
if typenames[i] === typenames[j]
typenames[j] === Any.name && continue
ijname = tname_intersect(typenames[i], typenames[j])
if !(ijname === nothing)
tj = types[j]
if ti <: tj
types[i] = Union{}
Expand All @@ -728,27 +758,33 @@ end
types[j] = Union{}
typenames[j] = Any.name
else
if typenames[i] === Tuple.name
if ijname === Tuple.name
# try to widen Tuple slower: make a single non-concrete Tuple containing both
# converge the Tuple element-wise if they are the same length
# see 4ee2b41552a6bc95465c12ca66146d69b354317b, be59686f7613a2ccfd63491c7b354d0b16a95c05,
widen = tuplemerge(unwrap_unionall(ti)::DataType, unwrap_unionall(tj)::DataType)
widen = rewrap_unionall(rewrap_unionall(widen, ti), tj)
else
wr = typenames[i].wrapper
wr = ijname.wrapper
uw = unwrap_unionall(wr)::DataType
ui = unwrap_unionall(ti)::DataType
while ui.name !== ijname
ui = ui.super
end
uj = unwrap_unionall(tj)::DataType
merged = wr
while uj.name !== ijname
uj = uj.super
end
merged = Vector{Any}(undef, length(uw.parameters))
for k = 1:length(uw.parameters)
ui_k = ui.parameters[k]
if ui_k === uj.parameters[k] && !has_free_typevars(ui_k)
merged = merged{ui_k}
merged[k] = ui_k
else
merged = merged{uw.parameters[k]}
merged[k] = uw.parameters[k]
end
end
widen = rewrap_unionall(merged, wr)
widen = rewrap_unionall(wr{merged...}, wr)
end
types[i] = Union{}
typenames[i] = Any.name
Expand All @@ -758,31 +794,28 @@ end
end
end
end
u = Union{types...}
# don't let type unions get too big, if the above didn't reduce it enough
if issimpleenoughtype(u)
return u
end
# don't let the slow widening of Tuple cause the whole type to grow too fast
# don't let elements of the union get too big, if the above didn't reduce something
# Specifically widen Tuple{..., Union{lots of stuff}...} to Tuple{..., Any, ...}
for i in 1:length(types)
# this element is too complicated, so
# just return the widest possible type now
issimpleenoughtype(types[i]) && continue
if typenames[i] === Tuple.name
ti = types[i]
tip = (unwrap_unionall(types[i])::DataType).parameters
lt = length(tip)
p = Vector{Any}(undef, lt)
for j = 1:lt
ui = tip[j]
p[j] = (unioncomplexity(ui)==0) ? ui : isvarargtype(ui) ? Vararg : Any
p[j] = issimpleenoughtype(unwrapva(ui)) ? ui : isvarargtype(ui) ? Vararg : Any
end
types[i] = rewrap_unionall(Tuple{p...}, ti)
else
issimpleenoughtype(types[i]) || return Any
end
end
u = Union{types...}
if issimpleenoughtype(u)
return u
end
return Any
return u
end

# the inverse of switchtupleunion, with limits on max element union size
Expand Down
47 changes: 27 additions & 20 deletions test/compiler/inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -164,40 +164,40 @@ Base.ndims(g::e43296) = ndims(typeof(g))


# PR 22120
function tmerge_test(a, b, r, commutative=true)
function tuplemerge_test(a, b, r, commutative=true)
@test r == Core.Compiler.tuplemerge(a, b)
if commutative
@test r == Core.Compiler.tuplemerge(b, a)
else
@test_broken r == Core.Compiler.tuplemerge(b, a)
end
end
tmerge_test(Tuple{Int}, Tuple{String}, Tuple{Union{Int, String}})
tmerge_test(Tuple{Int}, Tuple{String, String}, Tuple)
tmerge_test(Tuple{Vararg{Int}}, Tuple{String}, Tuple)
tmerge_test(Tuple{Int}, Tuple{Int, Int},
tuplemerge_test(Tuple{Int}, Tuple{String}, Tuple{Union{Int, String}})
tuplemerge_test(Tuple{Int}, Tuple{String, String}, Tuple)
tuplemerge_test(Tuple{Vararg{Int}}, Tuple{String}, Tuple)
tuplemerge_test(Tuple{Int}, Tuple{Int, Int},
Tuple{Vararg{Int}})
tmerge_test(Tuple{Integer}, Tuple{Int, Int},
tuplemerge_test(Tuple{Integer}, Tuple{Int, Int},
Tuple{Vararg{Integer}})
tmerge_test(Tuple{}, Tuple{Int, Int},
tuplemerge_test(Tuple{}, Tuple{Int, Int},
Tuple{Vararg{Int}})
tmerge_test(Tuple{}, Tuple{Complex},
tuplemerge_test(Tuple{}, Tuple{Complex},
Tuple{Vararg{Complex}})
tmerge_test(Tuple{ComplexF32}, Tuple{ComplexF32, ComplexF64},
tuplemerge_test(Tuple{ComplexF32}, Tuple{ComplexF32, ComplexF64},
Tuple{Vararg{Complex}})
tmerge_test(Tuple{Vararg{ComplexF32}}, Tuple{Vararg{ComplexF64}},
tuplemerge_test(Tuple{Vararg{ComplexF32}}, Tuple{Vararg{ComplexF64}},
Tuple{Vararg{Complex}})
tmerge_test(Tuple{}, Tuple{ComplexF32, Vararg{Union{ComplexF32, ComplexF64}}},
tuplemerge_test(Tuple{}, Tuple{ComplexF32, Vararg{Union{ComplexF32, ComplexF64}}},
Tuple{Vararg{Union{ComplexF32, ComplexF64}}})
tmerge_test(Tuple{ComplexF32}, Tuple{ComplexF32, Vararg{Union{ComplexF32, ComplexF64}}},
tuplemerge_test(Tuple{ComplexF32}, Tuple{ComplexF32, Vararg{Union{ComplexF32, ComplexF64}}},
Tuple{Vararg{Union{ComplexF32, ComplexF64}}})
tmerge_test(Tuple{ComplexF32, ComplexF32, ComplexF32}, Tuple{ComplexF32, Vararg{Union{ComplexF32, ComplexF64}}},
tuplemerge_test(Tuple{ComplexF32, ComplexF32, ComplexF32}, Tuple{ComplexF32, Vararg{Union{ComplexF32, ComplexF64}}},
Tuple{Vararg{Union{ComplexF32, ComplexF64}}})
tmerge_test(Tuple{}, Tuple{Union{ComplexF64, ComplexF32}, Vararg{Union{ComplexF32, ComplexF64}}},
tuplemerge_test(Tuple{}, Tuple{Union{ComplexF64, ComplexF32}, Vararg{Union{ComplexF32, ComplexF64}}},
Tuple{Vararg{Union{ComplexF32, ComplexF64}}})
tmerge_test(Tuple{ComplexF64, ComplexF64, ComplexF32}, Tuple{Vararg{Union{ComplexF32, ComplexF64}}},
tuplemerge_test(Tuple{ComplexF64, ComplexF64, ComplexF32}, Tuple{Vararg{Union{ComplexF32, ComplexF64}}},
Tuple{Vararg{Complex}}, false)
tmerge_test(Tuple{}, Tuple{Complex, Vararg{Union{ComplexF32, ComplexF64}}},
tuplemerge_test(Tuple{}, Tuple{Complex, Vararg{Union{ComplexF32, ComplexF64}}},
Tuple{Vararg{Complex}})
@test Core.Compiler.tmerge(Tuple{}, Union{Nothing, Tuple{ComplexF32, ComplexF32}}) ==
Union{Nothing, Tuple{}, Tuple{ComplexF32, ComplexF32}}
Expand All @@ -215,8 +215,15 @@ tmerge_test(Tuple{}, Tuple{Complex, Vararg{Union{ComplexF32, ComplexF64}}},
@test Core.Compiler.tmerge(Core.Compiler.fallback_ipo_lattice, Core.Compiler.InterConditional(1, Int, Union{}), Core.Compiler.InterConditional(2, String, Union{})) === Core.Compiler.Const(true)
# test issue behind https://github.com/JuliaLang/julia/issues/50458
@test Core.Compiler.tmerge(Nothing, Tuple{Base.BitInteger, Int}) == Union{Nothing, Tuple{Any, Int}}
@test Core.Compiler.tmerge(Nothing, Tuple{Union{Char, String, SubString{String}, Symbol}, Int}) == Union{Nothing, Tuple{Any, Int}}
@test Core.Compiler.tmerge(Nothing, Tuple{Union{Char, String, SubString{String}, Symbol}, Int}) == Union{Nothing, Tuple{Union{Char, String, SubString{String}, Symbol}, Int}}
@test Core.Compiler.tmerge(Nothing, Tuple{Integer, Int}) == Union{Nothing, Tuple{Integer, Int}}
@test Core.Compiler.tmerge(Union{Nothing, AbstractVector{Int}}, Vector) == Union{Nothing, AbstractVector}
@test Core.Compiler.tmerge(Union{Nothing, AbstractVector{Int}}, Matrix) == Union{Nothing, AbstractArray}
@test Core.Compiler.tmerge(Union{Nothing, AbstractVector{Int}}, Matrix{Int}) == Union{Nothing, AbstractArray{Int}}
@test Core.Compiler.tmerge(Union{Nothing, AbstractVector{Int}}, Array) == Union{Nothing, AbstractArray}
@test Core.Compiler.tmerge(Union{Nothing, AbstractArray{Int}}, Vector) == Union{Nothing, AbstractArray}
@test Core.Compiler.tmerge(Union{Nothing, AbstractVector}, Matrix{Int}) == Union{Nothing, AbstractArray}
@test Core.Compiler.tmerge(Union{Nothing, AbstractFloat}, Integer) == Union{Nothing, AbstractFloat, Integer}

# test that recursively more complicated types don't widen all the way to Any when there is a useful valid type upper bound
# Specificially test with base types of a trivial type, a simple union, a complicated union, and a tuple.
Expand Down Expand Up @@ -2886,7 +2893,7 @@ end
# issue #27316 - inference shouldn't hang on these
f27316(::Vector) = nothing
f27316(::Any) = f27316(Any[][1]), f27316(Any[][1])
let expected = NTuple{2, Union{Nothing, NTuple{2, Union{Nothing, Tuple{Any, Any}}}}}
let expected = NTuple{2, Union{Nothing, Tuple{Any, Any}}}
@test Tuple{Nothing, Nothing} <: only(Base.return_types(f27316, Tuple{Int})) == expected # we may be able to improve this bound in the future
end
function g27316()
Expand Down Expand Up @@ -3501,8 +3508,8 @@ function pickvarnames(x::Vector{Any})
end
@test pickvarnames(:a) === :a
@test pickvarnames(Any[:a, :b]) === (:a, :b)
@test only(Base.return_types(pickvarnames, (Vector{Any},))) == Tuple{Vararg{Union{Symbol, Tuple}}}
@test only(Base.code_typed(pickvarnames, (Vector{Any},), optimize=false))[2] == Tuple{Vararg{Union{Symbol, Tuple{Vararg{Union{Symbol, Tuple}}}}}}
@test only(Base.return_types(pickvarnames, (Vector{Any},))) == Tuple
@test only(Base.code_typed(pickvarnames, (Vector{Any},), optimize=false))[2] == Tuple{Vararg{Union{Symbol, Tuple}}}

@test map(>:, [Int], [Int]) == [true]

Expand Down

0 comments on commit f4ac759

Please sign in to comment.