Skip to content

Commit

Permalink
Fix BCELoss adressing #1192
Browse files Browse the repository at this point in the history
  • Loading branch information
erogol committed Jul 12, 2022
1 parent eefd482 commit a6f73a1
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 13 deletions.
4 changes: 2 additions & 2 deletions TTS/tts/configs/tacotron_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ class TacotronConfig(BaseTTSConfig):
enable /disable the Stopnet that predicts the end of the decoder sequence. Defaults to True.
stopnet_pos_weight (float):
Weight that is applied to over-weight positive instances in the Stopnet loss. Use larger values with
datasets with longer sentences. Defaults to 10.
datasets with longer sentences. Defaults to 0.2.
max_decoder_steps (int):
Max number of steps allowed for the decoder. Defaults to 50.
encoder_in_features (int):
Expand Down Expand Up @@ -161,7 +161,7 @@ class TacotronConfig(BaseTTSConfig):
prenet_dropout_at_inference: bool = False
stopnet: bool = True
separate_stopnet: bool = True
stopnet_pos_weight: float = 10.0
stopnet_pos_weight: float = 0.2
max_decoder_steps: int = 500
encoder_in_features: int = 256
decoder_in_features: int = 256
Expand Down
16 changes: 6 additions & 10 deletions TTS/tts/layers/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,19 +147,16 @@ def forward(self, align):
"""
Forces attention to be more decisive by penalizing
soft attention weights
TODO: arguments
TODO: unit_test
"""
entropy = torch.distributions.Categorical(probs=align).entropy()
loss = (entropy / np.log(align.shape[1])).mean()
return loss


class BCELossMasked(nn.Module):
def __init__(self, pos_weight):
def __init__(self, pos_weight:float=None):
super().__init__()
self.pos_weight = pos_weight
self.pos_weight = torch.tensor([pos_weight])

def forward(self, x, target, length):
"""
Expand All @@ -179,16 +176,15 @@ class for each corresponding step.
Returns:
loss: An average loss value in range [0, 1] masked by the length.
"""
# mask: (batch, max_len, 1)
target.requires_grad = False
if length is not None:
mask = sequence_mask(sequence_length=length, max_len=target.size(1)).float()
x = x * mask
target = target * mask
# mask: (batch, max_len, 1)
mask = sequence_mask(sequence_length=length, max_len=target.size(1))
num_items = mask.sum()
loss = functional.binary_cross_entropy_with_logits(x.masked_select(mask), target.masked_select(mask), pos_weight=self.pos_weight, reduction="sum")
else:
loss = functional.binary_cross_entropy_with_logits(x, target, pos_weight=self.pos_weight, reduction="sum")
num_items = torch.numel(x)
loss = functional.binary_cross_entropy_with_logits(x, target, pos_weight=self.pos_weight, reduction="sum")
loss = loss / num_items
return loss

Expand Down
39 changes: 38 additions & 1 deletion tests/tts_tests/test_losses.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import unittest
import torch as T
from torch.nn import functional

from TTS.tts.utils.helpers import sequence_mask
from TTS.tts.layers.losses import L1LossMasked, SSIMLoss, MSELossMasked
from TTS.tts.layers.losses import L1LossMasked, SSIMLoss, MSELossMasked, BCELossMasked


class L1LossMaskedTests(unittest.TestCase):
Expand Down Expand Up @@ -200,3 +201,39 @@ def test_in_out(self): # pylint: disable=no-self-use
mask = ((sequence_mask(dummy_length).float() - 1.0) * 100.0).unsqueeze(2)
output = layer(dummy_input + mask, dummy_target, dummy_length)
assert output.item() == 0, "0 vs {}".format(output.item())


class BCELossTest(unittest.TestCase):
def test_in_out(self): # pylint: disable=no-self-use
layer = BCELossMasked(pos_weight=5.0)

length = T.tensor([95])
mask = sequence_mask(length, 100)
pos_weight = T.tensor([5.0])
target = 1. - sequence_mask(length - 1, 100).float() # [0, 0, .... 1, 1] where the first 1 is the last mel frame
true_x = target * 200 - 100 # creates logits of [-100, -100, ... 100, 100] corresponding to target
zero_x = T.zeros(target.shape) - 100. # simulate logits if it never stops decoding
early_x = -200. * sequence_mask(length - 3, 100).float() + 100. # simulate logits on early stopping
late_x = -200. * sequence_mask(length + 1, 100).float() + 100. # simulate logits on late stopping

loss = layer(true_x, target, length)
self.assertEqual(loss.item(), 0.0)

loss = layer(early_x, target, length)
self.assertAlmostEqual(loss.item(), 2.1053, places=4)

loss = layer(late_x, target, length)
self.assertAlmostEqual(loss.item(), 5.2632, places=4)

loss = layer(zero_x, target, length)
self.assertAlmostEqual(loss.item(), 5.2632, places=4)

# pos_weight should be < 1 to penalize early stopping
layer = BCELossMasked(pos_weight=0.2)
loss = layer(true_x, target, length)
self.assertEqual(loss.item(), 0.0)

# when pos_weight < 1 overweight the early stopping loss
loss_early = layer(early_x, target, length)
loss_late = layer(late_x, target, length)
self.assertGreater(loss_early.item(), loss_late.item())

0 comments on commit a6f73a1

Please sign in to comment.