Skip to content

Commit

Permalink
Use namedtuple for grad/jacobian (#1850)
Browse files Browse the repository at this point in the history
* Use namedtuple for grad/jacobian

* Update index.md

* Update Enzyme.jl

* Update Enzyme.jl

* Update Enzyme.jl

* Update Enzyme.jl

* Update index.md
  • Loading branch information
wsmoses authored Sep 18, 2024
1 parent 786a998 commit bbaa1f8
Show file tree
Hide file tree
Showing 4 changed files with 67 additions and 55 deletions.
2 changes: 1 addition & 1 deletion docs/src/faq.md
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ That is why Enzyme provides a helper function `Enzyme.make_zero` that does this

```jldoctest sparse
Enzyme.make_zero(a)
Enzyme.gradient(Reverse, sum, a) # This calls make_zero(a)
Enzyme.gradient(Reverse, sum, a)[1] # This calls make_zero(a)
# output
Expand Down
88 changes: 50 additions & 38 deletions docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -76,24 +76,32 @@ Both the inplace and "normal" variant return the gradient. The difference is tha

## Forward mode

The return value of forward mode with a `Duplicated` return is a tuple containing as the first value
the primal return value and as the second value the derivative.
The return value when using `ForwardWithPrimal` is a tuple containing as the first value
the derivative return value and as the second value the original value.

The return value when using `Forward` is a single-element tuple containing the derivative.

In forward mode `Duplicated(x, 0.0)` is equivalent to `Const(x)`,
except that we can perform more optimizations for `Const`.

```jldoctest rosenbrock
julia> autodiff(Forward, rosenbrock, Duplicated, Const(1.0), Duplicated(3.0, 1.0))
julia> autodiff(ForwardWithPrimal, rosenbrock, Const(1.0), Duplicated(3.0, 1.0))
(400.0, 400.0)
julia> autodiff(Forward, rosenbrock, Duplicated, Duplicated(1.0, 1.0), Const(3.0))
(400.0, -800.0)
julia> autodiff(Forward, rosenbrock, Const(1.0), Duplicated(3.0, 1.0))
(400.0,)
julia> autodiff(ForwardWithPrimal, rosenbrock, Duplicated(1.0, 1.0), Const(3.0))
(-800.0, 400.0)
julia> autodiff(Forward, rosenbrock, Duplicated(1.0, 1.0), Const(3.0))
(-800.0,)
```

Of note, when we seed both arguments at once the tangent return is the sum of both.

```jldoctest rosenbrock
julia> autodiff(Forward, rosenbrock, Duplicated, Duplicated(1.0, 1.0), Duplicated(3.0, 1.0))
julia> autodiff(ForwardWithPrimal, rosenbrock, Duplicated(1.0, 1.0), Duplicated(3.0, 1.0))
(400.0, -400.0)
```

Expand Down Expand Up @@ -121,7 +129,7 @@ Note the seeding through `dx`.
We can also use vector mode to calculate both derivatives at once.

```jldoctest rosenbrock
julia> autodiff(Forward, rosenbrock, BatchDuplicated, BatchDuplicated(1.0, (1.0, 0.0)), BatchDuplicated(3.0, (0.0, 1.0)))
julia> autodiff(ForwardWithPrimal, rosenbrock, BatchDuplicated(1.0, (1.0, 0.0)), BatchDuplicated(3.0, (0.0, 1.0)))
(400.0, (var"1" = -800.0, var"2" = 400.0))
julia> x = [1.0, 3.0]
Expand All @@ -131,7 +139,7 @@ julia> x = [1.0, 3.0]
julia> dx_1 = [1.0, 0.0]; dx_2 = [0.0, 1.0];
julia> autodiff(Forward, rosenbrock_inp, BatchDuplicated, BatchDuplicated(x, (dx_1, dx_2)))
julia> autodiff(ForwardWithPrimal, rosenbrock_inp, BatchDuplicated(x, (dx_1, dx_2)))
(400.0, (var"1" = -800.0, var"2" = 400.0))
```

Expand All @@ -145,65 +153,69 @@ Like [`autodiff`](@ref), the mode (forward or reverse) is determined by the firs

The functions [`gradient`](@ref) and [`gradient!`](@ref) compute the gradient of function with vector input and scalar return.

Gradient functions take a mode as the first argument. If the mode is `Reverse` or `Forward`, the return type is a tuple of gradients of each argument.
If the mode is `ReverseWithPrimal` or `ForwardWithPrimal`, the return type is a named tuple containing both the derivatives and the original return result.

```jldoctest rosenbrock
julia> gradient(Reverse, rosenbrock_inp, [1.0, 2.0])
2-element Vector{Float64}:
-400.0
200.0
([-400.0, 200.0],)
julia> gradient(ReverseWithPrimal, rosenbrock_inp, [1.0, 2.0])
(derivs=[-400.0, 200.0], val=100.0)
julia> # inplace variant
dx = [0.0, 0.0];
gradient!(Reverse, dx, rosenbrock_inp, [1.0, 2.0])
2-element Vector{Float64}:
-400.0
200.0
([-400.0, 200.0],)
julia> dx
2-element Vector{Float64}:
-400.0
200.0
julia> gradient(Forward, rosenbrock_inp, [1.0, 2.0])
(-400.0, 200.0)
([-400.0, 200.0],)
julia> gradient(ForwardWithPrimal, rosenbrock_inp, [1.0, 2.0])
(derivs = [-400.0, 200.0], val = 100.0)
julia> # in forward mode, we can also optionally pass a chunk size
# to specify the number of derivatives computed simulateneously
# using vector forward mode
chunk_size = Val(2)
gradient(Forward, rosenbrock_inp, [1.0, 2.0], chunk_size)
(-400.0, 200.0)
gradient(Forward, rosenbrock_inp, [1.0, 2.0]; chunk=Val(1))
([-400.0, 200.0],)
```

## Jacobian Convenience functions

The function [`jacobian`](@ref) computes the Jacobian of a function vector input and vector return.
Like [`autodiff`](@ref) and [`gradient`](@ref), the mode (forward or reverse) is determined by the first argument.

Again like [`gradient`](@ref), if the mode is `Reverse` or `Forward`, the return type is a tuple of jacobians of each argument.
If the mode is `ReverseWithPrimal` or `ForwardWithPrimal`, the return type is a named tuple containing both the derivatives and the original return result.

Both forward and reverse modes take an optional chunk size to compute several derivatives simultaneously using vector mode, and reverse mode optionally takes `n_outs` which describes the shape of the output value.

```jldoctest rosenbrock
julia> foo(x) = [rosenbrock_inp(x), prod(x)];
julia> output_size = Val(2) # here we have to provide the output size of `foo` since it cannot be statically inferred
jacobian(Reverse, foo, [1.0, 2.0], output_size)
2×2 transpose(::Matrix{Float64}) with eltype Float64:
-400.0 200.0
2.0 1.0
julia> jacobian(Reverse, foo, [1.0, 2.0])
([-400.0 200.0; 2.0 1.0],)
julia> chunk_size = Val(2) # By specifying the optional chunk size argument, we can use vector inverse mode to propogate derivatives of multiple outputs at once.
jacobian(Reverse, foo, [1.0, 2.0], output_size, chunk_size)
2×2 transpose(::Matrix{Float64}) with eltype Float64:
-400.0 200.0
2.0 1.0
julia> jacobian(ReverseWithPrimal, foo, [1.0, 2.0])
(derivs = ([-400.0 200.0; 2.0 1.0],), val = [100.0, 2.0])
julia> jacobian(Reverse, foo, [1.0, 2.0]; chunk=Val(2))
([-400.0 200.0; 2.0 1.0],)
julia> jacobian(Reverse, foo, [1.0, 2.0]; chunk=Val(2), n_outs=Val((2,)))
([-400.0 200.0; 2.0 1.0],)
julia> jacobian(Forward, foo, [1.0, 2.0])
2×2 Matrix{Float64}:
-400.0 200.0
2.0 1.0
julia> # Again, the optinal chunk size argument allows us to use vector forward mode
jacobian(Forward, foo, [1.0, 2.0], chunk_size)
2×2 Matrix{Float64}:
-400.0 200.0
2.0 1.0
([-400.0 200.0; 2.0 1.0],)
julia> jacobian(Forward, foo, [1.0, 2.0], chunk=Val(2))
([-400.0 200.0; 2.0 1.0],)
```

## Hessian Vector Product Convenience functions
Expand Down Expand Up @@ -257,4 +269,4 @@ julia> grad
2-element Vector{Float64}:
2.880510859951098
1.920340573300732
```
```
28 changes: 14 additions & 14 deletions src/Enzyme.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1082,21 +1082,21 @@ a tuple where the first element contains the derivatives, and the second element
grad = gradient(ReverseWithPrimal, f, [2.0, 3.0])
# output
(([3.0, 2.0],), 6.0)
(derivs = ([3.0, 2.0],), val = 6.0)
```
```jldoctest gradient
grad = gradient(ReverseWithPrimal, mul, [2.0], [3.0])
# output
(([3.0], [2.0]), 6.0)
(derivs = ([3.0], [2.0]), val = 6.0)
```
```jldoctest gradient
grad = gradient(ReverseWithPrimal, mul, [2.0], Const([3.0]))
# output
(([3.0], nothing), 6.0)
(derivs = ([3.0], nothing), val = 6.0)
```
"""
Expand Down Expand Up @@ -1161,7 +1161,7 @@ grad = gradient(ReverseWithPrimal, mul, [2.0], Const([3.0]))
return quote
Base.@_inline_meta
$(toemit...)
(($(resargs...),), res[2])
(; derivs=($(resargs...),), val=res[2])
end
else
return quote
Expand Down Expand Up @@ -1196,14 +1196,14 @@ dx = [0.0, 0.0]
gradient!(ReverseWithPrimal, dx, f, [2.0, 3.0])
# output
(([3.0, 2.0],), 6.0)
(derivs = ([3.0, 2.0],), val = 6.0)
```
"""
@inline function gradient!(rm::ReverseMode{ReturnPrimal,RuntimeActivity,ABI,Holomorphic,ErrIfFuncWritten}, dx::X, f::F, x::X) where {X<:Array, F, ReturnPrimal, RuntimeActivity, ABI, Holomorphic, ErrIfFuncWritten}
make_zero!(dx)
res = autodiff(rm, f, Active, Duplicated(x, dx))
return if ReturnPrimal
((dx,), res[2])
(; derivs=(dx,), val=res[2])
else
(dx,)
end
Expand Down Expand Up @@ -1300,7 +1300,7 @@ gradient(Forward, f, [2.0, 3.0])
gradient(ForwardWithPrimal, f, [2.0, 3.0])
# output
(([3.0, 2.0],), 6.0)
(derivs = ([3.0, 2.0],), val = 6.0)
```
```jldoctest gradfwd
Expand All @@ -1315,7 +1315,7 @@ gradient(Forward, f, [2.0, 3.0]; chunk=Val(1))
gradient(ForwardWithPrimal, f, [2.0, 3.0]; chunk=Val(1))
# output
(([3.0, 2.0],), 6.0)
(derivs = ([3.0, 2.0],), val = 6.0)
```
For functions which return an AbstractArray or scalar, this function will return an AbstracttArray
Expand All @@ -1336,10 +1336,10 @@ grad = gradient(Forward, f, [2.0, 3.0, 4.0])
"""
@inline function gradient(fm::ForwardMode{ReturnPrimal, ABI, ErrIfFuncWritten,RuntimeActivity}, f, x; chunk::CS=nothing, shadows=create_shadows(chunk, x)) where {ReturnPrimal, ABI, ErrIfFuncWritten,RuntimeActivity, CS}
if length(shadows[1]) == 0
if ReturnPrimal
((x,), f(x.val))
return if ReturnPrimal
(; derivs=(x,), val=f(x.val))
else
return (x,)
(x,)
end
end
if chunk == Val(0)
Expand Down Expand Up @@ -1430,7 +1430,7 @@ grad = gradient(Forward, f, [2.0, 3.0, 4.0])
cols
end
if ReturnPrimal
((res,), gradtup[2])
(; derivs=(res,), val=gradtup[2])
else
(res,)
end
Expand Down Expand Up @@ -1498,7 +1498,7 @@ this function will retun an AbstractArray of shape `size(output)` of values of t
end

return if ReturnPrimal
(jac, res)
(; derivs=jac, val=res)
else
jac
end
Expand Down Expand Up @@ -1606,7 +1606,7 @@ this function will retun an AbstractArray of shape `size(output)` of values of t
end
if ReturnPrimal
# TODO optimize away redundant fwd pass
(res, if f isa Enzyme.Const
(; derivs=res, val=if f isa Enzyme.Const
f.val(x)
else
f(x)
Expand Down
4 changes: 2 additions & 2 deletions test/ext/bfloat16s.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,6 @@ using Enzyme
using Test
using BFloat16s

@test_broken Enzyme.gradient(Reverse, sum, ones(BFloat16, 10)) ones(BFloat16, 10)
@test_broken Enzyme.gradient(Reverse, sum, ones(BFloat16, 10))[1] ones(BFloat16, 10)

@test_broken Enzyme.gradient(Forward, sum, ones(BFloat16, 10)) ones(BFloat16, 10)
@test_broken Enzyme.gradient(Forward, sum, ones(BFloat16, 10))[1] ones(BFloat16, 10)

0 comments on commit bbaa1f8

Please sign in to comment.