Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make loss(f,x,y) == loss(f(x), y) #2090

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 13 additions & 8 deletions docs/src/models/losses.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,21 +4,26 @@ Flux provides a large number of common loss functions used for training machine
They are grouped together in the `Flux.Losses` module.

Loss functions for supervised learning typically expect as inputs a target `y`, and a prediction `ŷ` from your model.
In Flux's convention, the order of the arguments is the following
In Flux's convention, the target is the last argumemt:

```julia
loss(ŷ, y)
```

Most loss functions in Flux have an optional argument `agg`, denoting the type of aggregation performed over the
batch:
All loss functions in Flux have a method which takes the model as the first argument, and calculates the prediction `ŷ = model(x)`.
This is convenient for [`train!`](@ref Flux.train)`(loss, model, [(x,y), (x2,y2), ...], opt)`:

```julia
loss(ŷ, y) # defaults to `mean`
loss(ŷ, y, agg=sum) # use `sum` for reduction
loss(ŷ, y, agg=x->sum(x, dims=2)) # partial reduction
loss(ŷ, y, agg=x->mean(w .* x)) # weighted mean
loss(ŷ, y, agg=identity) # no aggregation.
loss(model, x, y) = loss(model(x), y)
```
Copy link
Member

@darsnack darsnack Oct 20, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

GH won't let me suggest on this easily, but right now, it almost reads like you need to define the 3-arg loss to work with train! (which is the exact opposite intent!). Something like

All loss functions in Flux have a method which takes the model as the first argument, and calculates the prediction `ŷ = model(x)`, and finally the loss `loss(ŷ, y)`. This is convenient for passing the loss function directly to [`train!`](@ref Flux.train)`(loss, model, [(x,y), (x2,y2), ...], opt)`. For a custom loss, you can replicate this as:
```julia
myloss(model, x, y) = myloss(model(x), y)
```

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes I wondered this too. In this doc section "loss" is an example of any built-in one.

I wonder if it should use say mse everywhere, and say "Flux has a method like this already defined:"?

Copy link
Member

@darsnack darsnack Oct 20, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, maybe it is clearer to start this section by saying something like "Using Flux.Losses.mse as an example, ...". Then say, for this specific point,

All loss functions in Flux have a method which takes the model as the first argument, and calculates the loss such that
```julia
Flux.Losses.mse(model, x, y) == Flux.Losses.mse(model(x), y)
```
This is convenient for passing the loss function directly to [`train!`](@ref Flux.train)`(loss, model, [(x,y), (x2,y2), ...], opt)`.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Turns out I was half-done with changing this section locally to work through defining a new one, rather than listing properties of existing ones. See what you think? Agree that if it does discuss existing ones, it should be ==.


Most loss functions in Flux have an optional keyword argument `agg`, which is the aggregation function used over the batch:

```julia
loss(ŷ, y) # defaults to `Statistics.mean`
loss(ŷ, y; agg = sum) # use `sum` instead
loss(ŷ, y; agg = x->mean(w .* x)) # weighted mean
loss(ŷ, y; agg = x->sum(x, dims=2)) # partial reduction, returns an array
```

### Function listing
Expand Down
30 changes: 27 additions & 3 deletions src/losses/Losses.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@ using CUDA
using NNlib: logsoftmax, logσ, ctc_loss, ctc_alpha, ∇ctc_loss
import Base.Broadcast: broadcasted

export mse, mae, msle,
label_smoothing,
export label_smoothing,
mse, mae, msle,
crossentropy, logitcrossentropy,
binarycrossentropy, logitbinarycrossentropy,
kldivergence,
Expand All @@ -19,9 +19,33 @@ export mse, mae, msle,
dice_coeff_loss,
poisson_loss,
hinge_loss, squared_hinge_loss,
binary_focal_loss, focal_loss, siamese_contrastive_loss
binary_focal_loss, focal_loss,
siamese_contrastive_loss

include("utils.jl")
include("functions.jl")

for loss in Symbol.([
mse, mae, msle,
crossentropy, logitcrossentropy,
binarycrossentropy, logitbinarycrossentropy,
kldivergence,
huber_loss,
tversky_loss,
dice_coeff_loss,
poisson_loss,
hinge_loss, squared_hinge_loss,
binary_focal_loss, focal_loss,
siamese_contrastive_loss,
])
@eval begin
"""
$($loss)(model, x, y)

This method calculates `ŷ = model(x)`. Accepts the same keyword arguments.
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Kept this docstring short. Not so sure whether or not it will show up in the doc listing, nor whether it should.

"""
$loss(f, x::AbstractArray, y::AbstractArray; kw...) = $loss(f(x), y; kw...)
end
end

end #module
9 changes: 9 additions & 0 deletions test/losses.jl
Original file line number Diff line number Diff line change
Expand Up @@ -248,3 +248,12 @@ end
@test_throws DomainError(-0.5, "Margin must be non-negative") Flux.siamese_contrastive_loss(ŷ1, y1, margin = -0.5)
@test_throws DomainError(-1, "Margin must be non-negative") Flux.siamese_contrastive_loss(ŷ, y, margin = -1)
end

@testset "3-arg methods" begin
@testset for loss in ALL_LOSSES
fun(x) = x[1:2]
x = rand(3)
y = rand(2)
@test loss(fun, x, y) == loss(fun(x), y)
end
end