Skip to content

Commit

Permalink
Implement EnzymeRules
Browse files Browse the repository at this point in the history
  • Loading branch information
vchuravy committed Feb 3, 2023
1 parent 5fce502 commit e199976
Show file tree
Hide file tree
Showing 13 changed files with 901 additions and 49 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"

[compat]
CEnum = "0.4"
EnzymeCore = "0.1"
EnzymeCore = "0.2"
Enzyme_jll = "0.0.48"
GPUCompiler = "0.16.7, 0.17"
LLVM = "4.14"
Expand Down
2 changes: 1 addition & 1 deletion lib/EnzymeCore/Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "EnzymeCore"
uuid = "f151be2c-9106-41f4-ab19-57ee4f262869"
authors = ["William Moses <wmoses@mit.edu>", "Valentin Churavy <vchuravy@mit.edu>"]
version = "0.1.0"
version = "0.2.0"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Expand Down
2 changes: 2 additions & 0 deletions lib/EnzymeCore/src/EnzymeCore.jl
Original file line number Diff line number Diff line change
Expand Up @@ -126,4 +126,6 @@ struct ForwardMode <: Mode
end
const Forward = ForwardMode()

include("rules.jl")

end # module EnzymeCore
117 changes: 117 additions & 0 deletions lib/EnzymeCore/src/rules.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
module EnzymeRules

import EnzymeCore: Annotation
export Config, ConfigWidth
export needs_primal, needs_shadow, width, overwritten

import Base: unwrapva, isvarargtype, unwrap_unionall, rewrap_unionall

"""
forward(func::Annotation{typeof(f)}, RT::Type{<:Annotation}, args::Annotation...)
Calculate the forward derivative. The first argument `func` is the callable
for which the rule applies to. Either wrapped in a [`Const`](@ref)), or
a [`Duplicated`](@ref) if it is a closure.
The second argument is the return type annotation, and all other arguments are
the annotated function arguments.
"""
function forward end

struct Config{NeedsPrimal, NeedsShadow, Width, Overwritten} end
const ConfigWidth{Width} = Config{<:Any,<:Any, Width}

needs_primal(::Config{NeedsPrimal}) where NeedsPrimal = NeedsPrimal
needs_shadow(::Config{<:Any, NeedsShadow}) where NeedsShadow = NeedsShadow
width(::Config{<:Any, <:Any, Width}) where Width = Width
overwritten(::Config{<:Any, <:Any, <:Any, Overwritten}) where Overwritten = Overwritten

"""
augmented_primal(::Config, func::Annotation{typeof(f)}, RT::Type{<:Annotation}, args::Annotation...)
Must return a tuple of length 2.
The first-value is primal value and the second is the tape. If no tape is
required return `(val, nothing)`.
"""
function augmented_primal end

"""
reverse(::Config, func::Annotation{typeof(f)}, dret::Annotation, tape, args::Annotation...)
Takes gradient of derivative, activity annotation, and tape
"""
function reverse end

function _annotate(@nospecialize(T))
if isvarargtype(T)
VA = T
T = _annotate(VA.T)
if isdefined(VA, :N)
return Vararg{T, VA.N}
else
return Vararg{T}
end
else
return TypeVar(gensym(), Annotation{T})
end
end

function _annotate_tt(@nospecialize(TT0))
TT = Base.unwrap_unionall(TT0)
ft = TT.parameters[1]
tt = map(T->_annotate(Base.rewrap_unionall(T, TT0)), TT.parameters[2:end])
return ft, tt
end

function has_frule_from_sig(@nospecialize(TT); world=Base.get_world_counter())
ft, tt = _annotate_tt(TT)
TT = Tuple{<:Annotation{ft}, Type{<:Annotation}, tt...}
isapplicable(forward, TT; world)
end

function has_rrule_from_sig(@nospecialize(TT); world=Base.get_world_counter())
ft, tt = _annotate_tt(TT)
TT = Tuple{<:Config, <:Annotation{ft}, <:Annotation, <:Any, tt...}
isapplicable(reverse, TT; world)
end

function has_frule(@nospecialize(f); world=Base.get_world_counter())
TT = Tuple{<:Annotation{Core.Typeof(f)}, Type{<:Annotation}, Vararg{<:Annotation}}
isapplicable(forward, TT; world)
end

# Do we need this one?
function has_frule(@nospecialize(f), @nospecialize(TT::Type{<:Tuple}); world=Base.get_world_counter())
TT = Base.unwrap_unionall(TT)
TT = Tuple{<:Annotation{Core.Typeof(f)}, Type{<:Annotation}, TT.parameters...}
isapplicable(forward, TT; world)
end

# Do we need this one?
function has_frule(@nospecialize(f), @nospecialize(RT::Type); world=Base.get_world_counter())
TT = Tuple{<:Annotation{Core.Typeof(f)}, Type{RT}, Vararg{<:Annotation}}
isapplicable(forward, TT; world)
end

