Skip to content

Commit

Permalink
Fix 1.9 and GC
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed Feb 4, 2023
1 parent 0a5a8cc commit e96be3f
Showing 1 changed file with 52 additions and 29 deletions.
81 changes: 52 additions & 29 deletions src/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2856,27 +2856,38 @@ function enzyme_custom_setup_args(B, orig, gutils, mi)
# TODO type analysis deduce if duplicated vs active
if activep == API.DFT_CONSTANT
Ty = Const{arg.typ}

llty = convert(LLVMType, Ty; ctx)
al = emit_allocobj!(B, Ty)
al = bitcast!(B, al, LLVM.PointerType(llty, addrspace(llvmtype(al))))
al = addrspacecast!(B, al, LLVM.PointerType(llty, 11))

arval = LLVM.UndefValue(llty)
arval = insert_value!(B, arval, val, 0)

al = alloca!(alloctx, llvmtype(arval))
store!(B, arval, al)
ptr = gep!(B, al, [LLVM.ConstantInt(LLVM.IntType(64; ctx), 0), LLVM.ConstantInt(LLVM.IntType(32; ctx), 0)])
store!(B, val, ptr)

if any_jltypes(llty)
vals = LLVM.Value[al0, val]
emit_writebarrier!(B, vals)
end

push!(args, al)

push!(activity, Ty)

elseif activep == API.DFT_OUT_DIFF
Ty = Active{arg.typ}
llty = convert(LLVMType, Ty; ctx)

arval = LLVM.UndefValue(llty)
arval = insert_value!(B, arval, val, 0)

al = alloca!(alloctx, llvmtype(arval))
store!(B, arval, al)
al = emit_allocobj!(B, Ty)
al = bitcast!(B, al, LLVM.PointerType(llty, addrspace(llvmtype(al))))
al = addrspacecast!(B, al, LLVM.PointerType(llty, 11))

ptr = gep!(B, al, [LLVM.ConstantInt(LLVM.IntType(64; ctx), 0), LLVM.ConstantInt(LLVM.IntType(32; ctx), 0)])
store!(B, val, ptr)

if any_jltypes(llty)
vals = LLVM.Value[al0, val]
emit_writebarrier!(B, vals)
end

push!(args, al)

push!(activity, Ty)
Expand All @@ -2900,13 +2911,21 @@ function enzyme_custom_setup_args(B, orig, gutils, mi)
end

llty = convert(LLVMType, Ty; ctx)
al0 = al = emit_allocobj!(B, Ty)
al = bitcast!(B, al, LLVM.PointerType(llty, addrspace(llvmtype(al))))
al = addrspacecast!(B, al, LLVM.PointerType(llty, 11))

arval = LLVM.UndefValue(llty)
arval = insert_value!(B, arval, val, 0)
arval = insert_value!(B, arval, ival, 1)
ptr = gep!(B, al, [LLVM.ConstantInt(LLVM.IntType(64; ctx), 0), LLVM.ConstantInt(LLVM.IntType(32; ctx), 0)])
store!(B, val, ptr)

iptr = gep!(B, al, [LLVM.ConstantInt(LLVM.IntType(64; ctx), 0), LLVM.ConstantInt(LLVM.IntType(32; ctx), 1)])
store!(B, ival, iptr)

al = alloca!(alloctx, llvmtype(arval))
store!(B, arval, al)
if any_jltypes(llty)
vals = LLVM.Value[al0, val, ival]
emit_writebarrier!(B, vals)
end

push!(args, al)
push!(activity, Ty)
end
Expand Down Expand Up @@ -3010,12 +3029,8 @@ function enzyme_custom_fwd(B::LLVM.API.LLVMBuilderRef, OrigCI::LLVM.API.LLVMValu
if llvmtype(args[i]) == party
continue
end
if LLVM.addrspace(party) != 0
args[i] = addrspacecast!(B, args[i], party)
else
GPUCompiler.@safe_error "Calling convention mismatch", party, args[i], i, llvmf
return
end
GPUCompiler.@safe_error "Calling convention mismatch", party, args[i], i, llvmf
return
end

res = LLVM.call!(B, llvmf, args)
Expand All @@ -3028,6 +3043,12 @@ function enzyme_custom_fwd(B::LLVM.API.LLVMBuilderRef, OrigCI::LLVM.API.LLVMValu
end

if sret !== nothing
attr = if LLVM.version().major >= 12
TypeAttribute("sret", eltype(llvmtype(parameters(llvmf)[1])); ctx)
else
EnumAttribute("sret"; ctx)
end
LLVM.API.LLVMAddCallSiteAttribute(res, LLVM.API.LLVMAttributeIndex(1), attr)
res = load!(B, sret)
end

Expand Down Expand Up @@ -3199,12 +3220,8 @@ function enzyme_custom_common_rev(forward::Bool, B::LLVM.API.LLVMBuilderRef, Ori
if llvmtype(args[i]) == party
continue
end
if isa(party, LLVM.PointerType) && LLVM.addrspace(party) != 0
args[i] = addrspacecast!(B, args[i], party)
else
GPUCompiler.@safe_error "Calling convention mismatch", party, args[i], i, llvmf, augprimal_TT, rev_TT
return tapeV
end
GPUCompiler.@safe_error "Calling convention mismatch", party, args[i], i, llvmf, augprimal_TT, rev_TT
return tapeV
end

res = LLVM.call!(B, llvmf, args)
Expand All @@ -3217,6 +3234,12 @@ function enzyme_custom_common_rev(forward::Bool, B::LLVM.API.LLVMBuilderRef, Ori
end

if sret !== nothing
attr = if LLVM.version().major >= 12
TypeAttribute("sret", eltype(llvmtype(parameters(llvmf)[1])); ctx)
else
EnumAttribute("sret"; ctx)
end
LLVM.API.LLVMAddCallSiteAttribute(res, LLVM.API.LLVMAttributeIndex(1), attr)
res = load!(B, sret)
end

Expand Down

0 comments on commit e96be3f

Please sign in to comment.