Skip to content

Commit

Permalink
Add explicit train!, unify update!, and auto-translate the two `A…
Browse files Browse the repository at this point in the history
…dam`s (#2082)

* explicit train, take 2

* remove train_autodiff macro

* make it stricter, to avoid batchmaybe weirdness

* remove 3-argument train! since this requires impure loss function, and you can just use update! instead really.

* remove issingletontype purity check, too strict

* tidy up

* fix tests

* use _old_to_new in Optimisers.setup too

* oops

* return nothing

* test NaN + error, tidy up

* fix test

* remove 2 vs 3 argument comment from docstring

* nice errors for update! with mixed-up input

* fix doctest by making "using Statistics" explicit

* also delete Flux.params() from the example completely

* fix

* fix
  • Loading branch information
mcabbott authored Nov 20, 2022
1 parent ad988da commit 8d948e8
Show file tree
Hide file tree
Showing 10 changed files with 365 additions and 32 deletions.
1 change: 1 addition & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

## v0.13.7
* Added [`@autosize` macro](https://github.com/FluxML/Flux.jl/pull/2078)
* New method of `train!` using Zygote's "explicit" mode. Part of a move away from "implicit" `Params`.

## v0.13.4
* Added [`PairwiseFusion` layer](https://github.com/FluxML/Flux.jl/pull/1983)
Expand Down
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ MacroTools = "0.5"
NNlib = "0.8.9"
NNlibCUDA = "0.2.4"
OneHotArrays = "0.1, 0.2"
Optimisers = "0.2.1"
Optimisers = "0.2.10"
ProgressLogging = "0.1"
Reexport = "0.2, 1.0"
SpecialFunctions = "1.8.2, 2.1.2"
Expand Down
4 changes: 2 additions & 2 deletions docs/make.jl
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
using Documenter, Flux, NNlib, Functors, MLUtils, BSON, Optimisers, OneHotArrays, Zygote, ChainRulesCore
using Documenter, Flux, NNlib, Functors, MLUtils, BSON, Optimisers, OneHotArrays, Zygote, ChainRulesCore, Statistics


DocMeta.setdocmeta!(Flux, :DocTestSetup, :(using Flux); recursive = true)

makedocs(
modules = [Flux, NNlib, Functors, MLUtils, BSON, Optimisers, OneHotArrays, Zygote, ChainRulesCore, Base],
modules = [Flux, NNlib, Functors, MLUtils, BSON, Optimisers, OneHotArrays, Zygote, ChainRulesCore, Base, Statistics],
doctest = false,
sitename = "Flux",
# strict = [:cross_references,],
Expand Down
42 changes: 16 additions & 26 deletions docs/src/models/overview.md
Original file line number Diff line number Diff line change
Expand Up @@ -77,13 +77,15 @@ julia> predict(x_train)
In order to make better predictions, you'll need to provide a *loss function* to tell Flux how to objectively *evaluate* the quality of a prediction. Loss functions compute the cumulative distance between actual values and predictions.

```jldoctest overview; filter = r"[+-]?([0-9]*[.])?[0-9]+(f[+-]*[0-9])?"
julia> loss(x, y) = Flux.Losses.mse(predict(x), y);
julia> using Statistics
julia> loss(x_train, y_train)
julia> loss(model, x, y) = mean(abs2.(model(x) .- y));
julia> loss(predict, x_train, y_train)
122.64734f0
```

More accurate predictions will yield a lower loss. You can write your own loss functions or rely on those already provided by Flux. This loss function is called [mean squared error](https://www.statisticshowto.com/probability-and-statistics/statistics-definitions/mean-squared-error/). Flux works by iteratively reducing the loss through *training*.
More accurate predictions will yield a lower loss. You can write your own loss functions or rely on those already provided by Flux. This loss function is called [mean squared error](https://www.statisticshowto.com/probability-and-statistics/statistics-definitions/mean-squared-error/) (and built-in as [`mse`](@ref Flux.Losses.mse)). Flux works by iteratively reducing the loss through *training*.

## 3. Improve the Prediction

Expand Down Expand Up @@ -112,40 +114,28 @@ julia> predict.bias
0.0
```

The dimensions of these model parameters depend on the number of inputs and outputs. Since models can have hundreds of inputs and several layers, it helps to have a function to collect the parameters into the data structure Flux expects:

```jldoctest overview; filter = r"[+-]?([0-9]*[.])?[0-9]+(f[+-]*[0-9])?"
julia> parameters = Flux.params(predict)
Params([Float32[0.9066542], Float32[0.0]])
```

These are the parameters Flux will change, one step at a time, to improve predictions. At each step, the contents of this `Params` object changes too, since it is just a collection of references to the mutable arrays inside the model:

```jldoctest overview
julia> predict.weight in parameters, predict.bias in parameters
(true, true)
```
The dimensions of these model parameters depend on the number of inputs and outputs.

The first parameter is the weight and the second is the bias. Flux will adjust predictions by iteratively changing these parameters according to the optimizer.
Flux will adjust predictions by iteratively changing these parameters according to the optimizer.

This optimiser implements the classic gradient descent strategy. Now improve the parameters of the model with a call to [`Flux.train!`](@ref) like this:

```jldoctest overview
julia> train!(loss, parameters, data, opt)
julia> train!(loss, predict, data, opt)
```

And check the loss:

```jldoctest overview; filter = r"[+-]?([0-9]*[.])?[0-9]+(f[+-]*[0-9])?"
julia> loss(x_train, y_train)
julia> loss(predict, x_train, y_train)
116.38745f0
```

It went down. Why?

```jldoctest overview; filter = r"[+-]?([0-9]*[.])?[0-9]+(f[+-]*[0-9])?"
julia> parameters
Params([Float32[7.5777884], Float32[1.9466728]])
julia> predict.weight, predict.bias
(Float32[7.5777884], Float32[1.9466728])
```

The parameters have changed. This single step is the essence of machine learning.
Expand All @@ -156,14 +146,14 @@ In the previous section, we made a single call to `train!` which iterates over t

```jldoctest overview; filter = r"[+-]?([0-9]*[.])?[0-9]+(f[+-]*[0-9])?"
julia> for epoch in 1:200
train!(loss, parameters, data, opt)
train!(loss, predict, data, opt)
end
julia> loss(x_train, y_train)
julia> loss(predict, x_train, y_train)
0.00339581f0
julia> parameters
Params([Float32[4.0178537], Float32[2.0050256]])
julia> predict.weight, predict.bias
(Float32[4.0178537], Float32[2.0050256])
```

After 200 training steps, the loss went down, and the parameters are getting close to those in the function the model is built to predict.
Expand All @@ -188,7 +178,7 @@ First, we gathered real-world data into the variables `x_train`, `y_train`, `x_t

Then, we built a single input, single output predictive model, `predict = Dense(1 => 1)`. The initial predictions weren't accurate, because we had not trained the model yet.

After building the model, we trained it with `train!(loss, parameters, data, opt)`. The loss function is first, followed by the `parameters` holding the weights and biases of the model, the training data, and the `Descent` optimizer provided by Flux. We ran the training step once, and observed that the parameters changed and the loss went down. Then, we ran the `train!` many times to finish the training process.
After building the model, we trained it with `train!(loss, predict, data, opt)`. The loss function is first, followed by the model itself, the training data, and the `Descent` optimizer provided by Flux. We ran the training step once, and observed that the parameters changed and the loss went down. Then, we ran the `train!` many times to finish the training process.

After we trained the model, we verified it with the test data to verify the results.

Expand Down
4 changes: 4 additions & 0 deletions src/Flux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@ export Descent, Adam, Momentum, Nesterov, RMSProp,
AdamW, RAdam, AdaBelief, InvDecay, ExpDecay,
WeightDecay, ClipValue, ClipNorm

include("train.jl")
using .Train
# using .Train: setup, @train_autodiff

using CUDA
const use_cuda = Ref{Union{Nothing,Bool}}(nothing)

Expand Down
102 changes: 102 additions & 0 deletions src/deprecations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -82,3 +82,105 @@ Base.@deprecate_binding ADAGrad AdaGrad
Base.@deprecate_binding ADADelta AdaDelta

@deprecate rng_from_array() default_rng_value()

#=
# Valid method in Optimise, old implicit style, is:
train!(loss, ps::Params, data, opt::AbstractOptimiser; cb = () -> ())
# Valid methods in Train, new explict style, are:
train!(loss, model, data, opt) # preferred
train!(loss, model, data, opt::Optimisers.AbstractRule) # if you forget setup
# Provide friendly errors for what happens if you mix these up:
=#
import .Optimise: train!

train!(loss, ps::Params, data, opt) = error(
"""can't mix implict Params with explict state!
To use `Flux.params(m)` in `train!`, the 4th argument must be from the old `Flux.Optimise` sub-module.
But better to use the new explicit style, in which `m` itself is the 2nd argument.
""")

train!(loss, ps::Params, data, opt::Optimisers.AbstractRule) = error(
"""can't mix implict Params with explict rule from Optimisers.jl
To use `Flux.params(m)` in `train!`, the 4th argument must be from the old `Flux.Optimise` sub-module.
But better to use the new explicit style, in which `m` itself is the 2nd argument.
""")

train!(loss, model, data, opt::Optimise.AbstractOptimiser) = train!(loss, model, data, _old_to_new(opt))

# Next, to use the new `setup` with the still-exported old-style `Adam` etc:
import .Train: setup
setup(rule::Optimise.AbstractOptimiser, model) = setup(_old_to_new(rule), model)
# ... and allow accidental use of `Optimisers.setup` to do the same:
Optimisers.setup(rule::Optimise.AbstractOptimiser, model) = setup(_old_to_new(rule), model)

for T in [:Descent, :Adam, :Momentum, :Nesterov,
:AdaGrad, :AdaMax, :AdaDelta, :AMSGrad, :NAdam, :RAdam, :OAdam, :AdaBelief,
# :InvDecay, :ExpDecay,
]
@eval function _old_to_new(rule::$T)
args = map(f -> getfield(rule, f), fieldnames(Optimisers.$T))
Optimisers.$T(args...)
end
end
_old_to_new(rule::Optimiser) = Optimisers.OptimiserChain(map(_old_to_new, rule.os)...)
const OptimiserChain = Optimise.Optimiser # lets you use new name with implicit params too.
_old_to_new(rule::WeightDecay) = Optimisers.WeightDecay(rule.wd) # called gamma now
_old_to_new(rule::ClipNorm) = Optimisers.ClipNorm(rule.thesh) # called omega, and there are more fields
_old_to_new(rule::ClipValue) = Optimisers.ClipGrad(rule.thesh) # called delta now, and struct name differs
const ClipGrad = Optimise.ClipValue
_old_to_new(rule::RMSProp) = Optimisers.RMSProp(rule.eta, rule.rho, rule.epsilon) # RMSProp has no field centred

_old_to_new(rule) = error("Flux.setup does not know how to translate this old-style implicit rule to a new-style Optimisers.jl explicit rule")

# Since `update!` should be called in a loop, it makes less sense to call `setup` for you if you forgot.
# But let's make sure that such uses give a helpful error:
import .Optimise: update!

function update!(opt::Optimise.AbstractOptimiser, model, grad)
# This error method requires narrowing the main worker method of Flux.Optimise
# to accept only arrays. Remove if this causes problems!
# update!(opt::Flux.Optimise.AbstractOptimiser, x::AbstractArray, x̄)
error("""Invalid input to `update!`.
* For the implicit style, this needs `update(::AbstractOptimiser, ::Params, ::Grads)`
* For the explicit style, `update(state, model, grad)` needs `state = Flux.setup(opt, model)`.
""")
end

# An easy error to make is to pass result of explicit gradient(...), not gradient(...)[1]
# Can't catch every case, but can catch many simple Flux models:

function update!(opt, model::Chain, grads::Tuple)
# Zygote will make a NamedTuple{(:layers,)} for the gradient of Chain, Diffractor a Tangent
@warn """explicit `update!(opt, model, grad)` wants the gradient for the model alone,
not the whole tuple from `gradient(m -> loss(m, x, y), model)`. You probably want `grads[1]`."""
update!(opt, model, grads[1])
end

function update!(opt::Optimise.AbstractOptimiser, model::Chain, grads::Tuple) # ambiguity
update!(opt, model, grads[1]) # calls error case "Invalid input" just above
end

# One more easy error to catch is using explicit gradient with `params(m)`:

function update!(opt::Optimise.AbstractOptimiser, ::Params, grads::Union{Tuple, NamedTuple})
error("""can't mix implicit Params with explicit gradients!
* For the implicit style, this needs `update(::AbstractOptimiser, ::Params, ::Grads)` with implicit gradient.
* For the explicit style, `update(state, model, grad)` needs the model itself, and `state = Flux.setup(opt, model)`.
""")
end

# v0.14 deprecations

# Enable these when 0.14 is released, and delete const ClipGrad = Optimise.ClipValue etc:
# Base.@deprecate_binding Optimiser OptimiserChain
# Base.@deprecate_binding ClipValue ClipGrad

# train!(loss::Function, ps::Zygote.Params, data, opt) = throw(ArgumentError(
# """On Flux 0.14, `train!` no longer accepts implicit `Zygote.Params`.
# Instead of `train!(loss_xy, Flux.params(model), data, Adam())`
# it now needs `opt = Flux.setup(Adam(), model); train!(loss_mxy, model, data, opt)`
# where `loss_mxy` accepts the model as its first argument.
# """
# ))
15 changes: 13 additions & 2 deletions src/optimise/train.jl
Original file line number Diff line number Diff line change
@@ -1,18 +1,25 @@
using ProgressLogging: @progress, @withprogress, @logprogress
import Zygote: Params, gradient, withgradient

# Add methods to Optimisers.jl's function, so that there is just one Flux.update!
# for both explicit and implicit parameters.
import Optimisers.update!

"""
update!(opt, p, g)
update!(opt, ps::Params, gs)
Perform an update step of the parameters `ps` (or the single parameter `p`)
according to optimizer `opt` and the gradients `gs` (the gradient `g`).
according to optimizer `opt::AbstractOptimiser` and the gradients `gs` (the gradient `g`).
As a result, the parameters are mutated and the optimizer's internal state may change.
The gradient could be mutated as well.
!!! note
This method for implicit `Params` (and `AbstractOptimiser`) will be removed from Flux 0.14.
The explicit method `update!(opt, model, grad)` from Optimisers.jl will remain.
"""
function update!(opt::AbstractOptimiser, x, x̄)
function update!(opt::AbstractOptimiser, x::AbstractArray, x̄)
x̄r = copyto!(similar(x̄), x̄) # Flux.Optimise assumes it can mutate the gradient. This is not
# safe due to aliasing, nor guaranteed to be possible, e.g. Fill.
x .-= apply!(opt, x, x̄r)
Expand Down Expand Up @@ -88,6 +95,10 @@ batchmemaybe(x::Tuple) = x
Uses a `loss` function and training `data` to improve the
model's parameters according to a particular optimisation rule `opt`.
!!! note
This method with implicit `Params` will be removed from Flux 0.14.
It should be replaced with the explicit method `train!(loss, model, data, opt)`.
For each `d in data`, first the gradient of the `loss` is computed like this:
```
gradient(() -> loss(d...), pars) # if d isa Tuple
Expand Down
Loading

0 comments on commit 8d948e8

Please sign in to comment.