# Do we need this one?
function has_frule(@nospecialize(f), @nospecialize(RT::Type), @nospecialize(TT::Type{<:Tuple}); world=Base.get_world_counter())
TT = Base.unwrap_unionall(TT)
TT = Tuple{<:Annotation{Core.Typeof(f)}, Type{RT}, TT.parameters...}
isapplicable(forward, TT; world)
end

# Base.hasmethod is a precise match we want the broader query.
function isapplicable(@nospecialize(f), @nospecialize(TT); world=Base.get_world_counter())
tt = Base.to_tuple_type(TT)
sig = Base.signature_type(f, tt)
return !isempty(Base._methods_by_ftype(sig, -1, world)) # TODO cheaper way of querying?
end

function issupported()
@static if VERSION < v"1.7.0"
return false
else
return true
end
end

end # EnzymeRules
3 changes: 2 additions & 1 deletion src/Enzyme.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ export markType, batch_size, onehot, chunkedonehot
using LinearAlgebra
import EnzymeCore: ReverseMode, ForwardMode, Annotation, Mode

import EnzymeCore: EnzymeRules

# Independent code, must be loaded before "compiler.jl"
include("pmap.jl")

Expand Down Expand Up @@ -79,7 +81,6 @@ end
end
end


include("logic.jl")
include("typeanalysis.jl")
include("typetree.jl")
Expand Down
5 changes: 5 additions & 0 deletions src/api.jl
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,7 @@ EnzymeRegisterFwdCallHandler(name, fwdhandle) = ccall((:EnzymeRegisterFwdCallHan

EnzymeGetShadowType(width, T) = ccall((:EnzymeGetShadowType, libEnzyme), LLVMTypeRef, (UInt64,LLVMTypeRef), width, T)

EnzymeGradientUtilsGetMode(gutils) = ccall((:EnzymeGradientUtilsGetMode, libEnzyme), CDerivativeMode, (EnzymeGradientUtilsRef,), gutils)
EnzymeGradientUtilsGetWidth(gutils) = ccall((:EnzymeGradientUtilsGetWidth, libEnzyme), UInt64, (EnzymeGradientUtilsRef,), gutils)
EnzymeGradientUtilsNewFromOriginal(gutils, val) = ccall((:EnzymeGradientUtilsNewFromOriginal, libEnzyme), LLVMValueRef, (EnzymeGradientUtilsRef, LLVMValueRef), gutils, val)
EnzymeGradientUtilsSetDebugLocFromOriginal(gutils, val, orig) = ccall((:EnzymeGradientUtilsSetDebugLocFromOriginal, libEnzyme), Cvoid, (EnzymeGradientUtilsRef, LLVMValueRef, LLVMValueRef), gutils, val, orig)
Expand All @@ -216,7 +217,11 @@ EnzymeGradientUtilsAllocationBlock(gutils) = ccall((:EnzymeGradientUtilsAllocati
EnzymeGradientUtilsTypeAnalyzer(gutils) = ccall((:EnzymeGradientUtilsTypeAnalyzer, libEnzyme), EnzymeTypeAnalyzerRef, (EnzymeGradientUtilsRef,), gutils)

EnzymeGradientUtilsAllocAndGetTypeTree(gutils, val) = ccall((:EnzymeGradientUtilsAllocAndGetTypeTree, libEnzyme), CTypeTreeRef, (EnzymeGradientUtilsRef,LLVMValueRef), gutils, val)

EnzymeGradientUtilsGetUncacheableArgs(gutils, orig, uncacheable, size) = ccall((:EnzymeGradientUtilsGetUncacheableArgs, libEnzyme), Cvoid, (EnzymeGradientUtilsRef,LLVMValueRef, Ptr{UInt8}, UInt64), gutils, orig, uncacheable, size)

EnzymeGradientUtilsGetDiffeType(gutils, op, isforeign) = ccall((:EnzymeGradientUtilsGetDiffeType, libEnzyme), CDIFFE_TYPE, (EnzymeGradientUtilsRef,LLVMValueRef, UInt8), gutils, op, isforeign)

EnzymeGradientUtilsGetReturnDiffeType(gutils, orig, needsPrimalP, needsShadowP) = ccall((:EnzymeGradientUtilsGetReturnDiffeType, libEnzyme), CDIFFE_TYPE, (EnzymeGradientUtilsRef,LLVMValueRef, Ptr{UInt8}, Ptr{UInt8}), gutils, orig, needsPrimalP, needsShadowP)

EnzymeGradientUtilsSubTransferHelper(gutils, mode, secretty, intrinsic, dstAlign, srcAlign, offset, dstConstant, origdst, srcConstant, origsrc, length, isVolatile, MTI, allowForward, shadowsLookedUp) = ccall((:EnzymeGradientUtilsSubTransferHelper, libEnzyme),
Expand Down
Loading

0 comments on commit e199976

Please sign in to comment.