Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for user-defined rules #177

Closed
wants to merge 16 commits into from
2 changes: 2 additions & 0 deletions lib/EnzymeCore/src/EnzymeCore.jl
Original file line number Diff line number Diff line change
Expand Up @@ -126,4 +126,6 @@ struct ForwardMode <: Mode
end
const Forward = ForwardMode()

include("rules.jl")

end # module EnzymeCore
113 changes: 113 additions & 0 deletions lib/EnzymeCore/src/rules.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
module EnzymeRules

import EnzymeCore: Annotation
export Config, ConfigWidth
export needs_primal, needs_shadow, width, overwritten

"""
forward(func::Annotation{typeof(f)}, RT::Type{<:Annotation}, args::Annotation...)

Calculate the forward derivative. The first argument `func` is the callable
for which the rule applies to. Either wrapped in a [`Const`](@ref)), or
a [`Duplicated`](@ref) if it is a closure.
The second argument is the return type annotation, and all other arguments are
the annotated function arguments.
"""
function forward end

struct Config{NeedsPrimal, NeedsShadow, Width, Overwritten} end
const ConfigWidth{Width} = Config{<:Any,<:Any, Width}

needs_primal(::Config{NeedsPrimal}) where NeedsPrimal = NeedsPrimal
needs_shadow(::Config{<:Any, NeedsShadow}) where NeedsShadow = NeedsShadow
width(::Config{<:Any, <:Any, Width}) where Width = Width
overwritten(::Config{<:Any, <:Any, <:Any, Overwritten}) where Overwritten = Overwritten

"""
augmented_primal(::Config, func::Annotation{typeof(f)}, RT::Type{<:Annotation}, args::Annotation...)

Must return a tuple of length 2.
The first-value is primal value and the second is the tape. If no tape is
required return `(val, nothing)`.
"""
function augmented_primal end

"""
reverse(::Config, func::Annotation{typeof(f)}, dret::Annotation, tape, args::Annotation...)

Takes gradient of derivative, activity annotation, and tape
"""
function reverse end

_annotate(T::DataType) = TypeVar(gensym(), Annotation{T})
_annotate(::Type{T}) where T = TypeVar(gensym(), Annotation{T})
function _annotate(VA::Core.TypeofVararg)
T = _annotate(VA.T)
if isdefined(VA, :N)
return Vararg{T, VA.N}
else
return Vararg{T}
end
end

function has_frule_from_sig(@nospecialize(TT); world=Base.get_world_counter())
TT = Base.unwrap_unionall(TT)
ft = TT.parameters[1]
tt = map(_annotate, TT.parameters[2:end])
TT = Tuple{<:Annotation{ft}, Type{<:Annotation}, tt...}
isapplicable(forward, TT; world)
end

function has_rrule_from_sig(@nospecialize(TT); world=Base.get_world_counter())
TT = Base.unwrap_unionall(TT)
ft = TT.parameters[1]
tt = map(_annotate, TT.parameters[2:end])
TT = Tuple{<:Config, <:Annotation{ft}, <:Annotation, <:Any, tt...}
isapplicable(reverse, TT; world)
end

function has_frule(@nospecialize(f); world=Base.get_world_counter())
TT = Tuple{<:Annotation{Core.Typeof(f)}, Type{<:Annotation}, Vararg{<:Annotation}}
isapplicable(forward, TT; world)
end

# Do we need this one?
function has_frule(@nospecialize(f), @nospecialize(TT::Type{<:Tuple}); world=Base.get_world_counter())
TT = Base.unwrap_unionall(TT)
TT = Tuple{<:Annotation{Core.Typeof(f)}, Type{<:Annotation}, TT.parameters...}
isapplicable(forward, TT; world)
end

# Do we need this one?
function has_frule(@nospecialize(f), @nospecialize(RT::Type); world=Base.get_world_counter())
TT = Tuple{<:Annotation{Core.Typeof(f)}, Type{RT}, Vararg{<:Annotation}}
isapplicable(forward, TT; world)
end

# Do we need this one?
function has_frule(@nospecialize(f), @nospecialize(RT::Type), @nospecialize(TT::Type{<:Tuple}); world=Base.get_world_counter())
TT = Base.unwrap_unionall(TT)
TT = Tuple{<:Annotation{Core.Typeof(f)}, Type{RT}, TT.parameters...}
isapplicable(forward, TT; world)
end

# Base.hasmethod is a precise match we want the broader query.
function isapplicable(@nospecialize(f), @nospecialize(TT); world=Base.get_world_counter())
tt = Base.to_tuple_type(TT)
sig = Base.signature_type(f, tt)
return !isempty(Base._methods_by_ftype(sig, -1, world)) # TODO cheaper way of querying?
end

function has_rrule(@nospecialize(TT), world=Base.get_world_counter())
return false
end

function issupported()
@static if VERSION < v"1.7.0"
return false
else
return true
end
end

end # EnzymeRules
3 changes: 2 additions & 1 deletion src/Enzyme.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ export markType, batch_size, onehot, chunkedonehot
using LinearAlgebra
import EnzymeCore: ReverseMode, ForwardMode, Annotation, Mode

import EnzymeCore: EnzymeRules

# Independent code, must be loaded before "compiler.jl"
include("pmap.jl")

Expand Down Expand Up @@ -61,7 +63,6 @@ end
end
end


include("logic.jl")
include("typeanalysis.jl")
include("typetree.jl")
Expand Down
5 changes: 5 additions & 0 deletions src/api.jl
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,7 @@ EnzymeRegisterFwdCallHandler(name, fwdhandle) = ccall((:EnzymeRegisterFwdCallHan

EnzymeGetShadowType(width, T) = ccall((:EnzymeGetShadowType, libEnzyme), LLVMTypeRef, (UInt64,LLVMTypeRef), width, T)

EnzymeGradientUtilsGetMode(gutils) = ccall((:EnzymeGradientUtilsGetMode, libEnzyme), CDerivativeMode, (EnzymeGradientUtilsRef,), gutils)
EnzymeGradientUtilsGetWidth(gutils) = ccall((:EnzymeGradientUtilsGetWidth, libEnzyme), UInt64, (EnzymeGradientUtilsRef,), gutils)
EnzymeGradientUtilsNewFromOriginal(gutils, val) = ccall((:EnzymeGradientUtilsNewFromOriginal, libEnzyme), LLVMValueRef, (EnzymeGradientUtilsRef, LLVMValueRef), gutils, val)
EnzymeGradientUtilsSetDebugLocFromOriginal(gutils, val, orig) = ccall((:EnzymeGradientUtilsSetDebugLocFromOriginal, libEnzyme), Cvoid, (EnzymeGradientUtilsRef, LLVMValueRef, LLVMValueRef), gutils, val, orig)
Expand All @@ -211,7 +212,11 @@ EnzymeGradientUtilsAllocationBlock(gutils) = ccall((:EnzymeGradientUtilsAllocati
EnzymeGradientUtilsTypeAnalyzer(gutils) = ccall((:EnzymeGradientUtilsTypeAnalyzer, libEnzyme), EnzymeTypeAnalyzerRef, (EnzymeGradientUtilsRef,), gutils)

EnzymeGradientUtilsAllocAndGetTypeTree(gutils, val) = ccall((:EnzymeGradientUtilsAllocAndGetTypeTree, libEnzyme), CTypeTreeRef, (EnzymeGradientUtilsRef,LLVMValueRef), gutils, val)

EnzymeGradientUtilsGetUncacheableArgs(gutils, orig, uncacheable, size) = ccall((:EnzymeGradientUtilsGetUncacheableArgs, libEnzyme), Cvoid, (EnzymeGradientUtilsRef,LLVMValueRef, Ptr{UInt8}, UInt64), gutils, orig, uncacheable, size)

EnzymeGradientUtilsGetDiffeType(gutils, op, isforeign) = ccall((:EnzymeGradientUtilsGetDiffeType, libEnzyme), CDIFFE_TYPE, (EnzymeGradientUtilsRef,LLVMValueRef, UInt8), gutils, op, isforeign)

EnzymeGradientUtilsGetReturnDiffeType(gutils, orig, needsPrimalP, needsShadowP) = ccall((:EnzymeGradientUtilsGetReturnDiffeType, libEnzyme), CDIFFE_TYPE, (EnzymeGradientUtilsRef,LLVMValueRef, Ptr{UInt8}, Ptr{UInt8}), gutils, orig, needsPrimalP, needsShadowP)

EnzymeGradientUtilsSubTransferHelper(gutils, mode, secretty, intrinsic, dstAlign, srcAlign, offset, dstConstant, origdst, srcConstant, origsrc, length, isVolatile, MTI, allowForward, shadowsLookedUp) = ccall((:EnzymeGradientUtilsSubTransferHelper, libEnzyme),
Expand Down
Loading