Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TTS] Add Mixed Representation Training #3473

Merged
merged 35 commits into from
Feb 11, 2022
Merged
Show file tree
Hide file tree
Changes from 27 commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
25e7dff
Update CMUdict with ADLR version pronunciations
redoctopus Jan 14, 2022
bf21276
minor updates for finetuning
blisc Jan 18, 2022
6bb0f0b
update conf
blisc Jan 19, 2022
8db6e5c
merge
blisc Jan 19, 2022
ed6121f
Merge remote-tracking branch 'nvidia/cmudict_update' into tts_finetun…
blisc Jan 19, 2022
725e15c
update
blisc Jan 19, 2022
d0f7edc
update
blisc Jan 19, 2022
eacd2ba
bug fixes
blisc Jan 19, 2022
3793c13
update config
blisc Jan 20, 2022
729fc4c
bf16 support
blisc Jan 20, 2022
539a4ab
bf16 support
blisc Jan 20, 2022
05309f1
bugfix
blisc Jan 21, 2022
b479e7a
update
blisc Jan 24, 2022
3869f79
finalize changes
blisc Jan 26, 2022
cdc1f6c
merge with main
blisc Feb 1, 2022
23aeb2c
undo notebook 1.6.0 pins
blisc Feb 1, 2022
0d33f71
more 1.6.0 undos
blisc Feb 1, 2022
780828c
wip
blisc Feb 3, 2022
b54f390
update num_workers
blisc Feb 3, 2022
b09dee9
update hypers
blisc Feb 3, 2022
a4124e3
revert to main _align yamls
blisc Feb 4, 2022
7898a40
update yamls
blisc Feb 4, 2022
1342c23
cleanup
blisc Feb 4, 2022
832f431
merge with main
blisc Feb 4, 2022
d8f796e
remove unnecessary line
blisc Feb 7, 2022
a31e824
address comments
blisc Feb 7, 2022
4f09e27
merge
blisc Feb 7, 2022
341deb2
update vocoder mel uploading; add contextmanager to mixed g2p
blisc Feb 10, 2022
1201136
update comments; make prob required argument
blisc Feb 11, 2022
bc218ce
added val check
blisc Feb 11, 2022
0f9386e
update message
blisc Feb 11, 2022
8d0baaa
update
blisc Feb 11, 2022
2a34a70
Merge remote-tracking branch 'nvidia/main' into tts_finetuning_updates
blisc Feb 11, 2022
28d1fd5
revert num workers
blisc Feb 11, 2022
3e3fc21
Merge branch 'main' into tts_finetuning_updates
blisc Feb 11, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Jenkinsfile
Original file line number Diff line number Diff line change
Expand Up @@ -2017,7 +2017,7 @@ pipeline {
}
}


