Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make loss(f,x,y) == loss(f(x), y) #2090

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

## v0.13.7
* Added [`@autosize` macro](https://github.com/FluxML/Flux.jl/pull/2078)
* All loss functions now [accept 3 arguments](https://github.com/FluxML/Flux.jl/pull/2090), `loss(model, x, y) == loss(model(x), y)`.

## v0.13.4
* Added [`PairwiseFusion` layer](https://github.com/FluxML/Flux.jl/pull/1983)
Expand Down
25 changes: 16 additions & 9 deletions docs/src/models/losses.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,21 +4,28 @@ Flux provides a large number of common loss functions used for training machine
They are grouped together in the `Flux.Losses` module.

Loss functions for supervised learning typically expect as inputs a target `y`, and a prediction `ŷ` from your model.
In Flux's convention, the order of the arguments is the following
In Flux's convention, the target is the last argumemt, so a new loss function could be defined:

```julia
loss(ŷ, y)
newloss(ŷ, y) = sum(abs2, ŷ .- y) # total squared error
```

Most loss functions in Flux have an optional argument `agg`, denoting the type of aggregation performed over the
batch:
All loss functions in Flux have a method which takes the model as the first argument, and calculates the prediction `ŷ = model(x)`.
This is convenient for [`train!`](@ref Flux.train)`(loss, model, [(x,y), (x2,y2), ...], opt)`.
For our example it could be defined:

```julia
loss(ŷ, y) # defaults to `mean`
loss(ŷ, y, agg=sum) # use `sum` for reduction
loss(ŷ, y, agg=x->sum(x, dims=2)) # partial reduction
loss(ŷ, y, agg=x->mean(w .* x)) # weighted mean
loss(ŷ, y, agg=identity) # no aggregation.
newloss(model, x, y) = newloss(model(x), y)
```

Most loss functions in Flux have an optional keyword argument `agg`, which is the aggregation function used over the batch.
Thus you may call, for example:

```julia
crossentropy(ŷ, y) # defaults to `Statistics.mean`
crossentropy(ŷ, y; agg = sum) # use `sum` instead
crossentropy(ŷ, y; agg = x->mean(w .* x)) # weighted mean
crossentropy(ŷ, y; agg = x->sum(x, dims=2)) # partial reduction, returns an array
```

### Function listing
Expand Down
30 changes: 27 additions & 3 deletions src/losses/Losses.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@ using CUDA
using NNlib: logsoftmax, logσ, ctc_loss, ctc_alpha, ∇ctc_loss
import Base.Broadcast: broadcasted

export mse, mae, msle,
label_smoothing,
export label_smoothing,
mse, mae, msle,
crossentropy, logitcrossentropy,
binarycrossentropy, logitbinarycrossentropy,
kldivergence,
Expand All @@ -19,9 +19,33 @@ export mse, mae, msle,
dice_coeff_loss,
poisson_loss,
hinge_loss, squared_hinge_loss,
binary_focal_loss, focal_loss, siamese_contrastive_loss
binary_focal_loss, focal_loss,
siamese_contrastive_loss

include("utils.jl")
include("functions.jl")

for loss in Symbol.([
mse, mae, msle,
crossentropy, logitcrossentropy,
binarycrossentropy, logitbinarycrossentropy,
kldivergence,
huber_loss,
tversky_loss,
dice_coeff_loss,
poisson_loss,
hinge_loss, squared_hinge_loss,
binary_focal_loss, focal_loss,
siamese_contrastive_loss,
])
@eval begin
"""
$($loss)(model, x, y)

This method calculates `ŷ = model(x)`. Accepts the same keyword arguments.
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Kept this docstring short. Not so sure whether or not it will show up in the doc listing, nor whether it should.

"""
$loss(f, x::AbstractArray, y::AbstractArray; kw...) = $loss(f(x), y; kw...)
end
end

end #module
9 changes: 9 additions & 0 deletions test/losses.jl
Original file line number Diff line number Diff line change
Expand Up @@ -248,3 +248,12 @@ end
@test_throws DomainError(-0.5, "Margin must be non-negative") Flux.siamese_contrastive_loss(ŷ1, y1, margin = -0.5)
@test_throws DomainError(-1, "Margin must be non-negative") Flux.siamese_contrastive_loss(ŷ, y, margin = -1)
end

@testset "3-arg methods" begin
@testset for loss in ALL_LOSSES
fun(x) = x[1:2]
x = rand(3)
y = rand(2)
@test loss(fun, x, y) == loss(fun(x), y)
end
end