Skip to content

Commit

Permalink
Fix jac nout (#1864)
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses authored Sep 20, 2024
1 parent 5a5beea commit f14bd4a
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 4 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "Enzyme"
uuid = "7da242da-08ed-463a-9acd-ee780be4f1d9"
authors = ["William Moses <wmoses@mit.edu>", "Valentin Churavy <vchuravy@mit.edu>"]
version = "0.13.1"
version = "0.13.2"

[deps]
CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82"
Expand Down
24 changes: 21 additions & 3 deletions src/Enzyme.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1417,8 +1417,8 @@ end
jacobian(::ReverseMode, f, x)
Compute the jacobian of a array-output function `f` using (potentially vector)
reverse mode. The `chunk` argument denotes the chunk size to use and `n_outs`
denotes the shape of the array returned by `f`.
reverse mode. The `chunk` argument optionally denotes the chunk size to use and
`n_outs` optionally denotes the shape of the array returned by `f` (e.g `size(f(x))`).
Example:
Expand All @@ -1434,12 +1434,30 @@ jacobian(Reverse, f, [2.0, 3.0, 4.0])
```jldoctest
f(x) = [ x[1] * x[2], x[2] + x[3] ]
grad = jacobian(ReverseWithPrimal, f, [2.0, 3.0, 4.0])
# output
(derivs = ([3.0 2.0 0.0; 0.0 1.0 1.0],), val = [6.0, 7.0])
```
```jldoctest
f(x) = [ x[1] * x[2], x[2] + x[3] ]
grad = jacobian(Reverse, f, [2.0, 3.0, 4.0], n_outs=Val((2,)))
# output
([3.0 2.0 0.0; 0.0 1.0 1.0],)
```
```jldoctest
f(x) = [ x[1] * x[2], x[2] + x[3] ]
grad = jacobian(ReverseWithPrimal, f, [2.0, 3.0, 4.0], n_outs=Val((2,)))
# output
(derivs = ([3.0 2.0 0.0; 0.0 1.0 1.0],), val = [6.0, 7.0])
```
This function will return an AbstractArray whose shape is `(size(output)..., size(input)...)`.
No guarantees are presently made about the type of the AbstractArray returned by this function
(which may or may not be the same as the input AbstractArray if provided).
Expand Down Expand Up @@ -1573,7 +1591,7 @@ this function will retun an AbstractArray of shape `size(output)` of values of t
end
if ReturnPrimal
# TODO optimize away redundant fwd pass
(; derivs=res, val=if f isa Enzyme.Const
(; derivs=(res,), val=if f isa Enzyme.Const
f.val(x)
else
f(x)
Expand Down

0 comments on commit f14bd4a

Please sign in to comment.