From dad67bfc3913f4eb66126d7c186a59b0c1f18586 Mon Sep 17 00:00:00 2001 From: William Moses Date: Sun, 29 Sep 2024 22:39:20 -0500 Subject: [PATCH] Union member type info (#1927) * Union member type info * fix * fix --- src/typetree.jl | 7 ++++++- test/typetree.jl | 35 +++++++++++++++++++++++++---------- 2 files changed, 31 insertions(+), 11 deletions(-) diff --git a/src/typetree.jl b/src/typetree.jl index c886c683ce..61d700acb8 100644 --- a/src/typetree.jl +++ b/src/typetree.jl @@ -353,12 +353,17 @@ function typetree_inner(@nospecialize(T::Type), ctx, dl, seen::TypeTreeTable) for f = 1:fieldcount(T) offset = fieldoffset(T, f) subT = fieldtype(T, f) - subtree = copy(typetree(subT, ctx, dl, seen)) if subT isa UnionAll || subT isa Union || subT == Union{} + if !allocatedinline(subT) + subtree = TypeTree(API.DT_Pointer, offset, ctx) + merge!(tt, subtree) + end # FIXME: Handle union continue end + + subtree = copy(typetree(subT, ctx, dl, seen)) # Allocated inline so adjust first path if allocatedinline(subT) diff --git a/test/typetree.jl b/test/typetree.jl index 1a869d6687..3b47161f62 100644 --- a/test/typetree.jl +++ b/test/typetree.jl @@ -37,6 +37,12 @@ struct Sibling2{T} b::T end +struct UnionMember + a::Float32 + b::Union{Function, Number} + c::Bool +end + @testset "TypeTree" begin @test tt(Float16) == "{[-1]:Float@half}" @test tt(Float32) == "{[-1]:Float@float}" @@ -55,28 +61,31 @@ end @test at2.z == 0.0 @test at2.type == 4 + if Sys.WORD_SIZE == 64 - @test tt(LList2{Float64}) == "{[8]:Float@double}" - @test tt(Sibling{LList2{Float64}}) == "{[-1]:Pointer, [-1,8]:Float@double}" + @test tt(UnionMember) == "{[0]:Float@float, [8]:Pointer, [16]:Integer}" + @test tt(LList2{Float64}) == "{[0]:Pointer, [8]:Float@double}" + @test tt(Sibling{LList2{Float64}}) == "{[-1]:Pointer, [-1,0]:Pointer, [-1,8]:Float@double}" @test tt(Sibling2{LList2{Float64}}) == - "{[0]:Pointer, [0,8]:Float@double, [8]:Integer, [16]:Pointer, [16,8]:Float@double}" + "{[0]:Pointer, [0,0]:Pointer, [0,8]:Float@double, [8]:Integer, [16]:Pointer, [16,0]:Pointer, [16,8]:Float@double}" @test tt(Sibling{Tuple{Int,Float64}}) == "{[0]:Integer, [1]:Integer, [2]:Integer, [3]:Integer, [4]:Integer, [5]:Integer, [6]:Integer, [7]:Integer, [8]:Float@double, [16]:Integer, [17]:Integer, [18]:Integer, [19]:Integer, [20]:Integer, [21]:Integer, [22]:Integer, [23]:Integer, [24]:Float@double}" @test tt(Sibling{LList2{Tuple{Int,Float64}}}) == - "{[-1]:Pointer, [-1,8]:Integer, [-1,9]:Integer, [-1,10]:Integer, [-1,11]:Integer, [-1,12]:Integer, [-1,13]:Integer, [-1,14]:Integer, [-1,15]:Integer, [-1,16]:Float@double}" + "{[-1]:Pointer, [-1,0]:Pointer, [-1,8]:Integer, [-1,9]:Integer, [-1,10]:Integer, [-1,11]:Integer, [-1,12]:Integer, [-1,13]:Integer, [-1,14]:Integer, [-1,15]:Integer, [-1,16]:Float@double}" @test tt(Sibling2{Sibling2{LList2{Tuple{Float32,Float64}}}}) == - "{[0]:Pointer, [0,8]:Float@float, [0,16]:Float@double, [8]:Integer, [16]:Pointer, [16,8]:Float@float, [16,16]:Float@double, [24]:Integer, [32]:Pointer, [32,8]:Float@float, [32,16]:Float@double, [40]:Integer, [48]:Pointer, [48,8]:Float@float, [48,16]:Float@double}" + "{[0]:Pointer, [0,0]:Pointer, [0,8]:Float@float, [0,16]:Float@double, [8]:Integer, [16]:Pointer, [16,0]:Pointer, [16,8]:Float@float, [16,16]:Float@double, [24]:Integer, [32]:Pointer, [32,0]:Pointer, [32,8]:Float@float, [32,16]:Float@double, [40]:Integer, [48]:Pointer, [48,0]:Pointer, [48,8]:Float@float, [48,16]:Float@double}" else - @test tt(LList2{Float64}) == "{[4]:Float@double}" - @test tt(Sibling{LList2{Float64}}) == "{[-1]:Pointer, [-1,4]:Float@double}" + @test tt(UnionMember) == "{[0]:Float@float, [4]:Pointer, [8]:Integer}" + @test tt(LList2{Float64}) == "{[0]:Pointer, [4]:Float@double}" + @test tt(Sibling{LList2{Float64}}) == "{[-1]:Pointer, [-1,0]:Pointer, [-1,4]:Float@double}" @test tt(Sibling2{LList2{Float64}}) == - "{[0]:Pointer, [0,4]:Float@double, [4]:Integer, [8]:Pointer, [8,4]:Float@double}" + "{[0]:Pointer, [0,0]:Pointer, [0,4]:Float@double, [4]:Integer, [8]:Pointer, [8,0]:Pointer, [8,4]:Float@double}" @test tt(Sibling{Tuple{Int,Float64}}) == "{[0]:Integer, [1]:Integer, [2]:Integer, [3]:Integer, [4]:Float@double, [12]:Integer, [13]:Integer, [14]:Integer, [15]:Integer, [16]:Float@double}" @test tt(Sibling{LList2{Tuple{Int,Float64}}}) == - "{[-1]:Pointer, [-1,4]:Integer, [-1,5]:Integer, [-1,6]:Integer, [-1,7]:Integer, [-1,8]:Float@double}" + "{[-1]:Pointer, [-1,0]:Pointer, [-1,4]:Integer, [-1,5]:Integer, [-1,6]:Integer, [-1,7]:Integer, [-1,8]:Float@double}" @test tt(Sibling2{Sibling2{LList2{Tuple{Float32,Float64}}}}) == - "{[0]:Pointer, [0,4]:Float@float, [0,8]:Float@double, [4]:Integer, [8]:Pointer, [8,4]:Float@float, [8,8]:Float@double, [12]:Integer, [16]:Pointer, [16,4]:Float@float, [16,8]:Float@double, [20]:Integer, [24]:Pointer, [24,4]:Float@float, [24,8]:Float@double}" + "{[0]:Pointer, [0,0]:Pointer, [0,4]:Float@float, [0,8]:Float@double, [4]:Integer, [8]:Pointer, [8,0]:Pointer, [8,4]:Float@float, [8,8]:Float@double, [12]:Integer, [16]:Pointer, [16,0]:Pointer, [16,4]:Float@float, [16,8]:Float@double, [20]:Integer, [24]:Pointer, [24,0]:Pointer, [24,4]:Float@float, [24,8]:Float@double}" end end @@ -91,4 +100,10 @@ end @test Enzyme.get_offsets(Ptr{Float32}) == ((Enzyme.API.DT_Pointer,0),) @test Enzyme.get_offsets(Vector{Float32}) == ((Enzyme.API.DT_Pointer,0),) @test Enzyme.get_offsets(Tuple{Float64, Int}) == [(Enzyme.API.DT_Double,0),(Enzyme.API.DT_Integer, 8)] + + if Sys.WORD_SIZE == 64 + @test Enzyme.get_offsets(UnionMember) == [(Enzyme.API.DT_Float,0),(Enzyme.API.DT_Pointer, 8), (Enzyme.API.DT_Integer, 16)] + else + @test Enzyme.get_offsets(UnionMember) == [(Enzyme.API.DT_Float, 0), (Enzyme.API.DT_Pointer, 4), (Enzyme.API.DT_Integer, 8)] + end end