Skip to content

Commit

Permalink
extract MI
Browse files Browse the repository at this point in the history
  • Loading branch information
vchuravy committed Mar 22, 2022
1 parent 77e9f6f commit e78ce1a
Showing 1 changed file with 30 additions and 11 deletions.
41 changes: 30 additions & 11 deletions src/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1682,15 +1682,32 @@ function enzyme_custom_augfwd(B::LLVM.API.LLVMBuilderRef, OrigCI::LLVM.API.LLVMV
end

orig = LLVM.Instruction(OrigCI)
ctx = LLVM.context(orig)

# TODO: don't inject the code multiple times for multiple calls

# 1) extract out the MI from metadata
mi = ...

active = API.EnzymeGradientUtilsIsConstantValue(gutils, orig) == 0
# 1) extract out the MI from attributes
llvmfn = LLVM.called_value(orig)
mi = nothing
for fattr in collect(function_attributes(llvmfn))
if isa(fattr, LLVM.StringAttribute)
if kind(fattr) == "enzymejl_mi"
ptr = reinterpret(Ptr{Cvoid}, parse(Int, LLVM.value(fattr)))
mi = Base.unsafe_pointer_to_objref(ptr)
break
end
end
end
builder = LLVM.Builder(B)

if mi === nothing
emit_error(builder, "Enzyme: Custom augmented forward handler, could not find MI")
end
emit_error(builder, "Enzyme: Custom augmented forward handler, not yet implemented")
return nothing

# TODO: don't inject the code multiple times for multiple calls

# 2) Create activity, and annotate function spec
active = API.EnzymeGradientUtilsIsConstantValue(gutils, orig) == 0
ops = collect(operands(orig))[1:end-1]

args = LLVM.Value[]
Expand All @@ -1709,22 +1726,24 @@ function enzyme_custom_augfwd(B::LLVM.API.LLVMBuilderRef, OrigCI::LLVM.API.LLVMV
tt = annotate_tuple_type(mi.specTypes, activity)
funcspec = FunctionSpec(EnzymeRules.augmented_forward, tt, #=kernel=# false, #=name=# nothing)

# TODO: GPU support
# 2) Use the MI to create the correct augmented fwd/reverse
# 3) Use the MI to create the correct augmented fwd/reverse
# TODO:
# - GPU support
# - When OrcV2 only use a MaterializationUnit to avoid mutation of the module here

target = GPUCompiler.NativeCompilerTarget()
params = Compiler.PrimalCompilerParams()
job = CompilerJob(target, funcspec, params)

otherMod, meta = GPUCompiler.codegen(:llvm, job, optimize=false, validate=false)
entry = name(meta.entry)

# 3) Link the corresponding module
builder = LLVM.Builder(B)
# 4) Link the corresponding module
bb = LLVM.position(builder)
mod = LLVM.parent(LLVM.parent(bb))
LLVM.link!(mod, otherMod)

# 4) Call the function
# 5) Call the function
entry = functions(mod)[entry]

emit_error(builder, "Enzyme: Not yet implemented custom augmented forward handler")
Expand Down

0 comments on commit e78ce1a

Please sign in to comment.