stage('L2: Megatron GPT Convert from Megatron-LM checkpoing and Eval') {
when {
anyOf {
Expand Down
7 changes: 5 additions & 2 deletions examples/tts/conf/fastpitch_align_v1.05.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ model:
_target_: nemo.collections.tts.torch.g2ps.EnglishG2p
phoneme_dict: ${phoneme_dict_path}
heteronyms: ${heteronyms_path}
phoneme_probability: 0.5

train_ds:
dataset:
Expand All @@ -101,12 +102,13 @@ model:
pitch_norm: true
pitch_mean: ${model.pitch_mean}
pitch_std: ${model.pitch_std}
use_beta_binomial_interpolator: true

dataloader_params:
drop_last: false
shuffle: true
batch_size: 32
num_workers: 12
num_workers: 0
Oktai15 marked this conversation as resolved.
Show resolved Hide resolved

validation_ds:
dataset:
Expand All @@ -131,12 +133,13 @@ model:
pitch_norm: true
pitch_mean: ${model.pitch_mean}
pitch_std: ${model.pitch_std}
use_beta_binomial_interpolator: true

dataloader_params:
drop_last: false
shuffle: false
batch_size: 32
num_workers: 8
num_workers: 0

preprocessor:
_target_: nemo.collections.asr.modules.AudioToMelSpectrogramPreprocessor
Expand Down
12 changes: 6 additions & 6 deletions nemo/collections/tts/models/fastpitch.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,20 +310,20 @@ def training_step(self, batch, batch_idx):

self.tb_logger.add_image(
"train_mel_target",
plot_spectrogram_to_numpy(mels[0].data.cpu().numpy()),
plot_spectrogram_to_numpy(mels[0].data.cpu().float().numpy()),
self.global_step,
dataformats="HWC",
)
spec_predict = mels_pred[0].data.cpu().numpy()
spec_predict = mels_pred[0].data.cpu().float().numpy()
self.tb_logger.add_image(
"train_mel_predicted", plot_spectrogram_to_numpy(spec_predict), self.global_step, dataformats="HWC",
)
if self.learn_alignment:
attn = attn_hard[0].data.cpu().numpy().squeeze()
attn = attn_hard[0].data.cpu().float().numpy().squeeze()
self.tb_logger.add_image(
"train_attn", plot_alignment_to_numpy(attn.T), self.global_step, dataformats="HWC",
)
soft_attn = attn_soft[0].data.cpu().numpy().squeeze()
soft_attn = attn_soft[0].data.cpu().float().numpy().squeeze()
self.tb_logger.add_image(
"train_soft_attn", plot_alignment_to_numpy(soft_attn.T), self.global_step, dataformats="HWC",
)
Expand Down Expand Up @@ -394,11 +394,11 @@ def validation_epoch_end(self, outputs):
if isinstance(self.logger, TensorBoardLogger):
self.tb_logger.add_image(
"val_mel_target",
plot_spectrogram_to_numpy(spec_target[0].data.cpu().numpy()),
plot_spectrogram_to_numpy(spec_target[0].data.cpu().float().numpy()),
self.global_step,
dataformats="HWC",
)
spec_predict = spec_predict[0].data.cpu().numpy()
spec_predict = spec_predict[0].data.cpu().float().numpy()
self.tb_logger.add_image(
"val_mel_predicted", plot_spectrogram_to_numpy(spec_predict), self.global_step, dataformats="HWC",
)
Expand Down
32 changes: 24 additions & 8 deletions nemo/collections/tts/torch/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,9 +134,12 @@ def __init__(

# Initialize text tokenizer
self.text_tokenizer = text_tokenizer

self.phoneme_probability = None
if isinstance(self.text_tokenizer, BaseTokenizer):
self.text_tokenizer_pad_id = text_tokenizer.pad
self.tokens = text_tokenizer.tokens
self.phoneme_probability = self.text_tokenizer.phoneme_probability
else:
if text_tokenizer_pad_id is None:
raise ValueError(f"text_tokenizer_pad_id must be specified if text_tokenizer is not BaseTokenizer")
Expand All @@ -146,6 +149,7 @@ def __init__(

self.text_tokenizer_pad_id = text_tokenizer_pad_id
self.tokens = tokens
self.cache_text = True if self.phoneme_probability is None else False

# Initialize text normalizer is specified
self.text_normalizer = text_normalizer
Expand Down Expand Up @@ -179,15 +183,14 @@ def __init__(

if "normalized_text" not in item:
text = item["text"]

if self.text_normalizer is not None:
text = self.text_normalizer_call(text, **self.text_normalizer_call_kwargs)

file_info["normalized_text"] = text
file_info["text_tokens"] = self.text_tokenizer(text)
else:
file_info["normalized_text"] = item["normalized_text"]
file_info["text_tokens"] = self.text_tokenizer(item["normalized_text"])

if self.cache_text:
file_info["text_tokens"] = self.text_tokenizer(file_info["normalized_text"])

data.append(file_info)

Expand Down Expand Up @@ -241,6 +244,7 @@ def __init__(
hop_length=self.hop_len,
win_length=self.win_length,
window=window_fn(self.win_length, periodic=False).to(torch.float) if window_fn else None,
return_complex=True,
)

# Initialize sup_data_path, sup_data_types and run preprocessing methods for every supplementary data type
Expand Down Expand Up @@ -331,6 +335,13 @@ def add_align_prior_matrix(self, **kwargs):
self.align_prior_matrix_folder.mkdir(exist_ok=True, parents=True)

self.use_beta_binomial_interpolator = kwargs.pop('use_beta_binomial_interpolator', False)
if not self.cache_text:
if 'use_beta_binomial_interpolator' in kwargs and not self.use_beta_binomial_interpolator:
logging(
Oktai15 marked this conversation as resolved.
Show resolved Hide resolved
"phoneme_probability is not None, but use_beta_binomial_interpolator=False, we"
" set use_beta_binomial_interpolator=True manually to use phoneme_probability."
)
self.use_beta_binomial_interpolator = True
Oktai15 marked this conversation as resolved.
Show resolved Hide resolved

if self.use_beta_binomial_interpolator:
self.beta_binomial_interpolator = BetaBinomialInterpolator()
Expand Down Expand Up @@ -386,9 +397,13 @@ def __getitem__(self, index):
features = self.featurizer.process(sample["audio_filepath"], trim=self.trim)
audio, audio_length = features, torch.tensor(features.shape[0]).long()

# Load text
text = torch.tensor(sample["text_tokens"]).long()
text_length = torch.tensor(len(sample["text_tokens"])).long()
if "text_tokens" in sample:
text = torch.tensor(sample["text_tokens"]).long()
text_length = torch.tensor(len(sample["text_tokens"])).long()
else:
tokenized = self.text_tokenizer(sample["normalized_text"])
text = torch.tensor(tokenized).long()
text_length = torch.tensor(len(tokenized)).long()

# Load mel if needed
log_mel, log_mel_length = None, None
Expand Down Expand Up @@ -417,9 +432,10 @@ def __getitem__(self, index):
# Load alignment prior matrix if needed
align_prior_matrix = None
if AlignPriorMatrix in self.sup_data_types_set:
align_prior_matrix = None
if self.use_beta_binomial_interpolator:
mel_len = self.get_log_mel(audio).shape[2]
align_prior_matrix = torch.from_numpy(self.beta_binomial_interpolator(mel_len, text_length.item()))
mel_len = self.get_log_mel(audio).shape[2]
else:
prior_path = self.align_prior_matrix_folder / f"{rel_audio_path_as_text_id}.pt"

Expand Down
9 changes: 9 additions & 0 deletions nemo/collections/tts/torch/g2ps.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,10 @@

import abc
import pathlib
import random
import re
import time
from typing import Optional

import nltk
import torch
Expand Down Expand Up @@ -53,6 +55,7 @@ def __init__(
ignore_ambiguous_words=True,
heteronyms=None,
encoding='latin-1',
phoneme_probability: Optional[float] = None,
):
"""English G2P module. This module converts words from grapheme to phoneme representation using phoneme_dict in CMU dict format.
Optionally, it can ignore words which are heteronyms, ambiguous or marked as unchangeable by word_tokenize_func (see code for details).
Expand All @@ -67,6 +70,7 @@ def __init__(
ignore_ambiguous_words: Whether to not handle word via phoneme_dict with ambiguous phoneme sequences. Defaults to True.
heteronyms (str, Path, List): Path to file with heteronyms (every line is new word) or list of words.
encoding: Encoding type.
phoneme_probability (Optional[float]): The probability (0.<var<1.) that each word is phonemized. Defaults to None which is the same as 1.
Oktai15 marked this conversation as resolved.
Show resolved Hide resolved
"""
phoneme_dict = (
self._parse_as_cmu_dict(phoneme_dict, encoding)
Expand All @@ -91,6 +95,8 @@ def __init__(
if isinstance(heteronyms, str) or isinstance(heteronyms, pathlib.Path)
else heteronyms
)
self.phoneme_probability = phoneme_probability
self._rng = random.Random()

@staticmethod
def _parse_as_cmu_dict(phoneme_dict_path=None, encoding='latin-1'):
Expand Down Expand Up @@ -163,6 +169,9 @@ def parse_one_word(self, word: str):
`status` will be `False` if word wasn't handled, `True` otherwise.
"""

if self.phoneme_probability is not None and self._rng.random() > self.phoneme_probability:
return word, True

# punctuation
if re.search("[a-zA-Z]", word) is None:
return list(word), True
Expand Down
14 changes: 11 additions & 3 deletions nemo/collections/tts/torch/tts_tokenizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,9 @@ def __init__(
Note that lower() function shouldn't applied here, because text can contains phonemes (it will be handled by g2p).
Oktai15 marked this conversation as resolved.
Show resolved Hide resolved
"""

self.phoneme_probability = None
if hasattr(g2p, "phoneme_probability"):
self.phoneme_probability = g2p.phoneme_probability
tokens = []
self.space, tokens = len(tokens), tokens + [space] # Space

Expand All @@ -295,7 +298,12 @@ def __init__(
vowels = [f'{p}{s}' for p, s in itertools.product(vowels, (0, 1, 2))]
tokens.extend(vowels)

if chars:
if chars or self.phoneme_probability is not None:
Oktai15 marked this conversation as resolved.
Show resolved Hide resolved
if not chars:
logging.warning(
"phoneme_probability was not None, characters will be enabled even though "
"chars was set to False."
)
tokens.extend(string.ascii_lowercase)

if apostrophe:
Expand All @@ -308,7 +316,7 @@ def __init__(

super().__init__(tokens, oov=oov, sep=sep, add_blank_at=add_blank_at)

self.chars = chars
self.chars = chars if self.phoneme_probability is None else True
self.punct = punct
self.stresses = stresses
self.pad_with_space = pad_with_space
Expand All @@ -321,7 +329,7 @@ def encode(self, text):
ps, space, tokens = [], self.tokens[self.space], set(self.tokens)

text = self.text_preprocessing_func(text)
g2p_text = self.g2p(text)
g2p_text = self.g2p(text) # TODO: handle infer

for p in g2p_text: # noqa
# Remove stress
Expand Down