Skip to content

Commit

Permalink
Simplify deferred functions (#1849)
Browse files Browse the repository at this point in the history
* Simplify deferred functions

* fix

* Update runtests.jl

* fix

* fix

* fix

* fix

* fix
  • Loading branch information
wsmoses authored Sep 18, 2024
1 parent bbaa1f8 commit 6a19be2
Show file tree
Hide file tree
Showing 11 changed files with 73 additions and 87 deletions.
3 changes: 2 additions & 1 deletion docs/Project.toml
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
[deps]
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
Literate = "98b081ad-f1c9-55d3-8b20-4c87d4299306"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[compat]
Literate = "2"
Documenter = "1"
Literate = "2"
2 changes: 1 addition & 1 deletion docs/src/faq.md
Original file line number Diff line number Diff line change
Expand Up @@ -627,7 +627,7 @@ Presently Enzyme only considers floats as base types. As a result, Enzyme does n

```jldoctest types
f_int(x) = x * x
Enzyme.autodiff(Forward, f_int, DuplicatedNoNeed, Duplicated(3, 1))
Enzyme.autodiff(Forward, f_int, Duplicated, Duplicated(3, 1))
# output
Expand Down
29 changes: 16 additions & 13 deletions docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,10 @@ Of note, when we seed both arguments at once the tangent return is the sum of bo

```jldoctest rosenbrock
julia> autodiff(ForwardWithPrimal, rosenbrock, Duplicated(1.0, 1.0), Duplicated(3.0, 1.0))
(400.0, -400.0)
(-400.0, 400.0)
julia> autodiff(Forward, rosenbrock, Duplicated(1.0, 1.0), Duplicated(3.0, 1.0))
(-400.0,)
```

We can also use forward mode with our inplace method.
Expand All @@ -118,8 +121,8 @@ julia> dx = [1.0, 1.0]
1.0
1.0
julia> autodiff(Forward, rosenbrock_inp, Duplicated, Duplicated(x, dx))
(400.0, -400.0)
julia> autodiff(ForwardWithPrimal, rosenbrock_inp, Duplicated, Duplicated(x, dx))
(-400.0, 400.0)
```

Note the seeding through `dx`.
Expand All @@ -130,7 +133,7 @@ We can also use vector mode to calculate both derivatives at once.

```jldoctest rosenbrock
julia> autodiff(ForwardWithPrimal, rosenbrock, BatchDuplicated(1.0, (1.0, 0.0)), BatchDuplicated(3.0, (0.0, 1.0)))
(400.0, (var"1" = -800.0, var"2" = 400.0))
((var"1" = -800.0, var"2" = 400.0), 400.0)
julia> x = [1.0, 3.0]
2-element Vector{Float64}:
Expand All @@ -140,7 +143,7 @@ julia> x = [1.0, 3.0]
julia> dx_1 = [1.0, 0.0]; dx_2 = [0.0, 1.0];
julia> autodiff(ForwardWithPrimal, rosenbrock_inp, BatchDuplicated(x, (dx_1, dx_2)))
(400.0, (var"1" = -800.0, var"2" = 400.0))
((var"1" = -800.0, var"2" = 400.0), 400.0)
```

## Gradient Convenience functions
Expand All @@ -161,7 +164,7 @@ julia> gradient(Reverse, rosenbrock_inp, [1.0, 2.0])
([-400.0, 200.0],)
julia> gradient(ReverseWithPrimal, rosenbrock_inp, [1.0, 2.0])
(derivs=[-400.0, 200.0], val=100.0)
(derivs = ([-400.0, 200.0],), val = 100.0)
julia> # inplace variant
dx = [0.0, 0.0];
Expand All @@ -177,7 +180,7 @@ julia> gradient(Forward, rosenbrock_inp, [1.0, 2.0])
([-400.0, 200.0],)
julia> gradient(ForwardWithPrimal, rosenbrock_inp, [1.0, 2.0])
(derivs = [-400.0, 200.0], val = 100.0)
(derivs = ([-400.0, 200.0],), val = 100.0)
julia> # in forward mode, we can also optionally pass a chunk size
# to specify the number of derivatives computed simulateneously
Expand All @@ -200,22 +203,22 @@ Both forward and reverse modes take an optional chunk size to compute several de
julia> foo(x) = [rosenbrock_inp(x), prod(x)];
julia> jacobian(Reverse, foo, [1.0, 2.0])
([-400.0 200.0; 2.0 1.0],)
([-400.0 200.0; 2.0 1.0],)
julia> jacobian(ReverseWithPrimal, foo, [1.0, 2.0])
(derivs = ([-400.0 200.0; 2.0 1.0],), val = [100.0, 2.0])
(derivs = ([-400.0 200.0; 2.0 1.0],), val = [100.0, 2.0])
julia> jacobian(Reverse, foo, [1.0, 2.0]; chunk=Val(2))
([-400.0 200.0; 2.0 1.0],)
([-400.0 200.0; 2.0 1.0],)
julia> jacobian(Reverse, foo, [1.0, 2.0]; chunk=Val(2), n_outs=Val((2,)))
([-400.0 200.0; 2.0 1.0],)
([-400.0 200.0; 2.0 1.0],)
julia> jacobian(Forward, foo, [1.0, 2.0])
([-400.0 200.0; 2.0 1.0],)
([-400.0 200.0; 2.0 1.0],)
julia> jacobian(Forward, foo, [1.0, 2.0], chunk=Val(2))
([-400.0 200.0; 2.0 1.0],)
([-400.0 200.0; 2.0 1.0],)
```

## Hessian Vector Product Convenience functions
Expand Down
19 changes: 11 additions & 8 deletions examples/custom_rule.jl
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ function forward(config::FwdConfig, func::Const{typeof(f)}, ::Type{<:Duplicated}
end

# In the signature of our rule, we have made use of `Enzyme`'s activity annotations. Let's break down each one:
# - the [`FwdConfig`](@ref) configuration passes certain compile-time information about differentiation procedure (the width, and if we're using runtime activity),
# - the [`EnzymeRules.FwdConfig`](@ref) configuration passes certain compile-time information about differentiation procedure (the width, and if we're using runtime activity),
# - the [`Const`](@ref) annotation on `f` indicates that we accept a function `f` that does not have a derivative component,
# which makes sense since `f` is not a closure with data that could be differentiated.
# - the [`Duplicated`](@ref) annotation given in the second argument annotates the return value of `f`. This means that
Expand Down Expand Up @@ -123,8 +123,9 @@ dy = [0.0, 0.0]
# If a custom rule is specified for the correct function/argument types, but not the correct activity annotation,
# a runtime error will be thrown alerting the user to the missing activity rule rather than silently ignoring the rule."

# Finally, it may be that either `x`, `y`, or the return value are marked as [`Const`](@ref). We can in fact handle this case,
# along with the previous two cases, all together in a single rule:
# Finally, it may be that either `x`, `y`, or the return value are marked as [`Const`](@ref), in which case we can simply return the original result. However, Enzyme also may determine the return is not differentiable and also not needed for other computations, in which case we should simply return nothing.
#
# We can in fact handle this case, along with the previous two cases, all together in a single rule by leveraging utility functions [`EnzymeRules.needs_primal`](@ref) and [`EnzymeRules.needs_shadow`](@ref), which return true if the original return or the derivative is needed to be returned, respectively:

Base.delete_method.(methods(forward, (Const{typeof(f)}, Vararg{Any}))) # delete our old rules

Expand All @@ -138,12 +139,14 @@ function forward(config, func::Const{typeof(f)}, RT::Type{<:Union{Const, Duplica
make_zero!(y.dval)
end
dret = !(y isa Const) ? sum(y.dval) : zero(eltype(y.val))
if RT <: Const
if needs_primal(config) && needs_shadow(config)
return Duplicated(sum(y.val), dret)
elseif needs_primal(config)
return sum(y.val)
elseif RT <: DuplicatedNoNeed
elseif needs_shadow(config)
return dret
else
return Duplicated(sum(y.val), dret)
return nothing
end
end

Expand Down Expand Up @@ -189,15 +192,15 @@ function augmented_primal(config::RevConfigWidth{1}, func::Const{typeof(f)}, ::T
end

# Let's unpack our signature for `augmented_primal` :
# * We accepted a [`EnzymeRules.Config`](@ref) object with a specified width of 1, which means that our rule does not support batched reverse mode.
# * We accepted a [`EnzymeRules.RevConfig`](@ref) object with a specified width of 1, which means that our rule does not support batched reverse mode.
# * We annotated `f` with [`Const`](@ref) as usual.
# * We dispatched on an [`Active`](@ref) annotation for the return value. This is a special annotation for scalar values, such as our return value,
# that indicates that that we care about the value's derivative but we need not explicitly allocate a mutable shadow since it is a scalar value.
# * We annotated `x` and `y` with [`Duplicated`](@ref), similar to our first simple forward rule.

# Now, let's unpack the body of our `augmented_primal` rule:
# * We checked if the `config` requires the primal. If not, we need not compute the return value, but we make sure to mutate `y` in all cases.
# * We checked if `x` could possibly be overwritten using the `Overwritten` attribute of [`EnzymeRules.Config`](@ref).
# * We checked if `x` could possibly be overwritten using the `Overwritten` attribute of [`EnzymeRules.RevConfig`](@ref).
# If so, we save the elements of `x` on the `tape` of the returned [`EnzymeRules.AugmentedReturn`](@ref) object.
# * We return a shadow of `nothing` since the return value is [`Active`](@ref) and hence does not need a shadow.

Expand Down
12 changes: 12 additions & 0 deletions lib/EnzymeCore/src/rules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,19 @@ Getters for the type parameters are provided by `needs_primal`, `needs_shadow`,
struct FwdConfig{NeedsPrimal, NeedsShadow, Width, RuntimeActivity} end
const FwdConfigWidth{Width} = FwdConfig{<:Any,<:Any,Width}

"""
needs_primal(::FwdConfig)
needs_primal(::RevConfig)
Whether a custom rule should return the original result of the function.
"""
@inline needs_primal(::FwdConfig{NeedsPrimal}) where NeedsPrimal = NeedsPrimal
"""
needs_shadow(::FwdConfig)
needs_shadow(::RevConfig)
Whether a custom rule should return the shadow (derivative) of the function result.
"""
@inline needs_shadow(::FwdConfig{<:Any, NeedsShadow}) where NeedsShadow = NeedsShadow

@inline width(::FwdConfig{<:Any, <:Any, Width}) where Width = Width
Expand Down
33 changes: 0 additions & 33 deletions src/Enzyme.jl
Original file line number Diff line number Diff line change
Expand Up @@ -531,39 +531,6 @@ code, as well as high-order differentiation.
thunk(f, args...)
end

"""
autodiff_deferred(mode::Mode, f, ::Type{A}, args)
Like [`autodiff_deferred`](@ref) but will try to extend f to an annotation, if needed.
"""
@inline function autodiff_deferred(mode::CMode, f::F, args::Vararg{Annotation, Nargs}) where {F, CMode<:Mode, Nargs}
autodiff_deferred(EnzymeCore.set_err_if_func_written(mode), Const(f), args...)
end
@inline function autodiff_deferred(mode::CMode, f::F, ::Type{RT}, args::Vararg{Annotation, Nargs}) where {F, RT<:Annotation, CMode<:Mode, Nargs}
autodiff_deferred(EnzymeCore.set_err_if_func_written(mode), Const(f), RT, args...)
end

"""
autodiff_deferred(mode, f, args...)
Like [`autodiff_deferred`](@ref) but will try to guess the activity of the return value.
"""

@inline function autodiff_deferred(mode::M, f::FA, args::Vararg{Annotation, Nargs}) where {FA<:Annotation, M<:Mode, Nargs}
tt = Tuple{map(T->eltype(Core.Typeof(T)), args)...}
rt = if mode isa ReverseMode
Compiler.primal_return_type(mode, Val(codegen_world_age(eltype(FA), tt)), eltype(FA), tt)
else
Core.Compiler.return_type(f.val, tt)
end

if rt === Union{}
error("return type is Union{}, giving up.")
end
rt = guess_activity(rt, mode)
autodiff_deferred(mode, f, rt, args...)
end

"""
autodiff_thunk(::ReverseModeSplit, ftype, Activity, argtypes::Vararg{Type{<:Annotation, Nargs})
Expand Down
30 changes: 15 additions & 15 deletions test/abi.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,13 @@ using Test

@test () === autodiff(Forward, f, Const(nothing))

res = autodiff_deferred(Reverse, f, Const(nothing))
res = autodiff_deferred(Reverse, Const(f), Const, Const(nothing))
@test res === ((nothing,),)
res = autodiff_deferred(Enzyme.set_abi(Reverse, NonGenABI), f, Const, Const(nothing))
res = autodiff_deferred(Enzyme.set_abi(Reverse, NonGenABI), Const(f), Const, Const(nothing))
@test res === ((nothing,),)

@test () === autodiff_deferred(Forward, f, Const(nothing))
@test () === autodiff_deferred(Enzyme.set_abi(Forward, NonGenABI), f, Const, Const(nothing))
@test () === autodiff_deferred(Forward, Const(f), Const, Const(nothing))
@test () === autodiff_deferred(Enzyme.set_abi(Forward, NonGenABI), Const(f), Const, Const(nothing))

# ConstType -> Type{Int}
res = autodiff(Reverse, f, Const, Const(Int))
Expand All @@ -37,9 +37,9 @@ using Test
@test res === ((nothing,),)
@test () === autodiff(Forward, f, Const(Int))

res = autodiff_deferred(Reverse, f, Const(Int))
res = autodiff_deferred(Reverse, Const(f), Const, Const(Int))
@test res === ((nothing,),)
@test () === autodiff_deferred(Forward, f, Const(Int))
@test () === autodiff_deferred(Forward, Const(f), Const, Const(Int))

# Complex numbers
@test_throws ErrorException autodiff(Reverse, f, Active, Active(1.5 + 0.7im))
Expand All @@ -54,10 +54,10 @@ using Test
cres, = autodiff(Forward, f, Duplicated(1.5 + 0.7im, 1.0+0im))
@test cres 1.0 + 0.0im

@test_throws ErrorException autodiff_deferred(Reverse, f, Active(1.5 + 0.7im))
@test_throws ErrorException autodiff_deferred(ReverseHolomorphic, f, Active(1.5 + 0.7im))
@test_throws ErrorException autodiff_deferred(Reverse, Const(f), Active, Active(1.5 + 0.7im))
@test_throws ErrorException autodiff_deferred(ReverseHolomorphic, Const(f), Active, Active(1.5 + 0.7im))

cres, = autodiff_deferred(Forward, f, Duplicated(1.5 + 0.7im, 1.0+0im))
cres, = autodiff_deferred(Forward, Const(f), Duplicated, Duplicated(1.5 + 0.7im, 1.0+0im))
@test cres 1.0 + 0.0im

# Unused singleton argument
Expand Down Expand Up @@ -97,7 +97,7 @@ using Test

x = [0.0]
dx = [1.2]
autodiff_deferred(Reverse, squareRetArray, Const, Duplicated(x, dx))
autodiff_deferred(Reverse, Const(squareRetArray), Const, Duplicated(x, dx))

dx = [1.2]
@test () === autodiff(Forward, squareRetArray, Const, Duplicated(x, dx))
Expand All @@ -113,7 +113,7 @@ using Test
@test pair[1] 3.0
@test pair[2] 2.0

pair = autodiff_deferred(Reverse, mul, Active(2.0), Active(3.0))[1]
pair = autodiff_deferred(Reverse, Const(mul), Active, Active(2.0), Active(3.0))[1]
@test pair[1] 3.0
@test pair[2] 2.0

Expand All @@ -122,7 +122,7 @@ using Test
@test pair[2] 2.0
@test orig 6.0

pair, orig = autodiff_deferred(ReverseWithPrimal, mul, Active(2.0), Active(3.0))
pair, orig = autodiff_deferred(ReverseWithPrimal, Const(mul), Active, Active(2.0), Active(3.0))
@test pair[1] 3.0
@test pair[2] 2.0
@test orig 6.0
Expand All @@ -142,7 +142,7 @@ using Test

res = Ref(3.0)
dres = Ref(1.0)
pair, orig = autodiff_deferred(ReverseWithPrimal, inplace, Const, Duplicated(res, dres))
pair, orig = autodiff_deferred(ReverseWithPrimal, Const(inplace), Const, Duplicated(res, dres))
@test pair == (nothing,)
@test res[] 6.0
@test dres[] 2.0
Expand All @@ -163,7 +163,7 @@ using Test

res = Ref(3.0)
dres = Ref(1.0)
pair, orig = autodiff_deferred(ReverseWithPrimal, inplace2, Const, Duplicated(res, dres))
pair, orig = autodiff_deferred(ReverseWithPrimal, Const(inplace2), Const, Duplicated(res, dres))
@test pair == (nothing,)
@test res[] 6.0
@test dres[] 2.0
Expand Down Expand Up @@ -450,7 +450,7 @@ end
@test r[2] 100.0
@test r[1][1] -400.0
@test r[1][2] 200.0
r = autodiff_deferred(ForwardWithPrimal, rosenbrock_inp, Duplicated, BatchDuplicated(x, (dx_1, dx_2)))
r = autodiff_deferred(ForwardWithPrimal, Const(rosenbrock_inp), Duplicated, BatchDuplicated(x, (dx_1, dx_2)))
@test r[2] 100.0
@test r[1][1] -400.0
@test r[1][2] 200.0
Expand Down
6 changes: 3 additions & 3 deletions test/amdgpu.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ function mul_kernel(A)
end

function grad_mul_kernel(A, dA)
autodiff_deferred(Reverse, mul_kernel, Const, Duplicated(A, dA))
autodiff_deferred(Reverse, Const(mul_kernel), Const, Duplicated(A, dA))
return nothing
end

Expand All @@ -34,7 +34,7 @@ function exp_kernel(A)
end

function grad_exp_kernel(A, dA)
autodiff_deferred(Reverse, exp_kernel, Const, Duplicated(A, dA))
autodiff_deferred(Reverse, Const(exp_kernel), Const, Duplicated(A, dA))
return nothing
end

Expand All @@ -57,7 +57,7 @@ function cos_kernel(A)
end

function grad_cos_kernel(A, dA)
autodiff_deferred(Reverse, cos_kernel, Const, Duplicated(A, dA))
autodiff_deferred(Reverse, Const(cos_kernel), Const, Duplicated(A, dA))
return nothing
end

Expand Down
10 changes: 5 additions & 5 deletions test/cuda.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ function mul_kernel(A)
end

function grad_mul_kernel(A, dA)
autodiff_deferred(Reverse, mul_kernel, Const, Duplicated(A, dA))
autodiff_deferred(Reverse, Const(mul_kernel), Const, Duplicated(A, dA))
return nothing
end

Expand All @@ -34,7 +34,7 @@ function exp_kernel(A)
end

function grad_exp_kernel(A, dA)
autodiff_deferred(Reverse, exp_kernel, Const, Duplicated(A, dA))
autodiff_deferred(Reverse, Const(exp_kernel), Const, Duplicated(A, dA))
return nothing
end

Expand All @@ -57,7 +57,7 @@ function cos_kernel(A)
end

function grad_cos_kernel(A, dA)
autodiff_deferred(Reverse, cos_kernel, Const, Duplicated(A, dA))
autodiff_deferred(Reverse, Const(cos_kernel), Const, Duplicated(A, dA))
return nothing
end

Expand All @@ -76,7 +76,7 @@ function val_kernel!(_, ::Val{N}) where N
end

function dval_kernel!(du, ::Val{N}) where {N}
autodiff_deferred(Reverse, val_kernel!, Const, du, Const(Val(N)))
autodiff_deferred(Reverse, Const(val_kernel!), Const, du, Const(Val(N)))
return nothing
end

Expand Down Expand Up @@ -123,7 +123,7 @@ function ddense!(

autodiff_deferred(
Reverse,
dense!,
Const(dense!),
Const,
dfeats_out, dfeats_in, dW, db,
Const(Val(nfeat_out)), Const(Val(nfeat_in)), Const(Val(ndof))
Expand Down
Loading

0 comments on commit 6a19be2

Please sign in to comment.