From 55582f8ba60f41411d162b3d3d73155d0545199c Mon Sep 17 00:00:00 2001 From: William Moses Date: Sat, 28 Sep 2024 14:00:41 -0500 Subject: [PATCH] Even more indexed typeinfo (#1916) --- src/absint.jl | 61 ++++++++++++++++++++++++++----------------------- src/compiler.jl | 8 ++++--- src/typetree.jl | 25 +++++++++++++++++--- 3 files changed, 59 insertions(+), 35 deletions(-) diff --git a/src/absint.jl b/src/absint.jl index 041b3bd1cc..77ce2b6a7e 100644 --- a/src/absint.jl +++ b/src/absint.jl @@ -204,6 +204,34 @@ function should_recurse(@nospecialize(typ2), arg_t, byref, dl) end end +function get_base_and_offset(larg::LLVM.Value)::Tuple{LLVM.Value, Int, Bool} + offset = 0 + error = false + while true + if isa(larg, LLVM.BitCastInst) || isa(larg, LLVM.AddrSpaceCastInst) + larg = operands(larg)[1] + continue + end + if isa(larg, LLVM.GetElementPtrInst) && + all(x -> isa(x, LLVM.ConstantInt), operands(larg)[2:end]) + b = LLVM.IRBuilder() + position!(b, larg) + offty = LLVM.IntType(8 * sizeof(Int)) + offset2 = API.EnzymeComputeByteOffsetOfGEP(b, larg, offty) + @assert isa(offset2, LLVM.ConstantInt) + offset += convert(Int, offset2) + larg = operands(larg)[1] + continue + end + if isa(larg, LLVM.Argument) + break + end + error = true + break + end + return larg, offset, error +end + function abs_typeof( arg::LLVM.Value, partial::Bool = false, @@ -354,32 +382,7 @@ function abs_typeof( end if isa(arg, LLVM.LoadInst) - larg = operands(arg)[1] - offset = nothing - error = false - while true - if isa(larg, LLVM.BitCastInst) || isa(larg, LLVM.AddrSpaceCastInst) - larg = operands(larg)[1] - continue - end - if offset === nothing && - isa(larg, LLVM.GetElementPtrInst) && - all(x -> isa(x, LLVM.ConstantInt), operands(larg)[2:end]) - b = LLVM.IRBuilder() - position!(b, larg) - offty = LLVM.IntType(8 * sizeof(Int)) - offset = API.EnzymeComputeByteOffsetOfGEP(b, larg, offty) - @assert isa(offset, LLVM.ConstantInt) - offset = convert(Int, offset) - larg = operands(larg)[1] - continue - end - if isa(larg, LLVM.Argument) - break - end - error = true - break - end + larg, offset, error = get_base_and_offset(operands(arg)[1]) if !error legal, typ, byref = abs_typeof(larg) @@ -387,7 +390,7 @@ function abs_typeof( @static if VERSION < v"1.11-" if typ <: Array && Base.isconcretetype(typ) T = eltype(typ) - if offset === nothing || offset == 0 + if offset == 0 return (true, Ptr{T}, GPUCompiler.BITS_VALUE) else return (true, Int, GPUCompiler.BITS_VALUE) @@ -400,14 +403,14 @@ function abs_typeof( byref = GPUCompiler.BITS_VALUE legal = true - while (offset !== nothing && offset != 0) && legal + while offset != 0 && legal @assert Base.isconcretetype(typ) seen = false lasti = 1 for i = 1:fieldcount(typ) fo = fieldoffset(typ, i) if fieldoffset(typ, i) == offset - offset = nothing + offset = 0 typ = fieldtype(typ, i) if !Base.allocatedinline(typ) if byref != GPUCompiler.BITS_VALUE diff --git a/src/compiler.jl b/src/compiler.jl index 08dc5f05c9..e890bd998d 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -7988,7 +7988,8 @@ function GPUCompiler.codegen( if intr == LLVM.Intrinsic("llvm.memcpy").id || intr == LLVM.Intrinsic("llvm.memmove").id || intr == LLVM.Intrinsic("llvm.memset").id - legal, jTy, byref = abs_typeof(operands(inst)[1]) + base, offset, _ = get_base_and_offset(operands(inst)[1]) + legal, jTy, byref = abs_typeof(base) sz = if intr == LLVM.Intrinsic("llvm.memcpy").id || intr == LLVM.Intrinsic("llvm.memmove").id @@ -8007,8 +8008,9 @@ function GPUCompiler.codegen( any(T2 isa Core.TypeofVararg for T2 in jTy.parameters) ) ) - if isa(sz, LLVM.ConstantInt) && sizeof(jTy) == convert(Int, sz) - md = to_fullmd(jTy) + if offset < sizeof(jTy) && isa(sz, LLVM.ConstantInt) && sizeof(jTy) - offset >= convert(Int, sz) + lim = convert(Int, sz) + md = to_fullmd(jTy, offset, lim) @assert byref == GPUCompiler.BITS_REF || byref == GPUCompiler.MUT_REF metadata(inst)["enzyme_truetype"] = md diff --git a/src/typetree.jl b/src/typetree.jl index 8ddce070b2..c96d41fb2b 100644 --- a/src/typetree.jl +++ b/src/typetree.jl @@ -137,9 +137,28 @@ function get_offsets(@nospecialize(T::Type)) return results end -function to_fullmd(@nospecialize(T::Type)) +function to_fullmd(@nospecialize(T::Type), offset::Int, lim::Int) mds = LLVM.Metadata[] - for (sT, sO) in get_offsets(T) + offs = get_offsets(T) + + minoff = -1 + for (sT, sO) in offs + if sO >= offset + if sO == offset + minOff = sO + end + else + minoff = max(minoff, sO) + end + end + + for (sT, sO) in offs + if sO != minoff && (sO < offset) + continue + end + if sO >= lim + continue + end if sT == API.DT_Pointer push!(mds, LLVM.MDString("Pointer")) elseif sT == API.DT_Integer @@ -155,7 +174,7 @@ function to_fullmd(@nospecialize(T::Type)) else @assert false end - push!(mds, LLVM.Metadata(LLVM.ConstantInt(sO))) + push!(mds, LLVM.Metadata(LLVM.ConstantInt(min(0, sO - offset)))) end return LLVM.MDNode(mds) end