Unintuitive reduction of mini-batch loss for NLLLoss #9882
Labels
module: docs
Related to our documentation, both in docs/ and docblocks
module: loss
Problem is related to loss function
module: nn
Related to torch.nn
triaged
This issue has been looked at a team member, and triaged and prioritized into an appropriate module
I find the reduction method that was chosen for the NLLLoss quite unintutive.
This introduces a weird interdependence of the chosen class weights with the chose batch size (and more: the influence of the class weights depend on which ground-truth classes are present in the mini-batch)
Extreme case with the current implementation: with batch size one, it does not matter which class weights I choose, my net will always see the same gradients.
In other words: I would expect
F.nll_loss(..., reduce=True) == torch.mean(F.nll_los(..., reduce=False))
but this does not hold true when using different class weights.In the documentation of the CrossEntropyLoss it also says the following
Especially the sentence "The losses are averaged across observations for each minibatch." is very misleading with the current implementation if you are using class weights.
I can only guess that the reason this implementation was chosen is s.t. your loss value doesn't change when you change the class weights (which makes multiple runs with different class weights more comparable when you're just looking at the loss), but it seems to come at a cost of a very unintuitive treatment of class weights, that in my opinion is not worth it.
cc @jlin27 @mruberry @albanD @jbschlosser
The text was updated successfully, but these errors were encountered: