Skip to content

Commit

Permalink
Union member type info (#1927)
Browse files Browse the repository at this point in the history
* Union member type info

* fix

* fix
  • Loading branch information
wsmoses authored Sep 30, 2024
1 parent 288a419 commit dad67bf
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 11 deletions.
7 changes: 6 additions & 1 deletion src/typetree.jl
Original file line number Diff line number Diff line change
Expand Up @@ -353,12 +353,17 @@ function typetree_inner(@nospecialize(T::Type), ctx, dl, seen::TypeTreeTable)
for f = 1:fieldcount(T)
offset = fieldoffset(T, f)
subT = fieldtype(T, f)
subtree = copy(typetree(subT, ctx, dl, seen))

if subT isa UnionAll || subT isa Union || subT == Union{}
if !allocatedinline(subT)
subtree = TypeTree(API.DT_Pointer, offset, ctx)
merge!(tt, subtree)
end
# FIXME: Handle union
continue
end

subtree = copy(typetree(subT, ctx, dl, seen))

# Allocated inline so adjust first path
if allocatedinline(subT)
Expand Down
35 changes: 25 additions & 10 deletions test/typetree.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,12 @@ struct Sibling2{T}
b::T
end

struct UnionMember
a::Float32
b::Union{Function, Number}
c::Bool
end

@testset "TypeTree" begin
@test tt(Float16) == "{[-1]:Float@half}"
@test tt(Float32) == "{[-1]:Float@float}"
Expand All @@ -55,28 +61,31 @@ end
@test at2.z == 0.0
@test at2.type == 4


if Sys.WORD_SIZE == 64
@test tt(LList2{Float64}) == "{[8]:Float@double}"
@test tt(Sibling{LList2{Float64}}) == "{[-1]:Pointer, [-1,8]:Float@double}"
@test tt(UnionMember) == "{[0]:Float@float, [8]:Pointer, [16]:Integer}"
@test tt(LList2{Float64}) == "{[0]:Pointer, [8]:Float@double}"
@test tt(Sibling{LList2{Float64}}) == "{[-1]:Pointer, [-1,0]:Pointer, [-1,8]:Float@double}"
@test tt(Sibling2{LList2{Float64}}) ==
"{[0]:Pointer, [0,8]:Float@double, [8]:Integer, [16]:Pointer, [16,8]:Float@double}"
"{[0]:Pointer, [0,0]:Pointer, [0,8]:Float@double, [8]:Integer, [16]:Pointer, [16,0]:Pointer, [16,8]:Float@double}"
@test tt(Sibling{Tuple{Int,Float64}}) ==
"{[0]:Integer, [1]:Integer, [2]:Integer, [3]:Integer, [4]:Integer, [5]:Integer, [6]:Integer, [7]:Integer, [8]:Float@double, [16]:Integer, [17]:Integer, [18]:Integer, [19]:Integer, [20]:Integer, [21]:Integer, [22]:Integer, [23]:Integer, [24]:Float@double}"
@test tt(Sibling{LList2{Tuple{Int,Float64}}}) ==
"{[-1]:Pointer, [-1,8]:Integer, [-1,9]:Integer, [-1,10]:Integer, [-1,11]:Integer, [-1,12]:Integer, [-1,13]:Integer, [-1,14]:Integer, [-1,15]:Integer, [-1,16]:Float@double}"
"{[-1]:Pointer, [-1,0]:Pointer, [-1,8]:Integer, [-1,9]:Integer, [-1,10]:Integer, [-1,11]:Integer, [-1,12]:Integer, [-1,13]:Integer, [-1,14]:Integer, [-1,15]:Integer, [-1,16]:Float@double}"
@test tt(Sibling2{Sibling2{LList2{Tuple{Float32,Float64}}}}) ==
"{[0]:Pointer, [0,8]:Float@float, [0,16]:Float@double, [8]:Integer, [16]:Pointer, [16,8]:Float@float, [16,16]:Float@double, [24]:Integer, [32]:Pointer, [32,8]:Float@float, [32,16]:Float@double, [40]:Integer, [48]:Pointer, [48,8]:Float@float, [48,16]:Float@double}"
"{[0]:Pointer, [0,0]:Pointer, [0,8]:Float@float, [0,16]:Float@double, [8]:Integer, [16]:Pointer, [16,0]:Pointer, [16,8]:Float@float, [16,16]:Float@double, [24]:Integer, [32]:Pointer, [32,0]:Pointer, [32,8]:Float@float, [32,16]:Float@double, [40]:Integer, [48]:Pointer, [48,0]:Pointer, [48,8]:Float@float, [48,16]:Float@double}"
else
@test tt(LList2{Float64}) == "{[4]:Float@double}"
@test tt(Sibling{LList2{Float64}}) == "{[-1]:Pointer, [-1,4]:Float@double}"
@test tt(UnionMember) == "{[0]:Float@float, [4]:Pointer, [8]:Integer}"
@test tt(LList2{Float64}) == "{[0]:Pointer, [4]:Float@double}"
@test tt(Sibling{LList2{Float64}}) == "{[-1]:Pointer, [-1,0]:Pointer, [-1,4]:Float@double}"
@test tt(Sibling2{LList2{Float64}}) ==
"{[0]:Pointer, [0,4]:Float@double, [4]:Integer, [8]:Pointer, [8,4]:Float@double}"
"{[0]:Pointer, [0,0]:Pointer, [0,4]:Float@double, [4]:Integer, [8]:Pointer, [8,0]:Pointer, [8,4]:Float@double}"
@test tt(Sibling{Tuple{Int,Float64}}) ==
"{[0]:Integer, [1]:Integer, [2]:Integer, [3]:Integer, [4]:Float@double, [12]:Integer, [13]:Integer, [14]:Integer, [15]:Integer, [16]:Float@double}"
@test tt(Sibling{LList2{Tuple{Int,Float64}}}) ==
"{[-1]:Pointer, [-1,4]:Integer, [-1,5]:Integer, [-1,6]:Integer, [-1,7]:Integer, [-1,8]:Float@double}"
"{[-1]:Pointer, [-1,0]:Pointer, [-1,4]:Integer, [-1,5]:Integer, [-1,6]:Integer, [-1,7]:Integer, [-1,8]:Float@double}"
@test tt(Sibling2{Sibling2{LList2{Tuple{Float32,Float64}}}}) ==
"{[0]:Pointer, [0,4]:Float@float, [0,8]:Float@double, [4]:Integer, [8]:Pointer, [8,4]:Float@float, [8,8]:Float@double, [12]:Integer, [16]:Pointer, [16,4]:Float@float, [16,8]:Float@double, [20]:Integer, [24]:Pointer, [24,4]:Float@float, [24,8]:Float@double}"
"{[0]:Pointer, [0,0]:Pointer, [0,4]:Float@float, [0,8]:Float@double, [4]:Integer, [8]:Pointer, [8,0]:Pointer, [8,4]:Float@float, [8,8]:Float@double, [12]:Integer, [16]:Pointer, [16,0]:Pointer, [16,4]:Float@float, [16,8]:Float@double, [20]:Integer, [24]:Pointer, [24,0]:Pointer, [24,4]:Float@float, [24,8]:Float@double}"
end
end

Expand All @@ -91,4 +100,10 @@ end
@test Enzyme.get_offsets(Ptr{Float32}) == ((Enzyme.API.DT_Pointer,0),)
@test Enzyme.get_offsets(Vector{Float32}) == ((Enzyme.API.DT_Pointer,0),)
@test Enzyme.get_offsets(Tuple{Float64, Int}) == [(Enzyme.API.DT_Double,0),(Enzyme.API.DT_Integer, 8)]

if Sys.WORD_SIZE == 64
@test Enzyme.get_offsets(UnionMember) == [(Enzyme.API.DT_Float,0),(Enzyme.API.DT_Pointer, 8), (Enzyme.API.DT_Integer, 16)]
else
@test Enzyme.get_offsets(UnionMember) == [(Enzyme.API.DT_Float, 0), (Enzyme.API.DT_Pointer, 4), (Enzyme.API.DT_Integer, 8)]
end
end

0 comments on commit dad67bf

Please sign in to comment.