diff --git a/src/Enzyme.jl b/src/Enzyme.jl index 9819246f42..4082d7c26c 100644 --- a/src/Enzyme.jl +++ b/src/Enzyme.jl @@ -62,7 +62,7 @@ prepare_cc(arg::Annotation, args...) = (arg.val, prepare_cc(args...)...) ptr = Compiler.deferred_codegen(Val(f), Val(tt′), Val(true)) tt = Tuple{map(T->eltype(Core.Typeof(T)), args′)...} rt = Core.Compiler.return_type(f, tt) - thunk = Compiler.CombinedAdjointThunk{F, rt, tt′}(ptr) + thunk = Compiler.CombinedAdjointThunk{F, rt, tt′}(f, ptr) thunk(args′...) end @@ -72,7 +72,7 @@ end ptr = Compiler.deferred_codegen(Val(f), Val(tt′), Val(false)) tt = Tuple{map(T->eltype(Core.Typeof(T)), args′)...} rt = Core.Compiler.return_type(f, tt) - thunk = Compiler.CombinedAdjointThunk{F, rt, tt′}(ptr) + thunk = Compiler.CombinedAdjointThunk{F, rt, tt′}(f, ptr) thunk(args′...) end @@ -104,7 +104,7 @@ for op in (asin,tanh) for (T, llvm_t, suffix) in ((Float32, "float", "f"), (Float64, "double", "")) mod = """ declare $llvm_t @$(nameof(op))$suffix($llvm_t) - + define $llvm_t @entry($llvm_t) #0 { %val = call $llvm_t @$op$suffix($llvm_t %0) ret $llvm_t %val diff --git a/src/compiler.jl b/src/compiler.jl index 263ff778a2..a630101284 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -149,6 +149,17 @@ function enzyme!(job, mod, primalf, adjoint, split, parallel) args_known_values = API.IntList[] ctx = LLVM.context(mod) + if !GPUCompiler.isghosttype(typeof(adjoint.f)) && !Core.Compiler.isconstType(typeof(adjoint.f)) + push!(args_activity, API.DFT_CONSTANT) + typeTree = typetree(typeof(adjoint.f), ctx, dl) + push!(args_typeInfo, typeTree) + if split + push!(uncacheable_args, true) + else + push!(uncacheable_args, false) + end + push!(args_known_values, API.IntList()) + end for T in tt source_typ = eltype(T) @@ -159,7 +170,7 @@ function enzyme!(job, mod, primalf, adjoint, split, parallel) continue end isboxed = GPUCompiler.deserves_argbox(source_typ) - + if T <: Const push!(args_activity, API.DFT_CONSTANT) elseif T <: Active @@ -173,11 +184,11 @@ function enzyme!(job, mod, primalf, adjoint, split, parallel) push!(args_activity, API.DFT_DUP_ARG) elseif T <: DuplicatedNoNeed push!(args_activity, API.DFT_DUP_NONEED) - else + else @assert("illegal annotation type") end T = source_typ - if isboxed + if isboxed T = Ptr{T} end typeTree = typetree(T, ctx, dl) @@ -197,7 +208,7 @@ function enzyme!(job, mod, primalf, adjoint, split, parallel) # If requested, the shadow return value of the function # For each active (non duplicated) argument # The adjoint of that argument - if rt <: Integer + if rt <: Integer || rt <: DataType retType = API.DFT_CONSTANT elseif rt <: AbstractFloat retType = API.DFT_OUT_DIFF @@ -209,7 +220,7 @@ function enzyme!(job, mod, primalf, adjoint, split, parallel) error("What even is $rt") end - TA = TypeAnalysis(triple(mod)) + TA = TypeAnalysis(triple(mod)) logic = Logic() if GPUCompiler.isghosttype(rt)|| Core.Compiler.isconstType(rt) @@ -402,7 +413,7 @@ function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget}; else target_machine = GPUCompiler.llvm_machine(primal_job.target) end - + parallel = false process_module = false if parent_job !== nothing @@ -469,17 +480,18 @@ end ## # Thunk -## +## -struct CombinedAdjointThunk{f, RT, TT} +struct CombinedAdjointThunk{F, RT, TT} + fn::F # primal::Ptr{Cvoid} adjoint::Ptr{Cvoid} end @inline (thunk::CombinedAdjointThunk{F, RT, TT})(args...) where {F, RT, TT} = - enzyme_call(thunk.adjoint, TT, RT, args...) + enzyme_call(thunk.adjoint, thunk.fn, TT, RT, args...) -@generated function enzyme_call(f::Ptr{Cvoid}, tt::Type{T}, rt::Type{RT}, args::Vararg{Any, N}) where {T, RT, N} +@generated function enzyme_call(fptr::Ptr{Cvoid}, f::F, tt::Type{T}, rt::Type{RT}, args::Vararg{Any, N}) where {F, T, RT, N} argtt = tt.parameters[1] rettype = rt.parameters[1] argtypes = DataType[argtt.parameters...] @@ -508,6 +520,20 @@ end # By ref values we create and need to preserve ccexprs = Union{Expr, Symbol}[] # The expressions passed to the `llvmcall` + if !GPUCompiler.isghosttype(F) && !Core.Compiler.isconstType(F) + isboxed = GPUCompiler.deserves_argbox(F) + llvmT = isboxed ? T_prjlvalue : convert(LLVMType, F, ctx) + argexpr = :(f) + if isboxed + push!(types, Any) + else + push!(types, F) + end + + push!(ccexprs, argexpr) + push!(T_wrapperargs, llvmT) + end + for (i, T) in enumerate(argtypes) source_typ = eltype(T) if GPUCompiler.isghosttype(source_typ) || Core.Compiler.isconstType(source_typ) @@ -528,7 +554,7 @@ end push!(ccexprs, argexpr) push!(T_wrapperargs, llvmT) - + T <: Const && continue if T <: Active @@ -591,29 +617,34 @@ end realparms = LLVM.Value[] i = target+1 - if !isempty(T_JuliaSRet) + if !isempty(T_JuliaSRet) sret = inttoptr!(builder, params[1], LLVM.PointerType(LLVM.StructType(T_JuliaSRet))) end activeNum = 0 + if !GPUCompiler.isghosttype(F) && !Core.Compiler.isconstType(F) + push!(realparms, params[i]) + i+=1 + end + for T in argtypes T′ = eltype(T) if GPUCompiler.isghosttype(T′) || Core.Compiler.isconstType(T′) continue end - isboxed = GPUCompiler.deserves_argbox(T′) push!(realparms, params[i]) i+=1 if T <: Const elseif T <: Active + isboxed = GPUCompiler.deserves_argbox(T′) if isboxed ptr = gep!(builder, sret, [LLVM.ConstantInt(LLVM.IntType(64, ctx), 0), LLVM.ConstantInt(LLVM.IntType(32, ctx), activeNum)]) cst = pointercast!(builder, ptr, ptr8) push!(realparms, ptr) - cparms = LLVM.Value[cst, + cparms = LLVM.Value[cst, LLVM.ConstantInt(LLVM.IntType(8, ctx), 0), LLVM.ConstantInt(LLVM.IntType(64, ctx), LLVM.storage_size(dl, Base.eltype(LLVM.llvmtype(ptr)) )), LLVM.ConstantInt(LLVM.IntType(1, ctx), 0)] @@ -626,8 +657,8 @@ end end end - # Primal Return type - if i <= size(params, 1) + # Primal Differential Return type + if rettype <: AbstractFloat || rettype <: Complex{<:AbstractFloat} push!(realparms, params[i]) end @@ -640,7 +671,7 @@ end ptr = inttoptr!(builder, params[target], LLVM.PointerType(ft)) val = call!(builder, ptr, realparms) - if !isempty(T_JuliaSRet) + if !isempty(T_JuliaSRet) activeNum = 0 returnNum = 0 for T in argtypes @@ -662,25 +693,25 @@ end ir = string(mod) fn = LLVM.name(llvm_f) - if !isempty(T_JuliaSRet) + if !isempty(T_JuliaSRet) quote Base.@_inline_meta sret = Ref{$(Tuple{sret_types...})}() GC.@preserve sret begin - ptr = Base.unsafe_convert(Ptr{$(Tuple{sret_types...})}, sret) - ptr = Base.unsafe_convert(Ptr{Cvoid}, ptr) + tptr = Base.unsafe_convert(Ptr{$(Tuple{sret_types...})}, sret) + tptr = Base.unsafe_convert(Ptr{Cvoid}, tptr) Base.llvmcall(($ir,$fn), Cvoid, $(Tuple{Ptr{Cvoid}, Ptr{Cvoid}, types...}), - ptr, f, $(ccexprs...)) + tptr, fptr, $(ccexprs...)) end sret[] end - else + else quote Base.@_inline_meta Base.llvmcall(($ir,$fn), Cvoid, $(Tuple{Ptr{Cvoid}, types...}), - f, $(ccexprs...)) + fptr, $(ccexprs...)) end end end @@ -728,7 +759,7 @@ function _link(job, (mod, adjoint_name, primal_name)) adjoint = params.adjoint split = params.split - primal = job.source + primal = job.source rt = Core.Compiler.return_type(primal.f, primal.tt) # Now invoke the JIT @@ -751,7 +782,7 @@ function _link(job, (mod, adjoint_name, primal_name)) end @assert primal_name === nothing - return CombinedAdjointThunk{typeof(adjoint.f), rt, adjoint.tt}(#=primal_ptr,=# adjoint_ptr) + return CombinedAdjointThunk{typeof(adjoint.f), rt, adjoint.tt}(adjoint.f, #=primal_ptr,=# adjoint_ptr) end # actual compilation @@ -859,4 +890,4 @@ end include("compiler/reflection.jl") include("compiler/validation.jl") -end \ No newline at end of file +end diff --git a/src/compiler/validation.jl b/src/compiler/validation.jl index 202ff8bba8..4356056f1a 100644 --- a/src/compiler/validation.jl +++ b/src/compiler/validation.jl @@ -201,6 +201,9 @@ function check_ir!(job, errors, imported, inst::LLVM.CallInst) data = open(flib, "r") do io lib = readmeta(io) sections = Sections(lib) + if !(".llvmbc" in sections) + return nothing + end llvmbc = read(findfirst(sections, ".llvmbc")) return llvmbc end @@ -247,7 +250,7 @@ function check_ir!(job, errors, imported, inst::LLVM.CallInst) end end end - + b = Builder(ctx) position!(b, inst) @@ -271,7 +274,7 @@ function check_ir!(job, errors, imported, inst::LLVM.CallInst) end end end - + b = Builder(ctx) position!(b, inst) replace_uses!(inst, LLVM.inttoptr!(b, replaceWith, llvmtype(inst))) @@ -327,8 +330,8 @@ function check_ir!(job, errors, imported, inst::LLVM.CallInst) if ptr == cglobal(:malloc) fn = "malloc" end - - if length(fn) > 1 && fromC + + if length(fn) > 1 && fromC mod = LLVM.parent(LLVM.parent(LLVM.parent(inst))) lfn = LLVM.API.LLVMGetNamedFunction(mod, fn) if lfn == C_NULL diff --git a/src/typetree.jl b/src/typetree.jl index a73fb4aa3c..f9f5eefb41 100644 --- a/src/typetree.jl +++ b/src/typetree.jl @@ -11,7 +11,7 @@ LLVM.dispose(tt::TypeTree) = API.EnzymeFreeTypeTree(tt) TypeTree() = TypeTree(API.EnzymeNewTypeTree()) TypeTree(CT, ctx) = TypeTree(API.EnzymeNewTypeTreeCT(CT, ctx)) -function TypeTree(CT, idx, ctx) +function TypeTree(CT, idx, ctx) tt = TypeTree(CT, ctx) only!(tt, idx) return tt @@ -73,6 +73,10 @@ function typetree(::Type{Float64}, ctx, dl) return TypeTree(API.DT_Double, -1, ctx) end +function typetree(::Type{<:DataType}, ctx, dl) + return TypeTree() +end + function typetree(::Type{<:Union{Ptr{T}, Core.LLVMPtr{T}}}, ctx, dl) where T tt = typetree(T, ctx, dl) merge!(tt, TypeTree(API.DT_Pointer, ctx)) @@ -123,9 +127,9 @@ function typetree(@nospecialize(T), ctx, dl) if subT.isinlinealloc shift!(subtree, dl, 0, sizeof(subT), offset) else - merge!(subtree, TypeTree(API.DT_Pointer, ctx)) + merge!(subtree, TypeTree(API.DT_Pointer, ctx)) only!(subtree, offset) - end + end merge!(tt, subtree) end @@ -139,14 +143,14 @@ struct FnTypeInfo end Base.cconvert(::Type{API.CFnTypeInfo}, fnti::FnTypeInfo) = fnti function Base.unsafe_convert(::Type{API.CFnTypeInfo}, fnti::FnTypeInfo) - args_kv = Base.unsafe_convert(Ptr{API.IntList}, Base.cconvert(Ptr{API.IntList}, fnti.known_values)) - rTT = Base.unsafe_convert(API.CTypeTreeRef, Base.cconvert(API.CTypeTreeRef, fnti.rTT)) + args_kv = Base.unsafe_convert(Ptr{API.IntList}, Base.cconvert(Ptr{API.IntList}, fnti.known_values)) + rTT = Base.unsafe_convert(API.CTypeTreeRef, Base.cconvert(API.CTypeTreeRef, fnti.rTT)) tts = API.CTypeTreeRef[] for tt in fnti.argTTs - raw_tt = Base.unsafe_convert(API.CTypeTreeRef, Base.cconvert(API.CTypeTreeRef, tt)) + raw_tt = Base.unsafe_convert(API.CTypeTreeRef, Base.cconvert(API.CTypeTreeRef, tt)) push!(tts, raw_tt) end - argTTs = Base.unsafe_convert(Ptr{API.CTypeTreeRef}, Base.cconvert(Ptr{API.CTypeTreeRef}, tts)) + argTTs = Base.unsafe_convert(Ptr{API.CTypeTreeRef}, Base.cconvert(Ptr{API.CTypeTreeRef}, tts)) return API.CFnTypeInfo(argTTs, rTT, args_kv) end diff --git a/test/abi.jl b/test/abi.jl index 02cd987da8..64375071a4 100644 --- a/test/abi.jl +++ b/test/abi.jl @@ -3,7 +3,7 @@ using Test @testset "ABI & Calling convention" begin - f(x) = x + f(x) = x # GhostType -> Nothing res = autodiff(f, Const(nothing)) @@ -93,3 +93,38 @@ using Test # returns: sret, const/ghost, !deserve_retbox end + + +@testset "Callable ABI" begin + function method(f, x) + return f(x) + end + + struct AFoo + x::Float64 + end + + function (f::AFoo)(x::Float64) + return f.x * x + end + + @test Enzyme.autodiff_no_cassette(method, AFoo(2.0), Active(3.0))[1]≈ 2.0 + @test Enzyme.autodiff(method, AFoo(2.0), Active(3.0))[1]≈ 2.0 + + @test Enzyme.autodiff_no_cassette(AFoo(2.0), Active(3.0))[1]≈ 2.0 + @test Enzyme.autodiff(AFoo(2.0), Active(3.0))[1]≈ 2.0 + + + struct ABar + end + + function (f::ABar)(x::Float64) + return 2.0 * x + end + + @test Enzyme.autodiff_no_cassette(method, ABar(), Active(3.0))[1]≈ 2.0 + @test Enzyme.autodiff(method, ABar(), Active(3.0))[1]≈ 2.0 + + @test Enzyme.autodiff_no_cassette(ABar(), Active(3.0))[1]≈ 2.0 + @test Enzyme.autodiff(ABar(), Active(3.0))[1]≈ 2.0 +end