Skip to content

Commit

Permalink
Fix sret undef (#1990)
Browse files Browse the repository at this point in the history
* Fix sret undef

* add test

* fix

* 1.11: the adventure continues, destroy (#1986)

* 1.11: the adventure continues, destroy

* fix

* fixup

* fix

* cleanup

* fix

* fix

* fix

* fix

* fix

* fix

* fix

* fix

* fix

* fix

* fix

* fix

* fix
  • Loading branch information
wsmoses authored Oct 22, 2024
1 parent 1c69a70 commit 80c9887
Show file tree
Hide file tree
Showing 5 changed files with 133 additions and 13 deletions.
4 changes: 2 additions & 2 deletions src/Enzyme.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1559,7 +1559,7 @@ end
Base.@_inline_meta
ntuple(Val(N)) do idx
Base.@_inline_meta
return (i == idx) ? 1.0 : 0.0
return (i == idx) ? T(1) : T(0)
end
end
end
Expand All @@ -1571,7 +1571,7 @@ end
Base.@_inline_meta
ntuple(Val(N)) do idx
Base.@_inline_meta
return (i + start - 1 == idx) ? 1.0 : 0.0
return (i + start - 1 == idx) ? T(1) : T(0)
end
end
end
Expand Down
63 changes: 53 additions & 10 deletions src/compiler/optimize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1684,17 +1684,23 @@ function propagate_returned!(mod::LLVM.Module)
illegalUse = true
break
end
if !isa(ops[i], LLVM.AllocaInst)
if !isa(ops[i], LLVM.AllocaInst) && !isa(ops[i], LLVM.UndefValue) && !isa(ops[i], LLVM.PoisonValue)
illegalUse = true
break
end
eltype = LLVM.LLVMType(LLVM.API.LLVMGetAllocatedType(ops[i]))
eltype = if isa(ops[i], LLVM.AllocaInst)
LLVM.LLVMType(LLVM.API.LLVMGetAllocatedType(ops[i]))
else
LLVM.eltype(value_type(ops[i]))
end
seenfn = false
todo = LLVM.Instruction[]
for u2 in LLVM.uses(ops[i])
if isa(ops[i], LLVM.AllocaInst)
for u2 in LLVM.uses(ops[i])
un2 = LLVM.user(u2)
push!(todo, un2)
end
end
while length(todo) > 0
un2 = pop!(todo)
if isa(un2, LLVM.BitCastInst)
Expand All @@ -1705,6 +1711,14 @@ function propagate_returned!(mod::LLVM.Module)
end
continue
end
if isa(un2, LLVM.GetElementPtrInst)
push!(torem, un2)
for u3 in LLVM.uses(un2)
un3 = LLVM.user(u3)
push!(todo, un3)
end
continue
end
if !isa(un2, LLVM.CallInst)
illegalUse = true
break
Expand Down Expand Up @@ -1776,14 +1790,9 @@ function propagate_returned!(mod::LLVM.Module)
illegalUse = true
break
end
if isa(ops[i], LLVM.UndefValue)
if isa(ops[i], LLVM.UndefValue) || isa(ops[i], LLVM.PoisonValue)
continue
end
@static if LLVM.version() >= v"12"
if isa(ops[i], LLVM.PoisonValue)
continue
end
end
if ops[i] == arg
continue
end
Expand Down Expand Up @@ -1911,6 +1920,7 @@ function propagate_returned!(mod::LLVM.Module)
un = LLVM.user(u)
push!(next, LLVM.name(LLVM.parent(LLVM.parent(un))))
end
delete_writes_into_removed_args(fn, toremove)
nfn = LLVM.Function(
API.EnzymeCloneFunctionWithoutReturnOrArgs(fn, keepret, toremove),
)
Expand Down Expand Up @@ -1953,6 +1963,39 @@ function propagate_returned!(mod::LLVM.Module)
end
end
end

function delete_writes_into_removed_args(fn::LLVM.Function, toremove)
args = collect(parameters(fn))
for tr in toremove
tr = tr + 1
todorep = Tuple{LLVM.Instruction, LLVM.Value}[]
for opv in LLVM.uses(args[tr])
u = LLVM.user(opv)
push!(todorep, (u, args[tr]))
end
toerase = LLVM.Instruction[]
while length(todorep) != 0
cur, cval = pop!(todorep)
if isa(cur, LLVM.StoreInst)
if operands(cur)[2] == cval
LLVM.API.LLVMInstructionEraseFromParent(nphi)
continue
end
end
if isa(cur, LLVM.GetElementPtrInst) ||
isa(cur, LLVM.BitCastInst) ||
isa(cur, LLVM.AddrSpaceCastInst)
for opv in LLVM.uses(cur)
u = LLVM.user(opv)
push!(todorep, (u, cur))
end
continue
end
throw(AssertionError("Deleting argument with an unknown dependency, $(string(cur)) uses $(string(cval))"))
end
end
end

function detect_writeonly!(mod::LLVM.Module)
for f in functions(mod)
if isempty(LLVM.blocks(f))
Expand Down Expand Up @@ -2376,7 +2419,7 @@ function removeDeadArgs!(mod::LLVM.Module, tm)
kind(attr) == kind(StringAttribute("enzyme_sret")) ||
kind(attr) == kind(StringAttribute("enzyme_sret_v"))
) for attr in attrs
)
) && any_jltypes(sret_ty(fn, idx))
for u in LLVM.uses(fn)
u = LLVM.user(u)
if isa(u, LLVM.ConstantExpr)
Expand Down
40 changes: 40 additions & 0 deletions src/rules/typerules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -92,3 +92,43 @@ function inoutcopyslice_rule(
end
return UInt8(false)
end

function inoutgcloaded_rule(
direction::Cint,
ret::API.CTypeTreeRef,
args::Ptr{API.CTypeTreeRef},
known_values::Ptr{API.IntList},
numArgs::Csize_t,
val::LLVM.API.LLVMValueRef,
)::UInt8
if numArgs != 1
return UInt8(false)
end
inst = LLVM.Instruction(val)

legal, typ = abs_typeof(inst)

if legal
if (direction & API.DOWN) != 0
ctx = LLVM.context(inst)
dl = string(LLVM.datalayout(LLVM.parent(LLVM.parent(LLVM.parent(inst)))))
if GPUCompiler.deserves_retbox(typ)
typ = Ptr{typ}
end
rest = typetree(typ, ctx, dl)
changed, legal = API.EnzymeCheckedMergeTypeTree(ret, rest)
@assert legal
end
return UInt8(false)
end

if (direction & API.UP) != 0
changed, legal = API.EnzymeCheckedMergeTypeTree(unsafe_load(args, 2), ret)
@assert legal
end
if (direction & API.DOWN) != 0
changed, legal = API.EnzymeCheckedMergeTypeTree(ret, unsafe_load(args, 2))
@assert legal
end
return UInt8(false)
end
9 changes: 8 additions & 1 deletion src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -327,7 +327,6 @@ end

export my_methodinstance


@static if VERSION < v"1.11-"

@inline function typed_fieldtype(@nospecialize(T::Type), i::Int)
Expand All @@ -352,3 +351,11 @@ end
end

export typed_fieldtype

# returns the inner type of an sret/enzyme_sret/enzyme_sret_v
function sret_ty(fn::LLVM.Function, idx::Int)
return eltype(LLVM.value_type(LLVM.parameters(fn)[idx]))
end

export sret_ty

30 changes: 30 additions & 0 deletions test/abi.jl
Original file line number Diff line number Diff line change
Expand Up @@ -544,6 +544,36 @@ end
@inferred hvp_and_gradient!(zeros(2), zeros(2), mulsin, [2.0, 3.0], [5.0, 2.7])
end

function ulogistic(x)
return x > 36 ? one(x) : 1 / (one(x) + 1/x)
end

@noinline function u_transform_tuple(x)
yfirst = ulogistic(@inbounds x[1])
yfirst, 2
end


@noinline function mytransform(ts, x)
yfirst = ulogistic(@inbounds x[1])
yrest, _ = u_transform_tuple(x)
(yfirst, yrest)
end

function undefsret(trf, x)
p = mytransform(trf, x)
return 1/(p[2])
end

@testset "Undef sret" begin
trf = 0.1

x = randn(3)
dx = zero(x)
undefsret(trf, x)
autodiff(Reverse, undefsret, Active, Const(trf), Duplicated(x, dx))
end

struct ByRefStruct
x::Vector{Float64}
v::Vector{Float64}
Expand Down

0 comments on commit 80c9887

Please sign in to comment.