Skip to content

Commit

Permalink
upgrade to GPUCompiler 0.8
Browse files Browse the repository at this point in the history
  • Loading branch information
vchuravy committed Nov 6, 2020
1 parent e22a306 commit eeae630
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 4 deletions.
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "Enzyme"
uuid = "7da242da-08ed-463a-9acd-ee780be4f1d9"
authors = ["William Moses <wmoses@mit.edu>", "Valentin Churavy <vchuravy@mit.edu>"]
version = "0.2.1"
version = "0.2.2"

[deps]
Cassette = "7057c7e9-c182-5462-911a-8362d720325c"
Expand All @@ -13,6 +13,6 @@ Libdl = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
[compat]
Cassette = "0.3"
Enzyme_jll = "0.0.3"
GPUCompiler = "0.7"
GPUCompiler = "0.8"
LLVM = "3.2"
julia = "1.5"
15 changes: 13 additions & 2 deletions src/compiler/thunk.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,16 +47,27 @@ function resolver(name, ctx)
return UInt64(reinterpret(UInt, ptr))
end

const cache = Dict{UInt, Dict{UInt, Any}}()

function thunk(f::F,tt::TT=Tuple{}) where {F<:Core.Function, TT<:Type}
primal, adjoint, rt = fspec(f, tt)

# We need to use primal as the key, to lookup the right method
# but need to mixin the hash of the adjoint to avoid cache collisions
GPUCompiler.cached_compilation(_thunk, primal, hash(adjoint), adjoint=adjoint, rt=rt)::Thunk{F,rt,tt}
# This is counter-intuitive since we would expect the cache to be split
# by the primal, but we want the generated code to be invalidated by
# invalidations of the primal, which is managed by GPUCompiler.
local_cache = get!(Dict{Int, Any}, cache, hash(adjoint))

GPUCompiler.cached_compilation(local_cache, _thunk, _link, primal, adjoint=adjoint, rt=rt)::Thunk{F,rt,tt}
end

function _link(@nospecialize(primal::FunctionSpec), thunk; kwargs...)
return thunk
end

# actual compilation
function _thunk(primal::FunctionSpec; adjoint, rt)
function _thunk(@nospecialize(primal::FunctionSpec); adjoint, rt)
target = Compiler.EnzymeTarget()
params = Compiler.EnzymeCompilerParams()
job = Compiler.CompilerJob(target, primal, params)
Expand Down

0 comments on commit eeae630

Please sign in to comment.