Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handle closures #66

Merged
merged 1 commit into from
Jun 10, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions src/Enzyme.jl
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ prepare_cc(arg::Annotation, args...) = (arg.val, prepare_cc(args...)...)
ptr = Compiler.deferred_codegen(Val(f), Val(tt′), Val(true))
tt = Tuple{map(T->eltype(Core.Typeof(T)), args′)...}
rt = Core.Compiler.return_type(f, tt)
thunk = Compiler.CombinedAdjointThunk{F, rt, tt′}(ptr)
thunk = Compiler.CombinedAdjointThunk{F, rt, tt′}(f, ptr)
thunk(args′...)
end

Expand All @@ -72,7 +72,7 @@ end
ptr = Compiler.deferred_codegen(Val(f), Val(tt′), Val(false))
tt = Tuple{map(T->eltype(Core.Typeof(T)), args′)...}
rt = Core.Compiler.return_type(f, tt)
thunk = Compiler.CombinedAdjointThunk{F, rt, tt′}(ptr)
thunk = Compiler.CombinedAdjointThunk{F, rt, tt′}(f, ptr)
thunk(args′...)
end

Expand Down Expand Up @@ -104,7 +104,7 @@ for op in (asin,tanh)
for (T, llvm_t, suffix) in ((Float32, "float", "f"), (Float64, "double", ""))
mod = """
declare $llvm_t @$(nameof(op))$suffix($llvm_t)

define $llvm_t @entry($llvm_t) #0 {
%val = call $llvm_t @$op$suffix($llvm_t %0)
ret $llvm_t %val
Expand Down
83 changes: 57 additions & 26 deletions src/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,17 @@ function enzyme!(job, mod, primalf, adjoint, split, parallel)
args_known_values = API.IntList[]

ctx = LLVM.context(mod)
if !GPUCompiler.isghosttype(typeof(adjoint.f)) && !Core.Compiler.isconstType(typeof(adjoint.f))
push!(args_activity, API.DFT_CONSTANT)
typeTree = typetree(typeof(adjoint.f), ctx, dl)
push!(args_typeInfo, typeTree)
if split
push!(uncacheable_args, true)
else
push!(uncacheable_args, false)
end
push!(args_known_values, API.IntList())
end

for T in tt
source_typ = eltype(T)
Expand All @@ -159,7 +170,7 @@ function enzyme!(job, mod, primalf, adjoint, split, parallel)
continue
end
isboxed = GPUCompiler.deserves_argbox(source_typ)

if T <: Const
push!(args_activity, API.DFT_CONSTANT)
elseif T <: Active
Expand All @@ -173,11 +184,11 @@ function enzyme!(job, mod, primalf, adjoint, split, parallel)
push!(args_activity, API.DFT_DUP_ARG)
elseif T <: DuplicatedNoNeed
push!(args_activity, API.DFT_DUP_NONEED)
else
else
@assert("illegal annotation type")
end
T = source_typ
if isboxed
if isboxed
T = Ptr{T}
end
typeTree = typetree(T, ctx, dl)
Expand All @@ -197,7 +208,7 @@ function enzyme!(job, mod, primalf, adjoint, split, parallel)
# If requested, the shadow return value of the function
# For each active (non duplicated) argument
# The adjoint of that argument
if rt <: Integer
if rt <: Integer || rt <: DataType
retType = API.DFT_CONSTANT
elseif rt <: AbstractFloat
retType = API.DFT_OUT_DIFF
Expand All @@ -209,7 +220,7 @@ function enzyme!(job, mod, primalf, adjoint, split, parallel)
error("What even is $rt")
end

TA = TypeAnalysis(triple(mod))
TA = TypeAnalysis(triple(mod))
logic = Logic()

if GPUCompiler.isghosttype(rt)|| Core.Compiler.isconstType(rt)
Expand Down Expand Up @@ -402,7 +413,7 @@ function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget};
else
target_machine = GPUCompiler.llvm_machine(primal_job.target)
end

parallel = false
process_module = false
if parent_job !== nothing
Expand Down Expand Up @@ -469,17 +480,18 @@ end

##
# Thunk
##
##

struct CombinedAdjointThunk{f, RT, TT}
struct CombinedAdjointThunk{F, RT, TT}
fn::F
# primal::Ptr{Cvoid}
adjoint::Ptr{Cvoid}
end

@inline (thunk::CombinedAdjointThunk{F, RT, TT})(args...) where {F, RT, TT} =
enzyme_call(thunk.adjoint, TT, RT, args...)
enzyme_call(thunk.adjoint, thunk.fn, TT, RT, args...)

@generated function enzyme_call(f::Ptr{Cvoid}, tt::Type{T}, rt::Type{RT}, args::Vararg{Any, N}) where {T, RT, N}
@generated function enzyme_call(fptr::Ptr{Cvoid}, f::F, tt::Type{T}, rt::Type{RT}, args::Vararg{Any, N}) where {F, T, RT, N}
argtt = tt.parameters[1]
rettype = rt.parameters[1]
argtypes = DataType[argtt.parameters...]
Expand Down Expand Up @@ -508,6 +520,20 @@ end
# By ref values we create and need to preserve
ccexprs = Union{Expr, Symbol}[] # The expressions passed to the `llvmcall`

if !GPUCompiler.isghosttype(F) && !Core.Compiler.isconstType(F)
isboxed = GPUCompiler.deserves_argbox(F)
llvmT = isboxed ? T_prjlvalue : convert(LLVMType, F, ctx)
argexpr = :(f)
if isboxed
push!(types, Any)
else
push!(types, F)
end

push!(ccexprs, argexpr)
push!(T_wrapperargs, llvmT)
end

for (i, T) in enumerate(argtypes)
source_typ = eltype(T)
if GPUCompiler.isghosttype(source_typ) || Core.Compiler.isconstType(source_typ)
Expand All @@ -528,7 +554,7 @@ end

push!(ccexprs, argexpr)
push!(T_wrapperargs, llvmT)

T <: Const && continue

if T <: Active
Expand Down Expand Up @@ -591,29 +617,34 @@ end
realparms = LLVM.Value[]
i = target+1

if !isempty(T_JuliaSRet)
if !isempty(T_JuliaSRet)
sret = inttoptr!(builder, params[1], LLVM.PointerType(LLVM.StructType(T_JuliaSRet)))
end

activeNum = 0

if !GPUCompiler.isghosttype(F) && !Core.Compiler.isconstType(F)
push!(realparms, params[i])
i+=1
end

for T in argtypes
T′ = eltype(T)

if GPUCompiler.isghosttype(T′) || Core.Compiler.isconstType(T′)
continue
end
isboxed = GPUCompiler.deserves_argbox(T′)
push!(realparms, params[i])
i+=1
if T <: Const
elseif T <: Active
isboxed = GPUCompiler.deserves_argbox(T′)
if isboxed
ptr = gep!(builder, sret, [LLVM.ConstantInt(LLVM.IntType(64, ctx), 0), LLVM.ConstantInt(LLVM.IntType(32, ctx), activeNum)])
cst = pointercast!(builder, ptr, ptr8)
push!(realparms, ptr)

cparms = LLVM.Value[cst,
cparms = LLVM.Value[cst,
LLVM.ConstantInt(LLVM.IntType(8, ctx), 0),
LLVM.ConstantInt(LLVM.IntType(64, ctx), LLVM.storage_size(dl, Base.eltype(LLVM.llvmtype(ptr)) )),
LLVM.ConstantInt(LLVM.IntType(1, ctx), 0)]
Expand All @@ -626,8 +657,8 @@ end
end
end

# Primal Return type
if i <= size(params, 1)
# Primal Differential Return type
if rettype <: AbstractFloat || rettype <: Complex{<:AbstractFloat}
push!(realparms, params[i])
end

Expand All @@ -640,7 +671,7 @@ end

ptr = inttoptr!(builder, params[target], LLVM.PointerType(ft))
val = call!(builder, ptr, realparms)
if !isempty(T_JuliaSRet)
if !isempty(T_JuliaSRet)
activeNum = 0
returnNum = 0
for T in argtypes
Expand All @@ -662,25 +693,25 @@ end
ir = string(mod)
fn = LLVM.name(llvm_f)

if !isempty(T_JuliaSRet)
if !isempty(T_JuliaSRet)
quote
Base.@_inline_meta
sret = Ref{$(Tuple{sret_types...})}()
GC.@preserve sret begin
ptr = Base.unsafe_convert(Ptr{$(Tuple{sret_types...})}, sret)
ptr = Base.unsafe_convert(Ptr{Cvoid}, ptr)
tptr = Base.unsafe_convert(Ptr{$(Tuple{sret_types...})}, sret)
tptr = Base.unsafe_convert(Ptr{Cvoid}, tptr)
Base.llvmcall(($ir,$fn), Cvoid,
$(Tuple{Ptr{Cvoid}, Ptr{Cvoid}, types...}),
ptr, f, $(ccexprs...))
tptr, fptr, $(ccexprs...))
end
sret[]
end
else
else
quote
Base.@_inline_meta
Base.llvmcall(($ir,$fn), Cvoid,
$(Tuple{Ptr{Cvoid}, types...}),
f, $(ccexprs...))
fptr, $(ccexprs...))
end
end
end
Expand Down Expand Up @@ -728,7 +759,7 @@ function _link(job, (mod, adjoint_name, primal_name))
adjoint = params.adjoint
split = params.split

primal = job.source
primal = job.source
rt = Core.Compiler.return_type(primal.f, primal.tt)

# Now invoke the JIT
Expand All @@ -751,7 +782,7 @@ function _link(job, (mod, adjoint_name, primal_name))
end

@assert primal_name === nothing
return CombinedAdjointThunk{typeof(adjoint.f), rt, adjoint.tt}(#=primal_ptr,=# adjoint_ptr)
return CombinedAdjointThunk{typeof(adjoint.f), rt, adjoint.tt}(adjoint.f, #=primal_ptr,=# adjoint_ptr)
end

# actual compilation
Expand Down Expand Up @@ -859,4 +890,4 @@ end
include("compiler/reflection.jl")
include("compiler/validation.jl")

end
end
11 changes: 7 additions & 4 deletions src/compiler/validation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,9 @@ function check_ir!(job, errors, imported, inst::LLVM.CallInst)
data = open(flib, "r") do io
lib = readmeta(io)
sections = Sections(lib)
if !(".llvmbc" in sections)
return nothing
end
llvmbc = read(findfirst(sections, ".llvmbc"))
return llvmbc
end
Expand Down Expand Up @@ -247,7 +250,7 @@ function check_ir!(job, errors, imported, inst::LLVM.CallInst)
end
end
end

b = Builder(ctx)

position!(b, inst)
Expand All @@ -271,7 +274,7 @@ function check_ir!(job, errors, imported, inst::LLVM.CallInst)
end
end
end

b = Builder(ctx)
position!(b, inst)
replace_uses!(inst, LLVM.inttoptr!(b, replaceWith, llvmtype(inst)))
Expand Down Expand Up @@ -327,8 +330,8 @@ function check_ir!(job, errors, imported, inst::LLVM.CallInst)
if ptr == cglobal(:malloc)
fn = "malloc"
end
if length(fn) > 1 && fromC

if length(fn) > 1 && fromC
mod = LLVM.parent(LLVM.parent(LLVM.parent(inst)))
lfn = LLVM.API.LLVMGetNamedFunction(mod, fn)
if lfn == C_NULL
Expand Down
18 changes: 11 additions & 7 deletions src/typetree.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ LLVM.dispose(tt::TypeTree) = API.EnzymeFreeTypeTree(tt)

TypeTree() = TypeTree(API.EnzymeNewTypeTree())
TypeTree(CT, ctx) = TypeTree(API.EnzymeNewTypeTreeCT(CT, ctx))
function TypeTree(CT, idx, ctx)
function TypeTree(CT, idx, ctx)
tt = TypeTree(CT, ctx)
only!(tt, idx)
return tt
Expand Down Expand Up @@ -73,6 +73,10 @@ function typetree(::Type{Float64}, ctx, dl)
return TypeTree(API.DT_Double, -1, ctx)
end

function typetree(::Type{<:DataType}, ctx, dl)
return TypeTree()
end

function typetree(::Type{<:Union{Ptr{T}, Core.LLVMPtr{T}}}, ctx, dl) where T
tt = typetree(T, ctx, dl)
merge!(tt, TypeTree(API.DT_Pointer, ctx))
Expand Down Expand Up @@ -123,9 +127,9 @@ function typetree(@nospecialize(T), ctx, dl)
if subT.isinlinealloc
shift!(subtree, dl, 0, sizeof(subT), offset)
else
merge!(subtree, TypeTree(API.DT_Pointer, ctx))
merge!(subtree, TypeTree(API.DT_Pointer, ctx))
only!(subtree, offset)
end
end

merge!(tt, subtree)
end
Expand All @@ -139,14 +143,14 @@ struct FnTypeInfo
end
Base.cconvert(::Type{API.CFnTypeInfo}, fnti::FnTypeInfo) = fnti
function Base.unsafe_convert(::Type{API.CFnTypeInfo}, fnti::FnTypeInfo)
args_kv = Base.unsafe_convert(Ptr{API.IntList}, Base.cconvert(Ptr{API.IntList}, fnti.known_values))
rTT = Base.unsafe_convert(API.CTypeTreeRef, Base.cconvert(API.CTypeTreeRef, fnti.rTT))
args_kv = Base.unsafe_convert(Ptr{API.IntList}, Base.cconvert(Ptr{API.IntList}, fnti.known_values))
rTT = Base.unsafe_convert(API.CTypeTreeRef, Base.cconvert(API.CTypeTreeRef, fnti.rTT))

tts = API.CTypeTreeRef[]
for tt in fnti.argTTs
raw_tt = Base.unsafe_convert(API.CTypeTreeRef, Base.cconvert(API.CTypeTreeRef, tt))
raw_tt = Base.unsafe_convert(API.CTypeTreeRef, Base.cconvert(API.CTypeTreeRef, tt))
push!(tts, raw_tt)
end
argTTs = Base.unsafe_convert(Ptr{API.CTypeTreeRef}, Base.cconvert(Ptr{API.CTypeTreeRef}, tts))
argTTs = Base.unsafe_convert(Ptr{API.CTypeTreeRef}, Base.cconvert(Ptr{API.CTypeTreeRef}, tts))
return API.CFnTypeInfo(argTTs, rTT, args_kv)
end
Loading