Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug] BCELoss should not be masked #1192

Closed
iamanigeeit opened this issue Feb 3, 2022 · 19 comments
Closed

[Bug] BCELoss should not be masked #1192

iamanigeeit opened this issue Feb 3, 2022 · 19 comments
Assignees
Labels
bug Something isn't working priority 🚨

Comments

@iamanigeeit
Copy link

I have trained Tacotron2 but during eval / inference, it often doesn't know when to stop decoding. This is a known issue in seq2seq models and i was trying to solve it in TensorFlowTTS when i gave up due to Tensorflow problems.

Training with enable_bos_eos=True helps a bit but the output is still 3x the ground truth mel length for shorter audio: see length_data_eos.csv vs length_data_no_eos.csv

One reason is the BCELossMasked criterion -- in its current form, it encourages the model never to stop decoding once it has passed mel_length. Some of the loss results don't quite make sense, as seen below:

import torch
def sequence_mask(sequence_length, max_len=None):
    if max_len is None:
        max_len = sequence_length.data.max()
    seq_range = torch.arange(max_len, dtype=sequence_length.dtype, device=sequence_length.device)
    # B x T_max
    mask = seq_range.unsqueeze(0) < sequence_length.unsqueeze(1)
    return mask

from torch.nn import functional
length = torch.tensor([95])
mask = sequence_mask(length, 100)
pos_weight = torch.tensor([5.0])
target = 1. - sequence_mask(length - 1, 100).float()  # [0, 0, .... 1, 1] where the first 1 is the last mel frame
true_x = target * 200 - 100  # creates logits of [-100, -100, ... 100, 100] corresponding to target
zero_x = torch.zeros(target.shape) - 100.  # simulate logits if it never stops decoding
early_x = -200. * sequence_mask(length - 3, 100).float() + 100.  # simulate logits on early stopping
late_x = -200. * sequence_mask(length + 1, 100).float() + 100.  # simulate logits on late stopping

# if we mask
>>> functional.binary_cross_entropy_with_logits(mask * true_x, mask * target, pos_weight=pos_weight, reduction='sum')
tensor(3.4657)  # Should be zero! It's not zero because of trailing zeros in the mask
>>> functional.binary_cross_entropy_with_logits(mask * zero_x, mask * target, pos_weight=pos_weight, reduction='sum')
tensor(503.4657)
>>> functional.binary_cross_entropy_with_logits(mask * late_x, mask * target, pos_weight=pos_weight, reduction='sum')
tensor(503.4657)  # Stopping late should be better than not stopping at all. Again due to trailing zeros in the mask
>>> functional.binary_cross_entropy_with_logits(mask * early_x, mask * target, pos_weight=pos_weight, reduction='sum')
tensor(203.4657)  # Early stopping should be worse than late stopping because the audio will be cut

# if we don't mask
>>> functional.binary_cross_entropy_with_logits(true_x, target, pos_weight=pos_weight, reduction='sum')
tensor(0.)  # correct
>>> functional.binary_cross_entropy_with_logits(zero_x, target, pos_weight=pos_weight, reduction='sum')
tensor(3000.)  # correct
>>> functional.binary_cross_entropy_with_logits(late_x, target, pos_weight=pos_weight, reduction='sum')
tensor(1000.)
>>> functional.binary_cross_entropy_with_logits(early_x, target, pos_weight=pos_weight, reduction='sum')
tensor(200.)  # still wrong

# pos_weight should be < 1 to penalize early stopping
>>> functional.binary_cross_entropy_with_logits(zero_x, target, pos_weight=torch.tensor([0.2]), reduction='sum')
tensor(120.0000)
>>> functional.binary_cross_entropy_with_logits(late_x, target, pos_weight=torch.tensor([0.2]), reduction='sum')
tensor(40.0000)
>>> functional.binary_cross_entropy_with_logits(early_x, target, pos_weight=torch.tensor([0.2]), reduction='sum')
tensor(200.)  # correct

For now i am passing length=None to avoid the mask and setting pos_weight=0.2 to experiment. Will update the training results.

Additional context

I would also propose renaming stop_tokens to either stop_probs or stop_logits depending on context. Currently, inference() produces stop_tokens that represent stop probabilities, while forward() produces the logits before sigmoid. Confusingly, both are called stop_tokens.

@iamanigeeit iamanigeeit added the bug Something isn't working label Feb 3, 2022
@erogol
Copy link
Member

erogol commented Feb 6, 2022

Nice catch also thanks for sharing the code to reproduce. It made really easy to test out your arguments.

Masking really seems like breaking the loss computation.

If you want you can send a ✨PR✨

We could also add the cases above in the tests.

Changing stop_tokens to something else is not something I'd like to even though I agree with you since it is referred in the docs and discussions.

@iamanigeeit
Copy link
Author

Nice catch also thanks for sharing the code to reproduce. It made really easy to test out your arguments.

Masking really seems like breaking the loss computation.

If you want you can send a sparklesPRsparkles

We could also add the cases above in the tests.

Changing stop_tokens to something else is not something I'd like to even though I agree with you since it is referred in the docs and discussions.

I have rerun training on LJSpeech without masking, with pos_weight=0.2 and with an added stopnet_alpha=100.0 (since reducing pos_weight reduces the stopnet loss value). It doesn't reduce the output_lengths and in fact makes it even longer: length_data_maskedloss.csv vs length_data_nomaskedloss.csv

BUT importantly, it solved one major problem. Every sample output i have checked using masked loss suffers from repetition (a well-known problem in NLP):

LJ025-0110
image

Using loss without masking:
image

The resulting silence can later be trimmed out, so it is ok, it just makes decoding slower.

@erogol
Copy link
Member

erogol commented Feb 11, 2022

And for how many steps you've trained the model and with what config?

Maybe longer training would make it better.

@iamanigeeit
Copy link
Author

I used the same config as recipes/ljspeech/tacotron2-DCA/train_tacotron_dca.py except batch_size=32 (due to GPU memory limit) and r=1 (i think r=1 is the correct one for Tacotron2). Training was for 100k steps each.

@lexkoro
Copy link
Collaborator

lexkoro commented Feb 11, 2022

r=1 is rly hard to train using tacotron2, especially if the dataset has samples with long silences.

Also you might want to try activating dropout during inference, I remember it improved my inference by a lot when using dca.

@erogol
Copy link
Member

erogol commented Feb 14, 2022

@iamanigeeit thanks for sharing your updates.

Do you intend to send a ✨PR✨ that fixes the BCELoss?

Or we can also handle it but I guess you deserve the attribution for this. That is why it'd be nice if you could send the PR.

@erogol
Copy link
Member

erogol commented Feb 14, 2022

r=1 is rly hard to train using tacotron2, especially if the dataset has samples with long silences.

Also you might want to try activating dropout during inference, I remember it improved my inference by a lot when using dca.

Finding the right DB value for trimming the silence would help. I agree that training with r=1 is a hard task when the dataset is not processed well for TTS.

@iamanigeeit
Copy link
Author

@iamanigeeit thanks for sharing your updates.

Do you intend to send a sparklesPRsparkles that fixes the BCELoss?

Or we can also handle it but I guess you deserve the attribution for this. That is why it'd be nice if you could send the PR.

Haha, i don't need the attribution... but i need help using Git as i haven't used it on multi-person projects before. I now have a version cloned from the 33aa27 commit with all my commits pushed to my own repo, including commits on other files. How do i fix this?

Here are the files to update: updates.zip

@Edresson
Copy link
Contributor

Edresson commented Feb 23, 2022

Follow some steps:

1st Fork the 🐸 TTS repository (use the button "fork" at the top of the page)

2st Clone from your Fork (dev branch). The command will be some like:
git clone https://github.com/iamanigeeit/TTS.git -b dev

3st Change the files that you need.

4st Commit the changes with the commands (obs: Change the commit message :)):

git add .
git commit -m "Commit message"

5st Push the commits to your fork with the command: git push

