Skip to content

Commit

Permalink
Even more indexed typeinfo (#1916)
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses authored Sep 28, 2024
1 parent b999e5a commit 55582f8
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 35 deletions.
61 changes: 32 additions & 29 deletions src/absint.jl
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,34 @@ function should_recurse(@nospecialize(typ2), arg_t, byref, dl)
end
end

function get_base_and_offset(larg::LLVM.Value)::Tuple{LLVM.Value, Int, Bool}
offset = 0
error = false
while true
if isa(larg, LLVM.BitCastInst) || isa(larg, LLVM.AddrSpaceCastInst)
larg = operands(larg)[1]
continue
end
if isa(larg, LLVM.GetElementPtrInst) &&
all(x -> isa(x, LLVM.ConstantInt), operands(larg)[2:end])
b = LLVM.IRBuilder()
position!(b, larg)
offty = LLVM.IntType(8 * sizeof(Int))
offset2 = API.EnzymeComputeByteOffsetOfGEP(b, larg, offty)
@assert isa(offset2, LLVM.ConstantInt)
offset += convert(Int, offset2)
larg = operands(larg)[1]
continue
end
if isa(larg, LLVM.Argument)
break
end
error = true
break
end
return larg, offset, error
end

function abs_typeof(
arg::LLVM.Value,
partial::Bool = false,
Expand Down Expand Up @@ -354,40 +382,15 @@ function abs_typeof(
end

if isa(arg, LLVM.LoadInst)
larg = operands(arg)[1]
offset = nothing
error = false
while true
if isa(larg, LLVM.BitCastInst) || isa(larg, LLVM.AddrSpaceCastInst)
larg = operands(larg)[1]
continue
end
if offset === nothing &&
isa(larg, LLVM.GetElementPtrInst) &&
all(x -> isa(x, LLVM.ConstantInt), operands(larg)[2:end])
b = LLVM.IRBuilder()
position!(b, larg)
offty = LLVM.IntType(8 * sizeof(Int))
offset = API.EnzymeComputeByteOffsetOfGEP(b, larg, offty)
@assert isa(offset, LLVM.ConstantInt)
offset = convert(Int, offset)
larg = operands(larg)[1]
continue
end
if isa(larg, LLVM.Argument)
break
end
error = true
break
end
larg, offset, error = get_base_and_offset(operands(arg)[1])

if !error
legal, typ, byref = abs_typeof(larg)
if legal && (byref == GPUCompiler.MUT_REF || byref == GPUCompiler.BITS_REF) && Base.isconcretetype(typ)
@static if VERSION < v"1.11-"
if typ <: Array && Base.isconcretetype(typ)
T = eltype(typ)
if offset === nothing || offset == 0
if offset == 0
return (true, Ptr{T}, GPUCompiler.BITS_VALUE)
else
return (true, Int, GPUCompiler.BITS_VALUE)
Expand All @@ -400,14 +403,14 @@ function abs_typeof(
byref = GPUCompiler.BITS_VALUE
legal = true

while (offset !== nothing && offset != 0) && legal
while offset != 0 && legal
@assert Base.isconcretetype(typ)
seen = false
lasti = 1
for i = 1:fieldcount(typ)
fo = fieldoffset(typ, i)
if fieldoffset(typ, i) == offset
offset = nothing
offset = 0
typ = fieldtype(typ, i)
if !Base.allocatedinline(typ)
if byref != GPUCompiler.BITS_VALUE
Expand Down
8 changes: 5 additions & 3 deletions src/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7988,7 +7988,8 @@ function GPUCompiler.codegen(
if intr == LLVM.Intrinsic("llvm.memcpy").id ||
intr == LLVM.Intrinsic("llvm.memmove").id ||
intr == LLVM.Intrinsic("llvm.memset").id
legal, jTy, byref = abs_typeof(operands(inst)[1])
base, offset, _ = get_base_and_offset(operands(inst)[1])
legal, jTy, byref = abs_typeof(base)
sz =
if intr == LLVM.Intrinsic("llvm.memcpy").id ||
intr == LLVM.Intrinsic("llvm.memmove").id
Expand All @@ -8007,8 +8008,9 @@ function GPUCompiler.codegen(
any(T2 isa Core.TypeofVararg for T2 in jTy.parameters)
)
)
if isa(sz, LLVM.ConstantInt) && sizeof(jTy) == convert(Int, sz)
md = to_fullmd(jTy)
if offset < sizeof(jTy) && isa(sz, LLVM.ConstantInt) && sizeof(jTy) - offset >= convert(Int, sz)
lim = convert(Int, sz)
md = to_fullmd(jTy, offset, lim)
@assert byref == GPUCompiler.BITS_REF ||
byref == GPUCompiler.MUT_REF
metadata(inst)["enzyme_truetype"] = md
Expand Down
25 changes: 22 additions & 3 deletions src/typetree.jl
Original file line number Diff line number Diff line change
Expand Up @@ -137,9 +137,28 @@ function get_offsets(@nospecialize(T::Type))
return results
end

function to_fullmd(@nospecialize(T::Type))
function to_fullmd(@nospecialize(T::Type), offset::Int, lim::Int)
mds = LLVM.Metadata[]
for (sT, sO) in get_offsets(T)
offs = get_offsets(T)

minoff = -1
for (sT, sO) in offs
if sO >= offset
if sO == offset
minOff = sO
end
else
minoff = max(minoff, sO)
end
end

for (sT, sO) in offs
if sO != minoff && (sO < offset)
continue
end
if sO >= lim
continue
end
if sT == API.DT_Pointer
push!(mds, LLVM.MDString("Pointer"))
elseif sT == API.DT_Integer
Expand All @@ -155,7 +174,7 @@ function to_fullmd(@nospecialize(T::Type))
else
@assert false
end
push!(mds, LLVM.Metadata(LLVM.ConstantInt(sO)))
push!(mds, LLVM.Metadata(LLVM.ConstantInt(min(0, sO - offset))))
end
return LLVM.MDNode(mds)
end
Expand Down

0 comments on commit 55582f8

Please sign in to comment.