Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow shared parameters, take III #106

Merged
merged 11 commits into from
Oct 13, 2022
Merged

Conversation

mcabbott
Copy link
Member

@mcabbott mcabbott commented Aug 28, 2022

Another take on #100. Borrows the idea of making Leaf mutable.

TriesTried to be simpler by pushing more of the recursion onto fmap:

  • setup is just fmapstructure really. Its notion of sharing is thus exactly the one of Functors, one source of truth. We should fix that not to share isbits types, eventually.
  • update! is just fmap. Much of the complication of the old walk was to reconstruct both the state and the model on the way out. But this isn't needed if Leaf is mutated.

Tests from #100 pass with first commit. However, the shared Leaves must always match shared Arrays. It's possible that this scenario can be done even more simply, possibly without mutable Leaf.

What #100 does is instead to take shared Leaves as the truth about parameter sharing, which some future API could set in a way not matching the model (for ImmutableArrays, etc) even though present setup will not. Then update! cannot just be fmap, and needs one more separate IdDict for the parameters. Second commit here e84b61b bolts that on, and adds a test of it (which also pass using #100). But it's a bit ugly.

Edit: Third commit 0de29e1 instead just replaces the walk used for fmap(f, tree, x) to use re from its 2nd argument, while Functors still uses the cache on the 1st argument. That's tidier.

But the state tree contains the the same () at every non-parameter node, and Functors caches the results of these... we should fix this upstream? A possible hack for now would be to supply a special cache IdDict{Leaf} which cannot store anything else -- done in e17e474.

But... that's still not right. If there are mutable layer structs, then I think you cannot rely on the ID of mutable Leaf to tie things. So I gave up on customising fmap and wrote out the recursion using (x,Leaf()) as the key for reconstruction.


Gradient accumulation uses an IdDict as in #100, but stores a broadcasted adding the pieces. Which it thus requires all apply! methods to accept. They all do. Changed to eager addition.

Does not at present allow for more than one derivative. But no rules use that. Added. There were no tests it seems.

Fixes the bug noted in #100 that update could in fact mutate the state. Does this by just saying @functor Leaf. Added a test.

One further possibility with a mutable Leaf is that if can easily have a flag to mark some parameters as temporarily frozen. This is implemented here (with no API to set the flag). Not sure it's what we want though. Easy to remove but perhaps if we're changing the struct we should consider other changes we might want.

Because setup does not call itself in recursion, it is fairly easy to add a warning if the model has no parameters. This was something someone complained about, I forget where.


Closes #42, closes #100, closes #97

@mcabbott mcabbott marked this pull request as draft August 28, 2022 16:47
@ToucheSir
Copy link
Member

ToucheSir commented Aug 28, 2022

A possible hack for now would be to supply a special cache IdDict{Leaf} which cannot store anything else.

I'd say this is less of a hack and something we should be doing more often. Either define a custom cache type, or (better) attach the cache to the callback itself by memoizing it. Then fmap and the rest of Functors can avoid cache management altogether.

src/adjust.jl Outdated Show resolved Hide resolved
src/interface.jl Outdated
function setup(rule::AbstractRule, model)
cnt = Ref(0)
# Rely on Functors to identify shared arrays, they will share a Leaf in this tree:
tree = fmapstructure(model, exclude = isnumeric) do x
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's pretty surprising tests pass with this, as it doesn't check trainable at all.

update!(t′, x′, x̄s...)
function _update!(tree, x; grads, params)
haskey(params, (tree,x)) && return params[(tree,x)]
isbits(tree) && return x # means () is not cached, and also (((),),)
Copy link
Member

@ToucheSir ToucheSir Aug 29, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This does imply we will be caching almost every level of an average Flux model (since BitsType{NotBits, BitsTypes...} is not a bitstype). objectid being not the fastest function in the world, perhaps both cache lookup and insertion should be additionally guarded by ismutable(x).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wondered this too. For large ImmutableArrays this may eventually need something fancier. But for now I think every fmap walk does the same thing.

Copy link
Member

@ToucheSir ToucheSir Aug 29, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I wasn't even thinking about those, but cases like JuliaLang/julia#43542. We're unlikely to see any truly pathological behaviour, but I have to imagine the single comparison ismutable makes is more efficient than the recursive hash function objectid uses.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK. I guess ismutable really is right here. For parameter arrays IIRC there was a concern that it tells you e.g. that PermutedDimsArray is immutable. But for known non-leaf types, maybe it's always right?

Copy link
Member

@ToucheSir ToucheSir Aug 29, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. PermutedDimsArray at least does implement functor, but you can always find an array wrapper which hasn't. Perhaps then the check should be isleaf instead? The isbits check is still useful either way.

Edit: I suppose isnumeric makes more sense since it forwards to isleaf already and setup guarantees only unfamiliar immutable wrappers of immutable arrays will get their own Leaf. Moving the isbits check up front also seems safe and could save a couple cycles on dict lookups.

function _update!(tree, x; grads, params)
  isbits(tree) && return x  # means () is not cached, and also (((),),)
  isnum = isnumeric(x)  
  isnum && haskey(params, (tree,x)) && return params[(tree,x)]
  children, re = functor(x)
  children′ = map((tᵢ, xᵢ) -> _update!(tᵢ, xᵢ; grads, params), tree, children)
  x′ = re(children′)
  isnum ? (params[(tree,x)] = x′) : x′
end

It's likely this can be simplified, but I wanted to get something on the page first in case there are any unforeseen edge cases present in this formulation.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think anything isnumeric should have a corresponding Leaf and hit the _update!(::Leaf, x; ...) method.

This one wants only to deal with mutable non-leaf things, like my mutable struct MutTwo example. Which makes me think that ismutable is fine -- we have Foo(MutTwo(Bar(Transpose(Array, then the Array is leaf, and the only level at which it's worthwhile for this method to cache anything is the MutTwo one. If this whole stack appears twice, a fresh new struct Foo cannot be distinguished from the old one.

@mcabbott
Copy link
Member Author

mcabbott commented Oct 11, 2022

Shall we do this?

I don't love it, and feel a bit bad about re-writing #100 in order to understand it... but this does add some features in the end.

But I do think we ought to handle shared parameters, and that we want mutable Leaf for other reasons too. (Namely: It enables freeze!. It allows for a Flux.train! without manually passing the state.)

We can re-write the internals if FluxML/Functors.jl#43 or something allows for a prettier version. The tests are pretty good.

Maybe if isbits(x) should be if !Functors.anymutable(x) from FluxML/Functors.jl#39 . Or a copy of that function if it's only in Functors 0.4 & we don't want to wait.

Edit: In fact perhaps setup can be simplified by from FluxML/Functors.jl#39 already, since fmapstructure will not create spurious ties? But it still needs a trainable walk, and maybe that's better done after FluxML/Functors.jl#43 too.

@ToucheSir
Copy link
Member

I have no objections assuming we're not considering any behavioural changes after those Functors PRs are merged.

@darsnack
Copy link
Member

I am also okay with doing this

@mcabbott mcabbott marked this pull request as ready for review October 12, 2022 22:35
@mcabbott
Copy link
Member Author

Ok let's do it.

@mcabbott mcabbott merged commit 9c12e5d into FluxML:master Oct 13, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

"Optimisers.jl does not at present handle tied weights, sorry."
3 participants