Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make loss(f,x,y) == loss(f(x), y) #2090

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open

Conversation

mcabbott
Copy link
Member

@mcabbott mcabbott commented Oct 18, 2022

If train! stops accepting implicit parameters, as in #2082, then its loss function needs to accept the model as an argument, rather than close over it.

This makes all the built-in ones do so, to avoid defining loss(m,x,y) = mse(m(x), y) etc. yourself every time.

(Defining loss(x,y) = mse(model(x), y) every time used to be the idiom for closing over the model, and IMO this is pretty confusing. It means "loss function" means two things. Cleaner to delete this entirely than to update it to a 3-arg version.)

PR Checklist

  • Tests are added
  • Entry in NEWS.md
  • Documentation, if applicable

Comment on lines 13 to 18
All loss functions in Flux have a method which takes the model as the first argument, and calculates the prediction `ŷ = model(x)`.
This is convenient for [`train!`](@ref Flux.train)`(loss, model, [(x,y), (x2,y2), ...], opt)`:

```julia
loss(ŷ, y) # defaults to `mean`
loss(ŷ, y, agg=sum) # use `sum` for reduction
loss(ŷ, y, agg=x->sum(x, dims=2)) # partial reduction
loss(ŷ, y, agg=x->mean(w .* x)) # weighted mean
loss(ŷ, y, agg=identity) # no aggregation.
loss(model, x, y) = loss(model(x), y)
```
Copy link
Member

@darsnack darsnack Oct 20, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

GH won't let me suggest on this easily, but right now, it almost reads like you need to define the 3-arg loss to work with train! (which is the exact opposite intent!). Something like

All loss functions in Flux have a method which takes the model as the first argument, and calculates the prediction `ŷ = model(x)`, and finally the loss `loss(ŷ, y)`. This is convenient for passing the loss function directly to [`train!`](@ref Flux.train)`(loss, model, [(x,y), (x2,y2), ...], opt)`. For a custom loss, you can replicate this as:
```julia
myloss(model, x, y) = myloss(model(x), y)
```

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes I wondered this too. In this doc section "loss" is an example of any built-in one.

I wonder if it should use say mse everywhere, and say "Flux has a method like this already defined:"?

Copy link
Member

@darsnack darsnack Oct 20, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, maybe it is clearer to start this section by saying something like "Using Flux.Losses.mse as an example, ...". Then say, for this specific point,

All loss functions in Flux have a method which takes the model as the first argument, and calculates the loss such that
```julia
Flux.Losses.mse(model, x, y) == Flux.Losses.mse(model(x), y)
```
This is convenient for passing the loss function directly to [`train!`](@ref Flux.train)`(loss, model, [(x,y), (x2,y2), ...], opt)`.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Turns out I was half-done with changing this section locally to work through defining a new one, rather than listing properties of existing ones. See what you think? Agree that if it does discuss existing ones, it should be ==.

@darsnack
Copy link
Member

A NEWS entry for this feature would be good too

"""
$($loss)(model, x, y)

This method calculates `ŷ = model(x)`. Accepts the same keyword arguments.
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Kept this docstring short. Not so sure whether or not it will show up in the doc listing, nor whether it should.

@ToucheSir
Copy link
Member

ToucheSir commented Oct 20, 2022

Sorry, I have to say that I'm really not a fan of this signature because it excludes a bunch of models while adding one more thing to know for loss function authors. For example, what does mse(m, x, y) even mean if you're doing self-supervised learning and m is some siamese network?

Given that the existing train! API requires users to define their own zero-arg "loss" function already, could we not keep that constraint (bring your own function) and pass in (m, x, y)? This would be strictly less confusing than the status quo and we could rename the callback to "forward pass" or some such.

@mcabbott
Copy link
Member Author

mcabbott commented Oct 20, 2022

Yes I agree it's specialised to some uses.

It just seems slightly weird to force people to define a function which is the just adjusting the signature to work, not doing any work or making any choices. They are forced to do so now because, in addition, this function closes over the model. So it must be re-defined if you change the model.

I suppose it seems especially odd if the "official" documented way is that you must name this trivial function. And perhaps writing always something like this would be less odd:

train!(model, [(x1,y1), (x2,y2), ...], opt) do m,x,y
  mse(m(x), y)
end

However, that's still quite a bit of boilerplate to say "use mse". And I know some people find the do super-confusing at first.

@ToucheSir
Copy link
Member

If it were just a matter of clarifying how the do syntax works, we could address this with a docs issue. But to the brevity point, ideally we'd be able to extract out some loss(f,x,y) = loss(f(x), y) helper so that individual loss functions don't have to be responsible for being model-aware? It would be one more verb/noun to learn, but it would save us confused users who ported over a loss(x, y) function from some other library and don't understand the resulting MethodError (I'm assuming that if they don't understand do, they'd have a hard time with this too). If this wrapper were a named type, there's even a chance to toss in optimization state and thus simplify #2082, but I haven't thought too hard about that yet.

@mcabbott
Copy link
Member Author

confused users who ported over a loss(x, y) function from some other library and don't understand the resulting MethodError

Right now this is worse, loss(x, y) = norm(x - y) will result in zero gradients but no error.

For implicit-Flux, having methods like mse(m) = (x,y) -> mse(m(x), y) would allow train!(mse(model), params(model), data, opt) which is less obscure than what we have now. Or it could be spelled train!(applyloss(mse, model), params(model), data, opt) with one more verb.

For explicit-Flux, we could have train!(applyloss(mse), model, data, opt). Not a big fan of a verb exclusively to translate built-in loss to what built-in train! wants. though.

We could also just make train!(loss, model, [(x1,y1), (x2,y2)], opt) call loss(model(x1), y1). The rule is then loss(model(data[1][1]), data[1][2:end]...) instead of loss(model, data[1]...) in #2082? Not sure either.

@ToucheSir
Copy link
Member

ToucheSir commented Oct 22, 2022

Right now this is worse, loss(x, y) = norm(x - y) will result in zero gradients but no error.

Yeah, that's a good argument for having the model(x) part more explicit. Is it too bad to ask users to write a 2-liner?

data = [(x1,y1), (x2,y2), ...]
train!((m, x, y) -> mse(m(x), y), model, data, opt)

Most users can directly copy-paste this, and those who have more complex forward passes can either define a separate function or ease into learning the do syntax. But if one just wants to add a regularization term?

train!((m, x, y) -> mse(m(x), y) + Optimisers.total(norm, m), model, data, opt)

@mcabbott
Copy link
Member Author

OK, https://fluxml.ai/Flux.jl/previews/PR2114/training/training/ takes this view that we should just always make an anon. function. It emphasises gradient + update over train!, and for gradient you are always going to want that. And it explains the do block several times.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants