diff --git a/src/Enzyme.jl b/src/Enzyme.jl index 0a3c030f11..c3769e35ac 100644 --- a/src/Enzyme.jl +++ b/src/Enzyme.jl @@ -1559,7 +1559,7 @@ end Base.@_inline_meta ntuple(Val(N)) do idx Base.@_inline_meta - return (i == idx) ? 1.0 : 0.0 + return (i == idx) ? T(1) : T(0) end end end @@ -1571,7 +1571,7 @@ end Base.@_inline_meta ntuple(Val(N)) do idx Base.@_inline_meta - return (i + start - 1 == idx) ? 1.0 : 0.0 + return (i + start - 1 == idx) ? T(1) : T(0) end end end diff --git a/src/compiler/optimize.jl b/src/compiler/optimize.jl index 2a1ef49553..dc26d140bb 100644 --- a/src/compiler/optimize.jl +++ b/src/compiler/optimize.jl @@ -1684,17 +1684,23 @@ function propagate_returned!(mod::LLVM.Module) illegalUse = true break end - if !isa(ops[i], LLVM.AllocaInst) + if !isa(ops[i], LLVM.AllocaInst) && !isa(ops[i], LLVM.UndefValue) && !isa(ops[i], LLVM.PoisonValue) illegalUse = true break end - eltype = LLVM.LLVMType(LLVM.API.LLVMGetAllocatedType(ops[i])) + eltype = if isa(ops[i], LLVM.AllocaInst) + LLVM.LLVMType(LLVM.API.LLVMGetAllocatedType(ops[i])) + else + LLVM.eltype(value_type(ops[i])) + end seenfn = false todo = LLVM.Instruction[] - for u2 in LLVM.uses(ops[i]) + if isa(ops[i], LLVM.AllocaInst) + for u2 in LLVM.uses(ops[i]) un2 = LLVM.user(u2) push!(todo, un2) end + end while length(todo) > 0 un2 = pop!(todo) if isa(un2, LLVM.BitCastInst) @@ -1705,6 +1711,14 @@ function propagate_returned!(mod::LLVM.Module) end continue end + if isa(un2, LLVM.GetElementPtrInst) + push!(torem, un2) + for u3 in LLVM.uses(un2) + un3 = LLVM.user(u3) + push!(todo, un3) + end + continue + end if !isa(un2, LLVM.CallInst) illegalUse = true break @@ -1776,14 +1790,9 @@ function propagate_returned!(mod::LLVM.Module) illegalUse = true break end - if isa(ops[i], LLVM.UndefValue) + if isa(ops[i], LLVM.UndefValue) || isa(ops[i], LLVM.PoisonValue) continue end - @static if LLVM.version() >= v"12" - if isa(ops[i], LLVM.PoisonValue) - continue - end - end if ops[i] == arg continue end @@ -1911,6 +1920,7 @@ function propagate_returned!(mod::LLVM.Module) un = LLVM.user(u) push!(next, LLVM.name(LLVM.parent(LLVM.parent(un)))) end + delete_writes_into_removed_args(fn, toremove) nfn = LLVM.Function( API.EnzymeCloneFunctionWithoutReturnOrArgs(fn, keepret, toremove), ) @@ -1953,6 +1963,39 @@ function propagate_returned!(mod::LLVM.Module) end end end + +function delete_writes_into_removed_args(fn::LLVM.Function, toremove) + args = collect(parameters(fn)) + for tr in toremove + tr = tr + 1 + todorep = Tuple{LLVM.Instruction, LLVM.Value}[] + for opv in LLVM.uses(args[tr]) + u = LLVM.user(opv) + push!(todorep, (u, args[tr])) + end + toerase = LLVM.Instruction[] + while length(todorep) != 0 + cur, cval = pop!(todorep) + if isa(cur, LLVM.StoreInst) + if operands(cur)[2] == cval + LLVM.API.LLVMInstructionEraseFromParent(nphi) + continue + end + end + if isa(cur, LLVM.GetElementPtrInst) || + isa(cur, LLVM.BitCastInst) || + isa(cur, LLVM.AddrSpaceCastInst) + for opv in LLVM.uses(cur) + u = LLVM.user(opv) + push!(todorep, (u, cur)) + end + continue + end + throw(AssertionError("Deleting argument with an unknown dependency, $(string(cur)) uses $(string(cval))")) + end + end +end + function detect_writeonly!(mod::LLVM.Module) for f in functions(mod) if isempty(LLVM.blocks(f)) @@ -2376,7 +2419,7 @@ function removeDeadArgs!(mod::LLVM.Module, tm) kind(attr) == kind(StringAttribute("enzyme_sret")) || kind(attr) == kind(StringAttribute("enzyme_sret_v")) ) for attr in attrs - ) + ) && any_jltypes(sret_ty(fn, idx)) for u in LLVM.uses(fn) u = LLVM.user(u) if isa(u, LLVM.ConstantExpr) diff --git a/src/rules/typerules.jl b/src/rules/typerules.jl index 2cba33c14e..de11d3c1cd 100644 --- a/src/rules/typerules.jl +++ b/src/rules/typerules.jl @@ -92,3 +92,43 @@ function inoutcopyslice_rule( end return UInt8(false) end + +function inoutgcloaded_rule( + direction::Cint, + ret::API.CTypeTreeRef, + args::Ptr{API.CTypeTreeRef}, + known_values::Ptr{API.IntList}, + numArgs::Csize_t, + val::LLVM.API.LLVMValueRef, +)::UInt8 + if numArgs != 1 + return UInt8(false) + end + inst = LLVM.Instruction(val) + + legal, typ = abs_typeof(inst) + + if legal + if (direction & API.DOWN) != 0 + ctx = LLVM.context(inst) + dl = string(LLVM.datalayout(LLVM.parent(LLVM.parent(LLVM.parent(inst))))) + if GPUCompiler.deserves_retbox(typ) + typ = Ptr{typ} + end + rest = typetree(typ, ctx, dl) + changed, legal = API.EnzymeCheckedMergeTypeTree(ret, rest) + @assert legal + end + return UInt8(false) + end + + if (direction & API.UP) != 0 + changed, legal = API.EnzymeCheckedMergeTypeTree(unsafe_load(args, 2), ret) + @assert legal + end + if (direction & API.DOWN) != 0 + changed, legal = API.EnzymeCheckedMergeTypeTree(ret, unsafe_load(args, 2)) + @assert legal + end + return UInt8(false) +end \ No newline at end of file diff --git a/src/utils.jl b/src/utils.jl index 0b441b1b25..55dc69769e 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -327,7 +327,6 @@ end export my_methodinstance - @static if VERSION < v"1.11-" @inline function typed_fieldtype(@nospecialize(T::Type), i::Int) @@ -352,3 +351,11 @@ end end export typed_fieldtype + +# returns the inner type of an sret/enzyme_sret/enzyme_sret_v +function sret_ty(fn::LLVM.Function, idx::Int) + return eltype(LLVM.value_type(LLVM.parameters(fn)[idx])) +end + +export sret_ty + diff --git a/test/abi.jl b/test/abi.jl index acc8f26090..f27affd3a4 100644 --- a/test/abi.jl +++ b/test/abi.jl @@ -544,6 +544,36 @@ end @inferred hvp_and_gradient!(zeros(2), zeros(2), mulsin, [2.0, 3.0], [5.0, 2.7]) end +function ulogistic(x) + return x > 36 ? one(x) : 1 / (one(x) + 1/x) +end + +@noinline function u_transform_tuple(x) + yfirst = ulogistic(@inbounds x[1]) + yfirst, 2 +end + + +@noinline function mytransform(ts, x) + yfirst = ulogistic(@inbounds x[1]) + yrest, _ = u_transform_tuple(x) + (yfirst, yrest) +end + +function undefsret(trf, x) + p = mytransform(trf, x) + return 1/(p[2]) +end + +@testset "Undef sret" begin + trf = 0.1 + + x = randn(3) + dx = zero(x) + undefsret(trf, x) + autodiff(Reverse, undefsret, Active, Const(trf), Duplicated(x, dx)) +end + struct ByRefStruct x::Vector{Float64} v::Vector{Float64}