Skip to content

Commit

Permalink
Fix union abi (#1150)
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses authored Nov 18, 2023
1 parent f628cf4 commit 9579ec3
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 5 deletions.
24 changes: 19 additions & 5 deletions src/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3347,6 +3347,19 @@ function create_abi_wrapper(enzymefn::LLVM.Function, TT, rettype, actualRetType,
metadata(val)[LLVM.MD_dbg] = DILocation( 0, 0, LLVM.get_subprogram(llvm_f) )
end

@inline function fixup_abi(index, value)
valty = sret_types[index]
# Union becoming part of a tuple needs to be adjusted
# See https://github.com/JuliaLang/julia/blob/81afdbc36b365fcbf3ae25b7451c6cb5798c0c3d/src/cgutils.cpp#L3795C1-L3801C121
if valty isa Union
T_int8 = LLVM.Int8Type()
if value_type(value) == T_int8
value = nuwsub!(builder, value, LLVM.ConstantInt(T_int8, 1))
end
end
return value
end

if Mode == API.DEM_ReverseModePrimal

# if in split mode and the return is a union marked duplicated, upgrade floating point like shadow returns into ref{ty} since otherwise use of the value will create problems.
Expand All @@ -3361,6 +3374,7 @@ function create_abi_wrapper(enzymefn::LLVM.Function, TT, rettype, actualRetType,
if data[i] != -1
eval = extract_value!(builder, val, data[i])
end
eval = fixup_abi(i, eval)
ptr = inbounds_gep!(builder, jltype, sret, [LLVM.ConstantInt(LLVM.IntType(64), 0), LLVM.ConstantInt(LLVM.IntType(32), returnNum)])
ptr = pointercast!(builder, ptr, LLVM.PointerType(value_type(eval)))
si = store!(builder, eval, ptr)
Expand Down Expand Up @@ -3394,6 +3408,7 @@ function create_abi_wrapper(enzymefn::LLVM.Function, TT, rettype, actualRetType,
@assert !(isghostty(combinedReturn) || Core.Compiler.isconstType(combinedReturn) )
@assert Core.Compiler.isconstType(ty)
eval = makeInstanceOf(ty)
eval = fixup_abi(i, eval)
ptr = inbounds_gep!(builder, jltype, sret, [LLVM.ConstantInt(LLVM.IntType(64), 0), LLVM.ConstantInt(LLVM.IntType(32), returnNum)])
ptr = pointercast!(builder, ptr, LLVM.PointerType(value_type(eval)))
si = store!(builder, eval, ptr)
Expand Down Expand Up @@ -3421,14 +3436,14 @@ function create_abi_wrapper(enzymefn::LLVM.Function, TT, rettype, actualRetType,
end
end
for returnNum in 0:(count_Sret-1)
eval = if count_llvm_Sret == 0
eval = fixup_abi(returnNum+1, if count_llvm_Sret == 0
makeInstanceOf(sret_types[returnNum+1])
elseif count_llvm_Sret == 1
val
else
@assert count_llvm_Sret > 1
extract_value!(builder, val, returnNum)
end
end)
ptr = inbounds_gep!(builder, jltype, sret, [LLVM.ConstantInt(LLVM.IntType(64), 0), LLVM.ConstantInt(LLVM.IntType(32), returnNum)])
ptr = pointercast!(builder, ptr, LLVM.PointerType(value_type(eval)))
si = store!(builder, eval, ptr)
Expand All @@ -3440,11 +3455,11 @@ function create_abi_wrapper(enzymefn::LLVM.Function, TT, rettype, actualRetType,
if Mode == API.DEM_ReverseModeCombined
if returnPrimal
if !isghostty(literal_rt)
eval = if !isghostty(actualRetType)
eval = fixup_abi(returnNum+1, if !isghostty(actualRetType)
extract_value!(builder, val, returnNum)
else
makeInstanceOf(sret_types[returnNum+1])
end
end)
store!(builder, eval, inbounds_gep!(builder, jltype, sret, [LLVM.ConstantInt(LLVM.IntType(64), 0), LLVM.ConstantInt(LLVM.IntType(32), length(elements(jltype))-1 )]))
returnNum+=1
end
Expand Down Expand Up @@ -4206,7 +4221,6 @@ function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget};


mod, meta = GPUCompiler.codegen(:llvm, primal_job; optimize=false, toplevel=toplevel, cleanup=false, validate=false, parent_job=parent_job)

prepare_llvm(mod, primal_job, meta)

LLVM.ModulePassManager() do pm
Expand Down
39 changes: 39 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1930,6 +1930,45 @@ end
@test 2.0 Enzyme.autodiff(Reverse, unionret, Active, Active(2.0), Duplicated(out, dout), true)[1][1]
end

struct MyFlux
end

@testset "Union i8" begin
args = (
Val{(false, false, false)},
Val(1),
Val((true, true, true)),
Base.Val(NamedTuple{(Symbol("1"), Symbol("2"), Symbol("3")), Tuple{Any, Any, Any}}),
Base.getindex,
nothing,
((nothing,), MyFlux()),
((nothing,), MyFlux()),
1,
nothing
)

nt1 = Enzyme.Compiler.runtime_generic_augfwd(args...)
@test nt1[1] == (nothing,)
@test nt1[2] == (nothing,)

args2 = (
Val{(false, false, false)},
Val(1),
Val((true, true, true)),
Base.Val(NamedTuple{(Symbol("1"), Symbol("2"), Symbol("3")), Tuple{Any, Any, Any}}),
Base.getindex,
nothing,
((nothing,), MyFlux()),
((nothing,), MyFlux()),
2,
nothing
)

nt = Enzyme.Compiler.runtime_generic_augfwd(args2...)
@test nt[1] == MyFlux()
@test nt[2] == MyFlux()
end

@testset "Array push" begin

function pusher(x, y)
Expand Down

0 comments on commit 9579ec3

Please sign in to comment.