Skip to content

Commit

Permalink
Add nthfield (#594)
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses authored Feb 3, 2023
1 parent d6a2a42 commit 5fce502
Showing 1 changed file with 64 additions and 1 deletion.
65 changes: 64 additions & 1 deletion src/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1962,6 +1962,63 @@ function jlcall2_rev(B::LLVM.API.LLVMBuilderRef, OrigCI::LLVM.API.LLVMValueRef,
return nothing
end

function jl_nthfield_fwd(B::LLVM.API.LLVMBuilderRef, OrigCI::LLVM.API.LLVMValueRef, gutils::API.EnzymeGradientUtilsRef, normalR::Ptr{LLVM.API.LLVMValueRef}, shadowR::Ptr{LLVM.API.LLVMValueRef})::Cvoid
if shadowR != C_NULL
orig = LLVM.Instruction(OrigCI)
origops = collect(operands(orig))
width = API.EnzymeGradientUtilsGetWidth(gutils)
if API.EnzymeGradientUtilsIsConstantValue(gutils, origops[1]) == 0
B = LLVM.Builder(B)

shadowin = LLVM.Value(API.EnzymeGradientUtilsInvertPointer(gutils, origops[1], B))
if width == 1
args = LLVM.Value[
shadowin
LLVM.Value(API.EnzymeGradientUtilsNewFromOriginal(gutils, origops[2]))
]
shadowres = LLVM.call!(B, LLVM.called_value(orig), args)
conv = LLVM.API.LLVMGetInstructionCallConv(orig)
LLVM.API.LLVMSetInstructionCallConv(shadowres, conv)
else
shadowres = UndefValue(LLVM.LLVMType(API.EnzymeGetShadowType(width, llvmtype(orig))))
for idx in 1:width
args = LLVM.Value[
extract_value!(B, shadowin, idx-1)
LLVM.Value(API.EnzymeGradientUtilsNewFromOriginal(gutils, origops[2]))
]
tmp = LLVM.call!(B, LLVM.called_value(orig), args)
conv = LLVM.API.LLVMGetInstructionCallConv(orig)
LLVM.API.LLVMSetInstructionCallConv(tmp, conv)
shadowres = insert_value!(B, shadowres, tmp, idx-1)
end
end
unsafe_store!(shadowR, shadowres.ref)
else
normal = LLVM.Value(API.EnzymeGradientUtilsNewFromOriginal(gutils, orig))
if width == 1
shadowres = normal
else
shadowres = UndefValue(LLVM.LLVMType(API.EnzymeGetShadowType(width, llvmtype(normal))))
for idx in 1:width
shadowres = insert_value!(B, shadowres, normal, idx-1)
end
end
unsafe_store!(shadowR, shadowres.ref)
end
end
return nothing
end
function jl_nthfield_augfwd(B::LLVM.API.LLVMBuilderRef, OrigCI::LLVM.API.LLVMValueRef, gutils::API.EnzymeGradientUtilsRef, normalR::Ptr{LLVM.API.LLVMValueRef}, shadowR::Ptr{LLVM.API.LLVMValueRef}, tapeR::Ptr{LLVM.API.LLVMValueRef})::Cvoid
jl_nthfield_fwd(B, OrigCI, gutils, normalR, shadowR)
end
function jl_nthfield_rev(B::LLVM.API.LLVMBuilderRef, OrigCI::LLVM.API.LLVMValueRef, gutils::API.EnzymeGradientUtilsRef, tape::LLVM.API.LLVMValueRef)::Cvoid
orig = LLVM.Instruction(OrigCI)
if API.EnzymeGradientUtilsIsConstantValue(gutils, orig) == 0
emit_error(LLVM.Builder(B), orig, "Enzyme: not yet implemented in reverse mode, jl_nthfield")
end
return nothing
end

function common_invoke_fwd(offset, B::LLVM.API.LLVMBuilderRef, OrigCI::LLVM.API.LLVMValueRef, gutils::API.EnzymeGradientUtilsRef, normalR::Ptr{LLVM.API.LLVMValueRef}, shadowR::Ptr{LLVM.API.LLVMValueRef})::Cvoid
orig = LLVM.Instruction(OrigCI)

Expand Down Expand Up @@ -4606,6 +4663,12 @@ function __init__()
@cfunction(jl_getfield_rev, Cvoid, (LLVM.API.LLVMBuilderRef, LLVM.API.LLVMValueRef, API.EnzymeGradientUtilsRef, LLVM.API.LLVMValueRef)),
@cfunction(jl_getfield_fwd, Cvoid, (LLVM.API.LLVMBuilderRef, LLVM.API.LLVMValueRef, API.EnzymeGradientUtilsRef, Ptr{LLVM.API.LLVMValueRef}, Ptr{LLVM.API.LLVMValueRef})),
)
register_handler!(
("ijl_get_nth_field_checked","jl_get_nth_field_checked"),
@cfunction(jl_nthfield_augfwd, Cvoid, (LLVM.API.LLVMBuilderRef, LLVM.API.LLVMValueRef, API.EnzymeGradientUtilsRef, Ptr{LLVM.API.LLVMValueRef}, Ptr{LLVM.API.LLVMValueRef}, Ptr{LLVM.API.LLVMValueRef})),
@cfunction(jl_nthfield_rev, Cvoid, (LLVM.API.LLVMBuilderRef, LLVM.API.LLVMValueRef, API.EnzymeGradientUtilsRef, LLVM.API.LLVMValueRef)),
@cfunction(jl_nthfield_fwd, Cvoid, (LLVM.API.LLVMBuilderRef, LLVM.API.LLVMValueRef, API.EnzymeGradientUtilsRef, Ptr{LLVM.API.LLVMValueRef}, Ptr{LLVM.API.LLVMValueRef})),
)
register_handler!(
("jl_array_sizehint","ijl_array_sizehint"),
@cfunction(jl_array_sizehint_augfwd, Cvoid, (LLVM.API.LLVMBuilderRef, LLVM.API.LLVMValueRef, API.EnzymeGradientUtilsRef, Ptr{LLVM.API.LLVMValueRef}, Ptr{LLVM.API.LLVMValueRef}, Ptr{LLVM.API.LLVMValueRef})),
Expand Down Expand Up @@ -4768,7 +4831,7 @@ function annotate!(mod, mode)
end
end

for fname in ("jl_f_getfield",)
for fname in ("jl_f_getfield","ijl_f_getfield","jl_get_nth_field_checked","ijl_get_nth_field_checked")
if haskey(fns, fname)
fn = fns[fname]
push!(function_attributes(fn), LLVM.EnumAttribute("readonly", 0; ctx))
Expand Down

0 comments on commit 5fce502

Please sign in to comment.