Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Upgrade train! to work with explicit parameters #2029

Closed
wants to merge 5 commits into from

Conversation

mcabbott
Copy link
Member

@mcabbott mcabbott commented Jul 28, 2022

This PR proposes to move away from implicit parameters not by simply deleting train!, but instead by re-writing it to use explicit mode. This means that implicit train! has an easy upgrade path, and the new explicit train! can later be changed to use something other than Zygote.

The option to use Optimisers.jl directly remains. But the style is quite different, and looking after the state yourself requires a certain amount of boilerplate. According to this PR, Flux should continue to offer a tidier version, which exploits mutation to update models & state objects.

The mutable state involves a new optimiser wrapper type, which is used for both explicit and implicit mode. Both modes use Optimisers.jl internally, so all the rule definitions in Flux.Optimise can be deleted. While many uses of the old train! will continue to work without modification, I think this is likely to be sufficiently breaking that it can only be in v0.14.

Example

A simple example that runs both modes, and works if you overload explicit_withgradient to use Diffractor instead of Zygote in that mode:

using Flux, Random
data = [(rand(3,2).*[i,1,20/i], [i i]) for i in 1:50] |> shuffle!;

# This exact code works on Flux@0.13. There, train! returns nothing:
model2 = Chain(Dense(3 => 7, relu), Dense(7 => 1))
opt2 = Flux.Adam()
Flux.train!(Flux.params(model2), data, opt2) do x, y
  Flux.mse(model2(x), y)
end
opt2  # contains an IdDict

# This is the new "explicit" method of Train
model1 = Chain(Dense(3 => 7, relu), Dense(7 => 1))
opt1 = Flux.Adam()
Flux.train!(model1, data, opt1) do m, x, y
  Flux.mse(m(x), y)
end |> sum
opt1  # contains state tree

# This changes the AD used:
import Diffractor
function Flux.Train.explicit_withgradient(f, args...)
  y, back = Diffractor.∂⃖¹(f, args...)
  @info "used Diffractor!"
  return (; value = y, gradient = Base.tail(back(one(y))))
end

# This is new 3-arg train!, one step not an iteration over data:
x1, y1 = data[1]
Flux.train!(model1, opt1) do m
  Flux.mse(m(x1), y1)
end

Checklist

  • Cover all the optimisation rules
  • More tests!
  • Entry in NEWS.md
  • Many documentation changes

Comment on lines 46 to 54
for opt in [
:Descent, :Adam, :Momentum, :Nesterov, :RMSProp,
:AdaGrad, :AdaMax, :AdaDelta, :AMSGrad, :NAdam, :AdamW, :RAdam, :OAdam, :AdaBelief,
# :InvDecay, :ExpDecay, :WeightDecay, :stop, :skip, :Optimiser,
# :ClipValue, :ClipNorm,
# TODO check that parameters line up nicely old-vs-new, and include the remaining rules
]
@eval $opt(parameters...; kw...) = FluxState(Optimisers.$opt(parameters...; kw...), missing)
end
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Punning rule constructors names like this feels like a recipe for confusion. Could we compromise a bit on brevity and, say, define a shorthand alias for the FluxState constructor instead?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might be too cute, but to try to defend this: I'm not sure it's a pun, it's more like different packages using the same word for slightly different implementations of the same concept. Like Zygote.gradient and ForwardDiff.gradient, not an accident they share a symbol, but never directly interchangeable. Sometimes one will use the other, internally.

There might be a much nicer name for FluxState, too.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, punning is probably not the right term. The gradient example is a good one though, because unlike Zygote and ForwardDiff, Optimisers is a dependency of Flux. It would be equivalent to Flux defining its own gradient function.

Stepping back a bit, is there a way to keep train!(loss::Function, pars::Params, opt::Flux.AbstractOptimiser) without using FluxState? I think that would allow us to be a bit more aggressive with the new train! methods.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I chose the example because ForwardDiff is also a dependency of Zygote, and will sometimes be used in the evaluation of the gradient, internally.

