Skip to content

Commit

Permalink
Update documentation (#681)
Browse files Browse the repository at this point in the history
* Update dependencies

* Update fenced code block examples

* Fix indents

* Add headers in docstrings

* Add backticks

* Fix admonition blocks

* Fix more doctests
  • Loading branch information
abhro authored Sep 3, 2024
1 parent a95c181 commit 9627bd6
Show file tree
Hide file tree
Showing 18 changed files with 631 additions and 502 deletions.
851 changes: 469 additions & 382 deletions docs/Manifest.toml

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion docs/src/FAQ.md
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ For example, in `access(xs, n) = xs[n]`, the derivative of `access` with respect
When no custom `frule` or `rrule` exists, if you try to call one of those, it will return `nothing` by default.
As a result, you may encounter errors like

```julia
```plain
MethodError: no method matching iterate(::Nothing)
```

Expand Down
2 changes: 1 addition & 1 deletion docs/src/ad_author/opt_out.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ We provide two ways to know that a rule has been opted out of.
`@opt_out` defines a `frule` or `rrule` matching the signature that returns `nothing`.

If you are in a position to generate code, in response to values returned by function calls then you can do something like:
```@julia
```julia
res = rrule(f, xs)
if res === nothing
y, pullback = perform_ad_via_decomposition(r, xs) # do AD without hitting the rrule
Expand Down
74 changes: 37 additions & 37 deletions docs/src/design/changing_the_primal.md
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ What about using `sincos`?
```@raw html
<details open><summary>Example for `sin`</summary>
```
```julia
```julia-repl
julia> using BenchmarkTools
julia> @btime sin(x) setup=(x=rand());
Expand All @@ -76,7 +76,7 @@ julia> 3.838 + 4.795
8.633
```
vs computing both together:
```julia
```julia-repl
julia> @btime sincos(x) setup=(x=rand());
6.028 ns (0 allocations: 0 bytes)
```
Expand All @@ -96,7 +96,7 @@ So we can save time, if we can reuse that `exp(x)`.
<details open><summary>Example for the logistic sigmoid</summary>
```
If we have to computing separately:
```julia
```julia-repl
julia> @btime 1/(1+exp(x)) setup=(x=rand());
5.622 ns (0 allocations: 0 bytes)
Expand All @@ -108,7 +108,7 @@ julia> 5.622 + 6.036
```

vs reusing `exp(x)`:
```julia
```julia-repl
julia> @btime exp(x) setup=(x=rand());
5.367 ns (0 allocations: 0 bytes)
Expand Down Expand Up @@ -148,8 +148,8 @@ x̄ = pullback_at(f, x, y, ȳ, intermediates)
```
```julia
function augmented_primal(::typeof(sin), x)
y, cx = sincos(x)
return y, (; cx=cx) # use a NamedTuple for the intermediates
y, cx = sincos(x)
return y, (; cx=cx) # use a NamedTuple for the intermediates
end

pullback_at(::typeof(sin), x, y, ȳ, intermediates) = ȳ * intermediates.cx
Expand All @@ -163,9 +163,9 @@ pullback_at(::typeof(sin), x, y, ȳ, intermediates) = ȳ * intermediates.cx
```
```julia
function augmented_primal(::typeof(σ), x)
ex = exp(x)
y = ex / (1 + ex)
return y, (; ex=ex) # use a NamedTuple for the intermediates
ex = exp(x)
y = ex / (1 + ex)
return y, (; ex=ex) # use a NamedTuple for the intermediates
end

pullback_at(::typeof(σ), x, y, ȳ, intermediates) = ȳ * y / (1 + intermediates.ex)
Expand All @@ -189,8 +189,8 @@ And storing all these things on the tape — inputs, outputs, sensitivities, int
What if we generalized the idea of the `intermediate` named tuple, and had `augmented_primal` return a struct that just held anything we might want put on the tape.
```julia
struct PullbackMemory{P, S}
primal_function::P
state::S
primal_function::P
state::S
end
# convenience constructor:
PullbackMemory(primal_function; state...) = PullbackMemory(primal_function, state)
Expand All @@ -211,8 +211,8 @@ which is much cleaner.
```
```julia
function augmented_primal(::typeof(sin), x)
y, cx = sincos(x)
return y, PullbackMemory(sin; cx=cx)
y, cx = sincos(x)
return y, PullbackMemory(sin; cx=cx)
end

pullback_at(pb::PullbackMemory{typeof(sin)}, ȳ) = ȳ * pb.cx
Expand All @@ -226,9 +226,9 @@ pullback_at(pb::PullbackMemory{typeof(sin)}, ȳ) = ȳ * pb.cx
```
```julia
function augmented_primal(::typeof(σ), x)
ex = exp(x)
y = ex / (1 + ex)
return y, PullbackMemory(σ; y=y, ex=ex)
ex = exp(x)
y = ex / (1 + ex)
return y, PullbackMemory(σ; y=y, ex=ex)
end

pullback_at(pb::PullbackMemory{typeof(σ)}, ȳ) = ȳ * pb.y / (1 + pb.ex)
Expand Down Expand Up @@ -256,8 +256,8 @@ x̄ = pb(ȳ)
```
```julia
function augmented_primal(::typeof(sin), x)
y, cx = sincos(x)
return y, PullbackMemory(sin; cx=cx)
y, cx = sincos(x)
return y, PullbackMemory(sin; cx=cx)
end
(pb::PullbackMemory{typeof(sin)})(ȳ) = ȳ * pb.cx
```
Expand All @@ -271,9 +271,9 @@ end
```
```julia
function augmented_primal(::typeof(σ), x)
ex = exp(x)
y = ex / (1 + ex)
return y, PullbackMemory(σ; y=y, ex=ex)
ex = exp(x)
y = ex / (1 + ex)
return y, PullbackMemory(σ; y=y, ex=ex)
end

(pb::PullbackMemory{typeof(σ)})(ȳ) = ȳ * pb.y / (1 + pb.ex)
Expand All @@ -295,16 +295,16 @@ Let's go back and think about the changes we would have make to go from our orig
To rewrite that original formulation in the new pullback form we have:
```julia
function augmented_primal(::typeof(sin), x)
y = sin(x)
return y, PullbackMemory(sin; x=x)
y = sin(x)
return y, PullbackMemory(sin; x=x)
end
(pb::PullbackMemory)(ȳ) = ȳ * cos(pb.x)
```
To go from that to:
```julia
function augmented_primal(::typeof(sin), x)
y, cx = sincos(x)
return y, PullbackMemory(sin; cx=cx)
y, cx = sincos(x)
return y, PullbackMemory(sin; cx=cx)
end
(pb::PullbackMemory)(ȳ) = ȳ * pb.cx
```
Expand All @@ -317,17 +317,17 @@ end
```
```julia
function augmented_primal(::typeof(σ), x)
y = σ(x)
return y, PullbackMemory(σ; y=y, x=x)
y = σ(x)
return y, PullbackMemory(σ; y=y, x=x)
end
(pb::PullbackMemory{typeof(σ)})(ȳ) = ȳ * pb.y * σ(-pb.x)
```
to get to:
```julia
function augmented_primal(::typeof(σ), x)
ex = exp(x)
y = ex/(1 + ex)
return y, PullbackMemory(σ; y=y, ex=ex)
ex = exp(x)
y = ex/(1 + ex)
return y, PullbackMemory(σ; y=y, ex=ex)
end
(pb::PullbackMemory{typeof(σ)})(ȳ) = ȳ * pb.y/(1 + pb.ex)
```
Expand Down Expand Up @@ -356,9 +356,9 @@ Replacing `PullbackMemory` with a closure that works the same way lets us avoid
```
```julia
function augmented_primal(::typeof(sin), x)
y, cx = sincos(x)
pb = ȳ -> cx * ȳ # pullback closure. closes over `cx`
return y, pb
y, cx = sincos(x)
pb = ȳ -> cx * ȳ # pullback closure. closes over `cx`
return y, pb
end
```
```@raw html
Expand All @@ -370,10 +370,10 @@ end
```
```julia
function augmented_primal(::typeof(σ), x)
ex = exp(x)
y = ex / (1 + ex)
pb = ȳ -> ȳ * y / (1 + ex) # pullback closure. closes over `y` and `ex`
return y, pb
ex = exp(x)
y = ex / (1 + ex)
pb = ȳ -> ȳ * y / (1 + ex) # pullback closure. closes over `y` and `ex`
return y, pb
end
```
```@raw html
Expand Down
6 changes: 3 additions & 3 deletions docs/src/design/many_tangents.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ Structural tangents are derived from the structure of the input.
Either automatically, as part of the AD, or manually, as part of a custom rule.

Consider the structure of `DateTime`:
```julia
```julia-repl
julia> dump(now())
DateTime
instant: UTInstant{Millisecond}
Expand Down Expand Up @@ -83,15 +83,15 @@ Where there is no natural tangent type for the outermost type but there is for s

Consider if we had a representation of a country's GDP as output by some continuous time model like a Gaussian Process, where that representation is as a sequence of `TimeSample`s
structured as follows:
```julia
```julia-repl
julia> struct TimeSample
time::DateTime
value::Float64
end
```

We can look at its structure:
```julia
```julia-repl
julia> dump(TimeSample(now(), 2.6e9))
TimeSample
time: DateTime
Expand Down
1 change: 1 addition & 0 deletions docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,7 @@ end
# output
```

```jldoctest index
#### Find dfoo/dx via rrules
#### First the forward pass, gathering up the pullbacks
Expand Down
2 changes: 1 addition & 1 deletion docs/src/rule_author/converting_zygoterules.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Converting ZygoteRules.@adjoint to `rrule`s
# Converting `ZygoteRules.@adjoint` to `rrule`s

[ZygoteRules.jl](https://github.com/FluxML/ZygoteRules.jl) is a legacy package similar to ChainRulesCore but supporting [Zygote.jl](https://github.com/FluxML/Zygote.jl) only.

Expand Down
3 changes: 2 additions & 1 deletion docs/src/rule_author/example.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,9 @@ end
```

We can check this rule against a finite-differences approach using [`ChainRulesTestUtils`](https://github.com/JuliaDiff/ChainRulesTestUtils.jl):
```julia
```julia-repl
julia> using ChainRulesTestUtils
julia> test_rrule(foo_mul, Foo(rand(3, 3), 3.0), rand(3, 3))
Test Summary: | Pass Total
test_rrule: foo_mul on Foo{Float64},Matrix{Float64} | 10 10
Expand Down
23 changes: 13 additions & 10 deletions docs/src/rule_author/which_functions_need_rules.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,9 @@ function addone(a::AbstractArray)
end
```
complains that
```julia
```julia-repl
julia> using Zygote
julia> gradient(addone, a)
ERROR: Mutating arrays is not supported
```
Expand All @@ -50,7 +51,7 @@ function ChainRules.rrule(::typeof(addone), a)
end
```
the gradient can be evaluated:
```julia
```julia-repl
julia> gradient(addone, a)
([1.0, 1.0, 1.0],)
```
Expand Down Expand Up @@ -86,7 +87,7 @@ function exception(x)
end
```
does not work
```julia
```julia-repl
julia> gradient(exception, 3.0)
ERROR: Compiling Tuple{typeof(exception),Int64}: try/catch is not supported.
```
Expand All @@ -101,7 +102,7 @@ function ChainRulesCore.rrule(::typeof(exception), x)
end
```

```julia
```julia-repl
julia> gradient(exception, 3.0)
(6.0,)
```
Expand All @@ -123,9 +124,11 @@ function mse(y, ŷ)
end
```
takes a lot longer to AD through
```julia
julia> y = rand(30)
julia> ŷ = rand(30)
```julia-repl
julia> y = rand(30);
julia> ŷ = rand(30);
julia> @btime gradient(mse, $y, $ŷ)
38.180 μs (993 allocations: 65.00 KiB)
```
Expand All @@ -142,7 +145,7 @@ function ChainRules.rrule(::typeof(mse), x, x̂)
end
```
which is much faster
```julia
```julia-repl
julia> @btime gradient(mse, $y, $ŷ)
143.697 ns (2 allocations: 672 bytes)
```
Expand All @@ -159,7 +162,7 @@ function sum3(array)
return x+y+z
end
```
```julia
```julia-repl
julia> @btime gradient(sum3, rand(30))
424.510 ns (9 allocations: 2.06 KiB)
```
Expand All @@ -176,7 +179,7 @@ function ChainRulesCore.rrule(::typeof(sum3), a)
end
```
turns out to be significantly faster
```julia
```julia-repl
julia> @btime gradient(sum3, rand(30))
192.818 ns (3 allocations: 784 bytes)
```
6 changes: 3 additions & 3 deletions docs/src/rule_author/writing_good_rules.md
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ Because `typeof(Bar)` is `DataType`, using this to define an `rrule`/`frule` wil

You can check which to use with `Core.Typeof`:

```julia
```julia-repl
julia> function foo end
foo (generic function with 0 methods)
Expand Down Expand Up @@ -254,7 +254,7 @@ function ChainRulesCore.rrule(::typeof(double_it), x)
end
```
Ends up infering a return type of `Any`
```julia
```julia-repl
julia> _, pullback = rrule(double_it, [2.0, 3.0])
([4.0, 6.0], var"#double_it_pullback#8"(Core.Box(var"#double_it_pullback#8"(#= circular reference @-2 =#))))
Expand Down Expand Up @@ -289,7 +289,7 @@ function ChainRulesCore.rrule(::typeof(double_it), x)
end
```
This infers just fine:
```julia
```julia-repl
julia> _, pullback = rrule(double_it, [2.0, 3.0])
([4.0, 6.0], _double_it_pullback)
Expand Down
Loading

0 comments on commit 9627bd6

Please sign in to comment.