From 5fce502cc31fa6e2da6d0b5e86dc79f03a816b5c Mon Sep 17 00:00:00 2001 From: William Moses Date: Fri, 3 Feb 2023 13:46:44 -0500 Subject: [PATCH] Add nthfield (#594) --- src/compiler.jl | 65 ++++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 64 insertions(+), 1 deletion(-) diff --git a/src/compiler.jl b/src/compiler.jl index 128eb9d075..e46c1c23a4 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -1962,6 +1962,63 @@ function jlcall2_rev(B::LLVM.API.LLVMBuilderRef, OrigCI::LLVM.API.LLVMValueRef, return nothing end +function jl_nthfield_fwd(B::LLVM.API.LLVMBuilderRef, OrigCI::LLVM.API.LLVMValueRef, gutils::API.EnzymeGradientUtilsRef, normalR::Ptr{LLVM.API.LLVMValueRef}, shadowR::Ptr{LLVM.API.LLVMValueRef})::Cvoid + if shadowR != C_NULL + orig = LLVM.Instruction(OrigCI) + origops = collect(operands(orig)) + width = API.EnzymeGradientUtilsGetWidth(gutils) + if API.EnzymeGradientUtilsIsConstantValue(gutils, origops[1]) == 0 + B = LLVM.Builder(B) + + shadowin = LLVM.Value(API.EnzymeGradientUtilsInvertPointer(gutils, origops[1], B)) + if width == 1 + args = LLVM.Value[ + shadowin + LLVM.Value(API.EnzymeGradientUtilsNewFromOriginal(gutils, origops[2])) + ] + shadowres = LLVM.call!(B, LLVM.called_value(orig), args) + conv = LLVM.API.LLVMGetInstructionCallConv(orig) + LLVM.API.LLVMSetInstructionCallConv(shadowres, conv) + else + shadowres = UndefValue(LLVM.LLVMType(API.EnzymeGetShadowType(width, llvmtype(orig)))) + for idx in 1:width + args = LLVM.Value[ + extract_value!(B, shadowin, idx-1) + LLVM.Value(API.EnzymeGradientUtilsNewFromOriginal(gutils, origops[2])) + ] + tmp = LLVM.call!(B, LLVM.called_value(orig), args) + conv = LLVM.API.LLVMGetInstructionCallConv(orig) + LLVM.API.LLVMSetInstructionCallConv(tmp, conv) + shadowres = insert_value!(B, shadowres, tmp, idx-1) + end + end + unsafe_store!(shadowR, shadowres.ref) + else + normal = LLVM.Value(API.EnzymeGradientUtilsNewFromOriginal(gutils, orig)) + if width == 1 + shadowres = normal + else + shadowres = UndefValue(LLVM.LLVMType(API.EnzymeGetShadowType(width, llvmtype(normal)))) + for idx in 1:width + shadowres = insert_value!(B, shadowres, normal, idx-1) + end + end + unsafe_store!(shadowR, shadowres.ref) + end + end + return nothing +end +function jl_nthfield_augfwd(B::LLVM.API.LLVMBuilderRef, OrigCI::LLVM.API.LLVMValueRef, gutils::API.EnzymeGradientUtilsRef, normalR::Ptr{LLVM.API.LLVMValueRef}, shadowR::Ptr{LLVM.API.LLVMValueRef}, tapeR::Ptr{LLVM.API.LLVMValueRef})::Cvoid + jl_nthfield_fwd(B, OrigCI, gutils, normalR, shadowR) +end +function jl_nthfield_rev(B::LLVM.API.LLVMBuilderRef, OrigCI::LLVM.API.LLVMValueRef, gutils::API.EnzymeGradientUtilsRef, tape::LLVM.API.LLVMValueRef)::Cvoid + orig = LLVM.Instruction(OrigCI) + if API.EnzymeGradientUtilsIsConstantValue(gutils, orig) == 0 + emit_error(LLVM.Builder(B), orig, "Enzyme: not yet implemented in reverse mode, jl_nthfield") + end + return nothing +end + function common_invoke_fwd(offset, B::LLVM.API.LLVMBuilderRef, OrigCI::LLVM.API.LLVMValueRef, gutils::API.EnzymeGradientUtilsRef, normalR::Ptr{LLVM.API.LLVMValueRef}, shadowR::Ptr{LLVM.API.LLVMValueRef})::Cvoid orig = LLVM.Instruction(OrigCI) @@ -4606,6 +4663,12 @@ function __init__() @cfunction(jl_getfield_rev, Cvoid, (LLVM.API.LLVMBuilderRef, LLVM.API.LLVMValueRef, API.EnzymeGradientUtilsRef, LLVM.API.LLVMValueRef)), @cfunction(jl_getfield_fwd, Cvoid, (LLVM.API.LLVMBuilderRef, LLVM.API.LLVMValueRef, API.EnzymeGradientUtilsRef, Ptr{LLVM.API.LLVMValueRef}, Ptr{LLVM.API.LLVMValueRef})), ) + register_handler!( + ("ijl_get_nth_field_checked","jl_get_nth_field_checked"), + @cfunction(jl_nthfield_augfwd, Cvoid, (LLVM.API.LLVMBuilderRef, LLVM.API.LLVMValueRef, API.EnzymeGradientUtilsRef, Ptr{LLVM.API.LLVMValueRef}, Ptr{LLVM.API.LLVMValueRef}, Ptr{LLVM.API.LLVMValueRef})), + @cfunction(jl_nthfield_rev, Cvoid, (LLVM.API.LLVMBuilderRef, LLVM.API.LLVMValueRef, API.EnzymeGradientUtilsRef, LLVM.API.LLVMValueRef)), + @cfunction(jl_nthfield_fwd, Cvoid, (LLVM.API.LLVMBuilderRef, LLVM.API.LLVMValueRef, API.EnzymeGradientUtilsRef, Ptr{LLVM.API.LLVMValueRef}, Ptr{LLVM.API.LLVMValueRef})), + ) register_handler!( ("jl_array_sizehint","ijl_array_sizehint"), @cfunction(jl_array_sizehint_augfwd, Cvoid, (LLVM.API.LLVMBuilderRef, LLVM.API.LLVMValueRef, API.EnzymeGradientUtilsRef, Ptr{LLVM.API.LLVMValueRef}, Ptr{LLVM.API.LLVMValueRef}, Ptr{LLVM.API.LLVMValueRef})), @@ -4768,7 +4831,7 @@ function annotate!(mod, mode) end end - for fname in ("jl_f_getfield",) + for fname in ("jl_f_getfield","ijl_f_getfield","jl_get_nth_field_checked","ijl_get_nth_field_checked") if haskey(fns, fname) fn = fns[fname] push!(function_attributes(fn), LLVM.EnumAttribute("readonly", 0; ctx))