This PR deletes the notion of Flux.AbstractOptimiser, on the grounds that maintaining two different implementations of Adam (etc) seems like a bug magnet. And because it only takes a few lines to replace that with a Params-to-Optimisers.jl bridge. But it means the optimiser is opt::FluxState now, instead of opt::Flux.AbstractOptimiser. Because this is mutable, a lot of code which worked on 0.13 will still work with this PR. Maybe I'm not clear on what you're suggesting here.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The big question is whether the additional internal complexity and possible naming confusion are worth it if this is already a breaking change. Now, there is the question of what is a breaking change (e.g. does preventing deserialization of saved models count?), but assuming consensus on it being one I don't think the tradeoff is worth it.

One alternative would be to make a version of train! that works with Optimisers.jl types sans wrapper and ask users to qualify their imports until 0.14. This migration will require more user effort, but should also be free of any surprising errors that come from a false sense of security (this code runs, so that code should too, right?).

I think the main factors for deciding between both approaches are when 0.14 is coming and how long it will last. The longer the former, the more leeway we have to ask users to change their code. The longer the latter, the less jarring the transition from Flux.Adam() isa FluxState to Flux.Adam() isa AbstractRule will be.

Copy link
Member Author

@mcabbott mcabbott Aug 1, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure I follow what alternatives you are comparing.

One alternative would be to make a version of train! that works with Optimisers.jl types sans wrapper

I don't think this can work. train! needs to put the updated optimiser state somewhere. Even if it's given the Optimisers.jl state tree, not just the rule, it can't mutate that. The exception is rules like Descent() with no state, which this method accepts (and checks & warns you if necessary):
https://github.com/FluxML/Flux.jl/pull/2029/files#diff-c835714f94af5b03e96dd7e45827c090cac82c1c168f535ab0d81280de54eb69R112-R120

At present, Flux.Adam() has a mutable IdDict in which the state is stored. After this PR, Flux.Adam() is different struct, storing things in a slightly different IdDict. But for implicit parameters, you use it the same way. No code changes except for Adam not being exported.

At first I was going to suggest just the version of train! for explicit parameters, leave the old one alone. But this seems more confusing, as that already needs some mutable container like FluxState for its own use, which is then separate both from the AbstractOptimiser container used for implicit train!, and from what Optimisers.jl does. 3 different flavours. Replacing the implicit train! gets it down to 2 flavours: Flux.jl's, and Optimisers.jl's.

when 0.14 is coming and how long it will last

Note that what's envisaged here is that train! with explicit parameters, and FluxState and Flux.Adam(), will last beyond 0.14. Perhaps in 0.x they will call Diffractor.gradient or Yota.grad instead. And perhaps in 0.y the implicit train!, and Zygote dep, can be dropped entirely.

Copy link
Member

@ToucheSir ToucheSir Aug 1, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

new_opt_state = train!(loss, model, data, Optimisers.setup(opt)) works and also keeps things to 2 flavours.

Personally, I'd prefer not to see Flux.Adam() != Optimisers.Adam() last beyond 0.14. Perhaps Flux could absorb parts of https://github.com/FluxML/FluxTraining.jl/blob/master/src/training.jl in the future after its own functionality has been further diffused into lower-level libraries, but train! has proven to be awkward since its inception. It's illustrative to see how much "cleaner" FluxML/FluxTraining.jl#114 was than all previous attempts to modernize train!. I think that hints at the latter function being a fundamentally poor abstraction.

@zsz00
Copy link
Contributor

zsz00 commented Sep 17, 2022

what's the status now?

@mcabbott
Copy link
Member Author

Nothing has moved really. But after reading the docs a bit, I think some variant of train! ought to survive the move away from implicit parameters.

A simpler take than this PR's present state might be to remove implicit Flux.params etc. completely, and merge something like FluxML/Optimisers.jl#106 which will, as a side-effect, let train! update the optimiser state in-place. Then there would still be a new step like opt = Flux.setup(model, Adam()) required, but the resulting opt would not need to be explicitly passed around.

@mcabbott
Copy link
Member Author

Closing in favour of #2082

Adding more code to deal with implicit parameters in new ways doesn't seem great.

@mcabbott mcabbott closed this Oct 16, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants