Skip to content

Commit

Permalink
fix doctests
Browse files Browse the repository at this point in the history
  • Loading branch information
mcabbott committed Oct 12, 2022
1 parent 0d6619a commit d13e52a
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 8 deletions.
10 changes: 5 additions & 5 deletions src/Optimisers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ or [`update!`](@ref).
julia> m = (x = rand(3), y = (true, false), z = tanh);
julia> Optimisers.setup(Momentum(), m) # same field names as m
(x = Leaf(Momentum{Float32}(0.01, 0.9), [0.0, 0.0, 0.0]), y = (nothing, nothing), z = nothing)
(x = Leaf(Momentum{Float32}(0.01, 0.9), [0.0, 0.0, 0.0]), y = ((), ()), z = ())
```
The recursion into structures uses Functors.jl, and any new `struct`s containing parameters
Expand All @@ -90,15 +90,15 @@ julia> struct Layer; mat; fun; end
julia> model = (lay = Layer([1 2; 3 4f0], sin), vec = [5, 6f0]);
julia> Optimisers.setup(Momentum(), model) # new struct is by default ignored
(lay = nothing, vec = Leaf(Momentum{Float32}(0.01, 0.9), Float32[0.0, 0.0]))
(lay = (), vec = Leaf(Momentum{Float32}(0.01, 0.9), Float32[0.0, 0.0]))
julia> destructure(model)
(Float32[5.0, 6.0], Restructure(NamedTuple, ..., 2))
julia> using Functors; @functor Layer # annotate this type as containing parameters
julia> Optimisers.setup(Momentum(), model)
(lay = (mat = Leaf(Momentum{Float32}(0.01, 0.9), Float32[0.0 0.0; 0.0 0.0]), fun = nothing), vec = Leaf(Momentum{Float32}(0.01, 0.9), Float32[0.0, 0.0]))
(lay = (mat = Leaf(Momentum{Float32}(0.01, 0.9), Float32[0.0 0.0; 0.0 0.0]), fun = ()), vec = Leaf(Momentum{Float32}(0.01, 0.9), Float32[0.0, 0.0]))
julia> destructure(model)
(Float32[1.0, 3.0, 2.0, 4.0, 5.0, 6.0], Restructure(NamedTuple, ..., 6))
Expand All @@ -120,12 +120,12 @@ See also [`update!`](@ref), which will be faster for models of ordinary `Array`s
julia> m = (x = Float32[1,2,3], y = tanh);
julia> t = Optimisers.setup(Descent(0.1f0), m)
(x = Leaf(Descent{Float32}(0.1), nothing), y = nothing)
(x = Leaf(Descent{Float32}(0.1), nothing), y = ())
julia> g = (x = [1,1,1], y = nothing); # fake gradient
julia> Optimisers.update(t, m, g)
((x = Leaf(Descent{Float32}(0.1), nothing), y = nothing), (x = Float32[0.9, 1.9, 2.9], y = tanh))
((x = Leaf(Descent{Float32}(0.1), nothing), y = ()), (x = Float32[0.9, 1.9, 2.9], y = tanh))
```
"""
update
Expand Down
6 changes: 3 additions & 3 deletions src/adjust.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,15 @@ To change just the learning rate, provide a number `η::Real`.
julia> m = (vec = rand(Float32, 2), fun = sin);
julia> st = Optimisers.setup(Nesterov(), m) # stored momentum is initialised to zero
(vec = Leaf(Nesterov{Float32}(0.001, 0.9), Float32[0.0, 0.0]), fun = nothing)
(vec = Leaf(Nesterov{Float32}(0.001, 0.9), Float32[0.0, 0.0]), fun = ())
julia> st, m = Optimisers.update(st, m, (vec = [16, 88], fun = nothing)); # with fake gradient
julia> st
(vec = Leaf(Nesterov{Float32}(0.001, 0.9), Float32[-0.016, -0.088]), fun = nothing)
(vec = Leaf(Nesterov{Float32}(0.001, 0.9), Float32[-0.016, -0.088]), fun = ())
julia> st = Optimisers.adjust(st, 0.123) # change learning rate, stored momentum untouched
(vec = Leaf(Nesterov{Float32}(0.123, 0.9), Float32[-0.016, -0.088]), fun = nothing)
(vec = Leaf(Nesterov{Float32}(0.123, 0.9), Float32[-0.016, -0.088]), fun = ())
```
To change other parameters, `adjust` also accepts keyword arguments matching the field
Expand Down

0 comments on commit d13e52a

Please sign in to comment.