6st Go to your fork (https://github.com/iamanigeeit/TTS). Github will identify that you've made changes and suggest the pull request and it will show a pull request button below "Go to file", "Add file" and "code" buttons. Now you can click on the pull request button and send a pull request from your dev branch to Coqui's dev branch :).

@Edresson Edresson reopened this Feb 23, 2022
@iamanigeeit
Copy link
Author

@erogol I think losses should not be masked at all.

From the original Tacotron paper:

It’s a common practice to train sequence models with a loss mask, which masks loss on zero-padded frames. However, we found that models trained this way don’t know when to stop emitting outputs, causing repeated sounds towards the end. One simple trick to get around this problem is to also reconstruct the zero-padded frames.

This is what i've found as well.

@iamanigeeit
Copy link
Author

@Edresson thanks for the suggestion -- i've made a lot of other changes in prior commits to my own repo, so probably i will create a new fork and update from there...

@erogol
Copy link
Member

erogol commented Feb 28, 2022

I've tried no masking but masking worked better in all of my experiments although it was buggy.

@iamanigeeit
Copy link
Author

How did you manage to solve the non-terminating outputs problem?

@erogol
Copy link
Member

erogol commented Mar 1, 2022

I've not observed this problem with my models after training long enough.

@iamanigeeit
Copy link
Author

Thanks, i was thinking 100k steps on batch size 32 would be enough, but it seems not... you haven't encountered inference frequently hitting max_decoder_steps?

@erogol
Copy link
Member

erogol commented Mar 7, 2022

you can change max_decoder_steps in the model config.

@iamanigeeit
Copy link
Author

I set max_decoder_steps to a big number to test whether the model stops. I think it is a problem when the model doesn't stop generating, although limiting the decoder steps is a workaround.

@erogol
Copy link
Member

erogol commented Mar 8, 2022

Sorry I misunderstood your question.

Thanks, i was thinking 100k steps on batch size 32 would be enough, but it seems not... you haven't encountered inference frequently hitting max_decoder_steps?

Yes, I've not observed the model hitting the max_decoder_steps after enough training.

@stale stale bot added the wontfix This will not be worked on but feel free to help. label Apr 7, 2022
@coqui-ai coqui-ai deleted a comment from stale bot Apr 11, 2022
@stale stale bot removed the wontfix This will not be worked on but feel free to help. label Apr 11, 2022
@erogol
Copy link
Member

erogol commented Apr 15, 2022

@Edresson can you look at this when you have time?

@stale stale bot added the wontfix This will not be worked on but feel free to help. label May 15, 2022
@coqui-ai coqui-ai deleted a comment from stale bot May 16, 2022
@stale stale bot removed the wontfix This will not be worked on but feel free to help. label May 16, 2022
@stale stale bot added the wontfix This will not be worked on but feel free to help. label Jun 15, 2022
@coqui-ai coqui-ai deleted a comment from stale bot Jun 17, 2022
@stale stale bot removed the wontfix This will not be worked on but feel free to help. label Jun 17, 2022
erogol added a commit that referenced this issue Jul 12, 2022
@erogol erogol closed this as completed Jul 12, 2022
erogol added a commit that referenced this issue Aug 22, 2022
* Fix checkpointing GAN models (#1641)

* checkpoint sae step crash fix

* checkpoint save step crash fix

* Update gan.py

updated requested changes

* crash fix

* Fix the --model_name and --vocoder_name arguments need a <model_type> element (#1469)

Co-authored-by: Eren Gölge <erogol@hotmail.com>

* Fix Publish CI (#1597)

* Try out manylinux

* temporary removal of useless pipeline

* remove check and use only manylinux

* Try --plat-name

* Add install requirements

* Add back other actions

* Add PR trigger

* Remove conditions

* Fix sythax

* Roll back some changes

* Add other python versions

* Add test pypi upload

* Add username

* Add back __token__ as username

* Modify name of entry to testpypi

* Set it to release only

* Fix version checking

* Fix tokenizer for punc only (#1717)

* Remove redundant config field

* Fix SSIM loss

* Separate loss tests

* Fix BCELoss adressing  #1192

* Make style

* Add durations as aux input for VITS (#1694)

* Add durations as aux input for VITS

* Make style

* Fix tts_tests

* Fix test_get_aux_input

* Make lint

* feat: updated recipes and lr fix (#1718)

- updated the recipes activating more losses for more stable training
- re-enabling guided attention loss
- fixed a bug about not the correct lr fetched for logging

* Implement VitsAudioConfig (#1556)

* Implement VitsAudioConfig

* Update VITS LJSpeech recipe

* Update VITS VCTK recipe

* Make style

* Add missing decorator

* Add missing param

* Make style

* Update recipes

* Fix test

* Bug fix

* Exclude tests folder

* Make linter

* Make style

* Fix device allocation

* Fix SSIM loss correction

* Fix aux tests (#1753)

* Set n_jobs to 1 for resample script

* Delete resample test

* Set n_jobs 1 in vad test

* delete vad test

* Revert "Delete resample test"

This reverts commit bb7c846.

* Remove tests with resample

* Fix for FloorDiv Function Warning (#1760)

* Fix for Floor Function Warning

Fix for Floor Function Warning

* Adding double quotes to fix formatting

Adding double quotes to fix formatting

* Update glow_tts.py

* Update glow_tts.py

* Fix type in download_vctk.sh (#1739)

typo in comment

* Update decoder.py (#1792)

Minor comment correction.

* Update requirements.txt (#1791)

Support for #1775

* Update README.md (#1776)

Fix typo in different and code sample

* Fix & update WaveRNN vocoder model (#1749)

* Fixes KeyError bug. Adding logging to dashboard.

* Make pep8 compliant

* Make style compliant

* Still fixing style

* Fix rand_segment edge case (input_len == seg_len - 1)

* Update requirements.txt; inflect==5.6 (#1809)

New inflect version (6.0) depends on pydantic which has some issues irrelevant to 🐸 TTS. #1808 
Force inflect==5.6 (pydantic free) install to solve dependency issue.

* Update README.md; download progress bar in CLI. (#1797)

* Update README.md

- minor PR
- added model_info usage guide based on #1623 in README.md .

* "added tqdm bar for model download"

* Update manage.py

* fixed style

* fixed style

* sort imports

* Update wavenet.py (#1796)

* Update wavenet.py

Current version does not use "in_channels" argument. 
In glowTTS, we use normalizing flows and so "input dim" == "ouput dim" (channels and length). So, the existing code just uses hidden_channel sized tensor as input to first layer as well as outputs hidden_channel sized tensor. 
However, since it is a generic implementation, I believe it is better to update it for a more general use.

* "in_channels -> hidden_channels"

* Adjust default to be able to process longer sentences (#1835)

Running `tts --text "$text" --out_path …` with a somewhat longer
sentences in the text will lead to warnings like “Decoder stopped with
max_decoder_steps 500” and the sentences just being cut off in the
resulting WAV file.

This happens quite frequently when feeding longer texts (e.g. a blog
post) to `tts`. It's particular frustrating since the error is not
always obvious in the output. You have to notice that there are missing
parts. This is something other users seem to have run into as well [1].

This patch simply increases the maximum number of steps allowed for the
tacotron decoder to fix this issue, resulting in a smoother default
behavior.

[1] mozilla/TTS#734

* Fix language flags generated by espeak-ng phonemizer (#1801)

* fix language flags generated by espeak-ng phonemizer

* Style

* Updated language flag regex to consider all language codes alike

* fix get_random_embeddings --> get_random_embedding (#1726)

* fix get_random_embeddings --> get_random_embedding

function typo leads to training crash, no such function

* fix typo

get_random_embedding

* Introduce numpy and torch transforms (#1705)

* Refactor audio processing functions

* Add tests for numpy transforms

* Fix imports

* Fix imports2

* Implement bucketed weighted sampling for VITS (#1871)

* Update capacitron_layers.py (#1664)

crashing because of dimension miss match   at line no. 57
[batch, 256] vs [batch , 1, 512]
enc_out = torch.cat([enc_out, speaker_embedding], dim=-1)

* updates to dataset analysis notebooks for compatibility with latest version of TTS (#1853)

* Fix BCE loss issue (#1872)

* Fix BCE loss issue

* Remove import

* Remove deprecated files (#1873)

- samplers.py is moved
- distribute.py is replaces by the 👟Trainer

* Handle when no batch sampler (#1882)

* Fix tune wavegrad (#1844)

* fix imports in tune_wavegrad

* load_config returns Coqpit object instead None

* set action (store true) for flag "--use_cuda"; start to tune if module is running as the main program

* fix var order in the result of batch collating

* make style

* make style with black and isort

* Bump up to v0.8.0

* Add new DE Thorsten models (#1898)

- Tacotron2-DDC
- HifiGAN vocoder

Co-authored-by: manmay nakhashi <manmay.nakhashi@gmail.com>
Co-authored-by: camillem <camillem@users.noreply.github.com>
Co-authored-by: WeberJulian <julian.weber@hotmail.fr>
Co-authored-by: a-froghyar <adamfroghyar@gmail.com>
Co-authored-by: ivan provalov <iprovalo@yahoo.com>
Co-authored-by: Tsai Meng-Ting <sarah13680@gmail.com>
Co-authored-by: p0p4k <rajiv.punmiya@gmail.com>
Co-authored-by: Yuri Pourre <yuripourre@users.noreply.github.com>
Co-authored-by: vanIvan <alfa1211@gmail.com>
Co-authored-by: Lars Kiesow <lkiesow@uos.de>
Co-authored-by: rbaraglia <baraglia.r@live.fr>
Co-authored-by: jchai.me <jreus@users.noreply.github.com>
Co-authored-by: Stanislav Kachnov <42406556+geth-network@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working priority 🚨
Projects
None yet
Development

No branches or pull requests

4 participants