From eeae6303ef2a49fe9dc0fd6a1b635a659dda675b Mon Sep 17 00:00:00 2001 From: Valentin Churavy Date: Wed, 4 Nov 2020 11:55:36 -0500 Subject: [PATCH] upgrade to GPUCompiler 0.8 --- Project.toml | 4 ++-- src/compiler/thunk.jl | 15 +++++++++++++-- 2 files changed, 15 insertions(+), 4 deletions(-) diff --git a/Project.toml b/Project.toml index b810f06cc0..e336158046 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Enzyme" uuid = "7da242da-08ed-463a-9acd-ee780be4f1d9" authors = ["William Moses ", "Valentin Churavy "] -version = "0.2.1" +version = "0.2.2" [deps] Cassette = "7057c7e9-c182-5462-911a-8362d720325c" @@ -13,6 +13,6 @@ Libdl = "8f399da3-3557-5675-b5ff-fb832c97cbdb" [compat] Cassette = "0.3" Enzyme_jll = "0.0.3" -GPUCompiler = "0.7" +GPUCompiler = "0.8" LLVM = "3.2" julia = "1.5" diff --git a/src/compiler/thunk.jl b/src/compiler/thunk.jl index 614524500a..bd2b870133 100644 --- a/src/compiler/thunk.jl +++ b/src/compiler/thunk.jl @@ -47,16 +47,27 @@ function resolver(name, ctx) return UInt64(reinterpret(UInt, ptr)) end +const cache = Dict{UInt, Dict{UInt, Any}}() + function thunk(f::F,tt::TT=Tuple{}) where {F<:Core.Function, TT<:Type} primal, adjoint, rt = fspec(f, tt) # We need to use primal as the key, to lookup the right method # but need to mixin the hash of the adjoint to avoid cache collisions - GPUCompiler.cached_compilation(_thunk, primal, hash(adjoint), adjoint=adjoint, rt=rt)::Thunk{F,rt,tt} + # This is counter-intuitive since we would expect the cache to be split + # by the primal, but we want the generated code to be invalidated by + # invalidations of the primal, which is managed by GPUCompiler. + local_cache = get!(Dict{Int, Any}, cache, hash(adjoint)) + + GPUCompiler.cached_compilation(local_cache, _thunk, _link, primal, adjoint=adjoint, rt=rt)::Thunk{F,rt,tt} +end + +function _link(@nospecialize(primal::FunctionSpec), thunk; kwargs...) + return thunk end # actual compilation -function _thunk(primal::FunctionSpec; adjoint, rt) +function _thunk(@nospecialize(primal::FunctionSpec); adjoint, rt) target = Compiler.EnzymeTarget() params = Compiler.EnzymeCompilerParams() job = Compiler.CompilerJob(target, primal, params)