Skip to content

Commit

Permalink
Recursive codegen
Browse files Browse the repository at this point in the history
  • Loading branch information
vchuravy committed Mar 22, 2022
1 parent 847ed40 commit 77e9f6f
Showing 1 changed file with 91 additions and 5 deletions.
96 changes: 91 additions & 5 deletions src/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1664,6 +1664,80 @@ function wait_rev(B::LLVM.API.LLVMBuilderRef, OrigCI::LLVM.API.LLVMValueRef, gut
return nothing
end


function enzyme_custom_fwd(B::LLVM.API.LLVMBuilderRef, OrigCI::LLVM.API.LLVMValueRef, gutils::API.EnzymeGradientUtilsRef, normalR::Ptr{LLVM.API.LLVMValueRef}, shadowR::Ptr{LLVM.API.LLVMValueRef})::Cvoid
normal = (unsafe_load(normalR) != C_NULL) ? LLVM.Instruction(unsafe_load(normalR)) : nothing
if shadowR != C_NULL && normal !== nothing
unsafe_store!(shadowR, normal.ref)
end

emit_error(LLVM.Builder(B), "Enzyme: Not yet implemented custom forward handler")
return nothing
end

function enzyme_custom_augfwd(B::LLVM.API.LLVMBuilderRef, OrigCI::LLVM.API.LLVMValueRef, gutils::API.EnzymeGradientUtilsRef, normalR::Ptr{LLVM.API.LLVMValueRef}, shadowR::Ptr{LLVM.API.LLVMValueRef}, tapeR::Ptr{LLVM.API.LLVMValueRef})::Cvoid
normal = (unsafe_load(normalR) != C_NULL) ? LLVM.Instruction(unsafe_load(normalR)) : nothing
if shadowR != C_NULL && normal !== nothing
unsafe_store!(shadowR, normal.ref)
end

orig = LLVM.Instruction(OrigCI)

# TODO: don't inject the code multiple times for multiple calls

# 1) extract out the MI from metadata
mi = ...

active = API.EnzymeGradientUtilsIsConstantValue(gutils, orig) == 0


ops = collect(operands(orig))[1:end-1]

args = LLVM.Value[]

for op in ops
val = LLVM.Value(API.EnzymeGradientUtilsNewFromOriginal(gutils, op))
push!(args, val)

active = API.EnzymeGradientUtilsIsConstantValue(gutils, op) == 0
# TODO type analysis deduce if duplicated vs active
if active
push!(args, LLVM.Value(API.EnzymeGradientUtilsInvertPointer(gutils, op, B)))
end
end

tt = annotate_tuple_type(mi.specTypes, activity)
funcspec = FunctionSpec(EnzymeRules.augmented_forward, tt, #=kernel=# false, #=name=# nothing)

# TODO: GPU support
# 2) Use the MI to create the correct augmented fwd/reverse
target = GPUCompiler.NativeCompilerTarget()
params = Compiler.PrimalCompilerParams()
job = CompilerJob(target, funcspec, params)

otherMod, meta = GPUCompiler.codegen(:llvm, job, optimize=false, validate=false)
entry = name(meta.entry)

# 3) Link the corresponding module
builder = LLVM.Builder(B)
bb = LLVM.position(builder)
mod = LLVM.parent(LLVM.parent(bb))
LLVM.link!(mod, otherMod)

# 4) Call the function
entry = functions(mod)[entry]

emit_error(builder, "Enzyme: Not yet implemented custom augmented forward handler")

return nothing
end


function enzyme_custom_rev(B::LLVM.API.LLVMBuilderRef, OrigCI::LLVM.API.LLVMValueRef, gutils::API.EnzymeGradientUtilsRef, tape::LLVM.API.LLVMValueRef)::Cvoid
emit_error(LLVM.Builder(B), "Enzyme: Not yet implemented custom reverse handler")
return nothing
end

function arraycopy_fwd(B::LLVM.API.LLVMBuilderRef, OrigCI::LLVM.API.LLVMValueRef, gutils::API.EnzymeGradientUtilsRef, normalR::Ptr{LLVM.API.LLVMValueRef}, shadowR::Ptr{LLVM.API.LLVMValueRef})::Cvoid

orig = LLVM.Instruction(OrigCI)
Expand Down Expand Up @@ -2216,6 +2290,12 @@ function __init__()
@cfunction(enq_work_rev, Cvoid, (LLVM.API.LLVMBuilderRef, LLVM.API.LLVMValueRef, API.EnzymeGradientUtilsRef, LLVM.API.LLVMValueRef)),
@cfunction(enq_work_fwd, Cvoid, (LLVM.API.LLVMBuilderRef, LLVM.API.LLVMValueRef, API.EnzymeGradientUtilsRef, Ptr{LLVM.API.LLVMValueRef}, Ptr{LLVM.API.LLVMValueRef}))
)
register_handler!(
("enzyme_custom",),
@cfunction(enzyme_custom_augfwd, Cvoid, (LLVM.API.LLVMBuilderRef, LLVM.API.LLVMValueRef, API.EnzymeGradientUtilsRef, Ptr{LLVM.API.LLVMValueRef}, Ptr{LLVM.API.LLVMValueRef}, Ptr{LLVM.API.LLVMValueRef})),
@cfunction(enzyme_custom_rev, Cvoid, (LLVM.API.LLVMBuilderRef, LLVM.API.LLVMValueRef, API.EnzymeGradientUtilsRef, LLVM.API.LLVMValueRef)),
@cfunction(enzyme_custom_fwd, Cvoid, (LLVM.API.LLVMBuilderRef, LLVM.API.LLVMValueRef, API.EnzymeGradientUtilsRef, Ptr{LLVM.API.LLVMValueRef}, Ptr{LLVM.API.LLVMValueRef}))
)
register_handler!(
("jl_wait",),
@cfunction(wait_augfwd, Cvoid, (LLVM.API.LLVMBuilderRef, LLVM.API.LLVMValueRef, API.EnzymeGradientUtilsRef, Ptr{LLVM.API.LLVMValueRef}, Ptr{LLVM.API.LLVMValueRef}, Ptr{LLVM.API.LLVMValueRef})),
Expand Down Expand Up @@ -3233,12 +3313,11 @@ function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget};
Base.FastMath.tanh_fast => (:tanh, 1)
)
actualRetType = nothing
customDerivativeNames = String[]
for (mi, k) in meta.compiled
k_name = GPUCompiler.safe_name(k.specfunc)
if has_rule(mi.specTypes)
@warn "Rule support not yet implemented"
continue
elseif !haskey(functions(mod), k_name)
has_custom_rule = has_rule(mi.specTypes)
if !(haskey(functions(mod), k_name) || has_custom_rule)
continue
end

Expand Down Expand Up @@ -3271,7 +3350,14 @@ function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget};
end

sparam_vals = mi.specTypes.parameters[2:end] # mi.sparam_vals
if func == Base.println || func == Base.print || func == Base.show ||
if has_custom_rule
attributes = function_attributes(llvmfn)
push!(attributes, EnumAttribute("noinline", 0; ctx))
push!(attributes, StringAttribute("enzymejl_mi", string(convert(Int, pointer_from_objref(mi))); ctx))
push!(attributes, StringAttribute("enzyme_math", "enzyme_custom"; ctx))
push!(custom, k.specfunc)
continue
else if func == Base.println || func == Base.print || func == Base.show ||
func == Base.flush || func == Base.string || func == Base.print_to_string
handleCustom("enz_noop", [StringAttribute("enzyme_inactive"; ctx)])
continue
Expand Down

0 comments on commit 77e9f6f

Please sign in to comment.