Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TTS] Add Mixed Representation Training #3473

Merged
merged 35 commits into from
Feb 11, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
25e7dff
Update CMUdict with ADLR version pronunciations
redoctopus Jan 14, 2022
bf21276
minor updates for finetuning
blisc Jan 18, 2022
6bb0f0b
update conf
blisc Jan 19, 2022
8db6e5c
merge
blisc Jan 19, 2022
ed6121f
Merge remote-tracking branch 'nvidia/cmudict_update' into tts_finetun…
blisc Jan 19, 2022
725e15c
update
blisc Jan 19, 2022
d0f7edc
update
blisc Jan 19, 2022
eacd2ba
bug fixes
blisc Jan 19, 2022
3793c13
update config
blisc Jan 20, 2022
729fc4c
bf16 support
blisc Jan 20, 2022
539a4ab
bf16 support
blisc Jan 20, 2022
05309f1
bugfix
blisc Jan 21, 2022
b479e7a
update
blisc Jan 24, 2022
3869f79
finalize changes
blisc Jan 26, 2022
cdc1f6c
merge with main
blisc Feb 1, 2022
23aeb2c
undo notebook 1.6.0 pins
blisc Feb 1, 2022
0d33f71
more 1.6.0 undos
blisc Feb 1, 2022
780828c
wip
blisc Feb 3, 2022
b54f390
update num_workers
blisc Feb 3, 2022
b09dee9
update hypers
blisc Feb 3, 2022
a4124e3
revert to main _align yamls
blisc Feb 4, 2022
7898a40
update yamls
blisc Feb 4, 2022
1342c23
cleanup
blisc Feb 4, 2022
832f431
merge with main
blisc Feb 4, 2022
d8f796e
remove unnecessary line
blisc Feb 7, 2022
a31e824
address comments
blisc Feb 7, 2022
4f09e27
merge
blisc Feb 7, 2022
341deb2
update vocoder mel uploading; add contextmanager to mixed g2p
blisc Feb 10, 2022
1201136
update comments; make prob required argument
blisc Feb 11, 2022
bc218ce
added val check
blisc Feb 11, 2022
0f9386e
update message
blisc Feb 11, 2022
8d0baaa
update
blisc Feb 11, 2022
2a34a70
Merge remote-tracking branch 'nvidia/main' into tts_finetuning_updates
blisc Feb 11, 2022
28d1fd5
revert num workers
blisc Feb 11, 2022
3e3fc21
Merge branch 'main' into tts_finetuning_updates
blisc Feb 11, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions Jenkinsfile
Original file line number Diff line number Diff line change
Expand Up @@ -2210,14 +2210,13 @@ pipeline {
~trainer.check_val_every_n_epoch'
}
}
// TODO(Oktai15): update it in 1.8.0 version
stage('FastPitch') {
steps {
sh 'python examples/tts/fastpitch.py \
--config-name fastpitch_align \
--config-name fastpitch_align_v1.05 \
train_dataset=/home/TestData/an4_dataset/an4_train.json \
validation_datasets=/home/TestData/an4_dataset/an4_val.json \
prior_folder=/home/TestData/an4_dataset/beta_priors \
sup_data_path=/home/TestData/an4_dataset/beta_priors \
trainer.devices="[0]" \
+trainer.limit_train_batches=1 +trainer.limit_val_batches=1 trainer.max_epochs=1 \
trainer.strategy=null \
Expand Down
3 changes: 3 additions & 0 deletions examples/tts/conf/fastpitch_align_v1.05.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ model:
_target_: nemo.collections.tts.torch.g2ps.EnglishG2p
phoneme_dict: ${phoneme_dict_path}
heteronyms: ${heteronyms_path}
phoneme_probability: 0.5

train_ds:
dataset:
Expand All @@ -101,6 +102,7 @@ model:
pitch_norm: true
pitch_mean: ${model.pitch_mean}
pitch_std: ${model.pitch_std}
use_beta_binomial_interpolator: true

dataloader_params:
drop_last: false
Expand Down Expand Up @@ -131,6 +133,7 @@ model:
pitch_norm: true
pitch_mean: ${model.pitch_mean}
pitch_std: ${model.pitch_std}
use_beta_binomial_interpolator: true

dataloader_params:
drop_last: false
Expand Down
2 changes: 1 addition & 1 deletion examples/tts/conf/hifigan/hifigan_44100.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ train_n_segments: 16384
train_max_duration: null
train_min_duration: 0.75

val_n_segments: 132096
val_n_segments: 131072
val_max_duration: null
val_min_duration: 3

Expand Down
6 changes: 6 additions & 0 deletions nemo/collections/common/data/vocabs.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import time
import unicodedata
from builtins import str as unicode
from contextlib import contextmanager
from typing import List

import nltk
Expand Down Expand Up @@ -375,3 +376,8 @@ def encode(self, text):
ps = [space] + ps + [space]

return [self._label2id[p] for p in ps]

@contextmanager
def set_phone_prob(self, prob=None):
# Add do nothing since this class doesn't support mixed g2p
yield
39 changes: 24 additions & 15 deletions nemo/collections/tts/models/fastpitch.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,13 +195,21 @@ def parser(self):
return self._parser

def parse(self, str_input: str, normalize=True) -> torch.tensor:
if self.training:
logging.warning("parse() is meant to be called in eval mode.")
Oktai15 marked this conversation as resolved.
Show resolved Hide resolved
if str_input[-1] not in [".", "!", "?"]:
Oktai15 marked this conversation as resolved.
Show resolved Hide resolved
str_input = str_input + "."

if normalize and self.text_normalizer_call is not None:
str_input = self.text_normalizer_call(str_input, **self.text_normalizer_call_kwargs)

tokens = self.parser(str_input)
if self.learn_alignment:
# Disable mixed g2p representation
with self.vocab.set_phone_prob(prob=1.0):
tokens = self.parser(str_input)
else:
# TODO(Oktai15): remove it in 1.8.0 version
tokens = self.parser(str_input)

x = torch.tensor(tokens).unsqueeze_(0).long().to(self.device)
return x
Expand Down Expand Up @@ -246,8 +254,8 @@ def forward(

@typecheck(output_types={"spect": NeuralType(('B', 'D', 'T_spec'), MelSpectrogramType())})
def generate_spectrogram(self, tokens: 'torch.tensor', speaker: int = 0, pace: float = 1.0) -> torch.tensor:
# FIXME: return masks as well?
self.eval()
if self.training:
logging.warning("generate_spectrogram() is meant to be called in eval mode.")
if isinstance(speaker, int):
speaker = torch.tensor([speaker]).to(self.device)
spect, *_ = self(text=tokens, durs=None, pitch=None, speaker=speaker, pace=pace)
Expand Down Expand Up @@ -312,20 +320,20 @@ def training_step(self, batch, batch_idx):

self.tb_logger.add_image(
"train_mel_target",
plot_spectrogram_to_numpy(mels[0].data.cpu().numpy()),
plot_spectrogram_to_numpy(mels[0].data.cpu().float().numpy()),
self.global_step,
dataformats="HWC",
)
spec_predict = mels_pred[0].data.cpu().numpy()
spec_predict = mels_pred[0].data.cpu().float().numpy()
self.tb_logger.add_image(
"train_mel_predicted", plot_spectrogram_to_numpy(spec_predict), self.global_step, dataformats="HWC",
)
if self.learn_alignment:
attn = attn_hard[0].data.cpu().numpy().squeeze()
attn = attn_hard[0].data.cpu().float().numpy().squeeze()
self.tb_logger.add_image(
"train_attn", plot_alignment_to_numpy(attn.T), self.global_step, dataformats="HWC",
)
soft_attn = attn_soft[0].data.cpu().numpy().squeeze()
soft_attn = attn_soft[0].data.cpu().float().numpy().squeeze()
self.tb_logger.add_image(
"train_soft_attn", plot_alignment_to_numpy(soft_attn.T), self.global_step, dataformats="HWC",
)
Expand Down Expand Up @@ -396,11 +404,11 @@ def validation_epoch_end(self, outputs):
if isinstance(self.logger, TensorBoardLogger):
self.tb_logger.add_image(
"val_mel_target",
plot_spectrogram_to_numpy(spec_target[0].data.cpu().numpy()),
plot_spectrogram_to_numpy(spec_target[0].data.cpu().float().numpy()),
self.global_step,
dataformats="HWC",
)
spec_predict = spec_predict[0].data.cpu().numpy()
spec_predict = spec_predict[0].data.cpu().float().numpy()
self.tb_logger.add_image(
"val_mel_predicted", plot_spectrogram_to_numpy(spec_predict), self.global_step, dataformats="HWC",
)
Expand Down Expand Up @@ -428,12 +436,13 @@ def __setup_dataloader_from_config(self, cfg, shuffle_should_be: bool = True, na
if cfg.dataset._target_ == "nemo.collections.asr.data.audio_to_text.FastPitchDataset":
dataset = instantiate(cfg.dataset, parser=self.parser)
elif cfg.dataset._target_ == "nemo.collections.tts.torch.data.TTSDataset":
dataset = instantiate(
cfg.dataset,
text_normalizer=self.normalizer,
text_normalizer_call_kwargs=self.text_normalizer_call_kwargs,
text_tokenizer=self.vocab,
)
with self.vocab.set_phone_prob(prob=None if name == "val" else self.vocab.phoneme_probability):
dataset = instantiate(
cfg.dataset,
text_normalizer=self.normalizer,
text_normalizer_call_kwargs=self.text_normalizer_call_kwargs,
text_tokenizer=self.vocab,
)
else:
# TODO(Oktai15): remove it in 1.8.0 version
dataset = instantiate(cfg.dataset)
Expand Down
6 changes: 5 additions & 1 deletion nemo/collections/tts/models/mixer_tts.py
Original file line number Diff line number Diff line change
Expand Up @@ -644,9 +644,13 @@ def generate_spectrogram(
return pred_spect

def parse(self, text: str, normalize=True) -> torch.Tensor:
Oktai15 marked this conversation as resolved.
Show resolved Hide resolved
if self.training:
logging.warning("parse() is meant to be called in eval mode.")
if normalize and self.text_normalizer_call is not None:
text = self.text_normalizer_call(text, **self.text_normalizer_call_kwargs)
return torch.tensor(self.tokenizer.encode(text)).long().unsqueeze(0).to(self.device)
with self.tokenizer.set_phone_prob(prob=1.0):
tokens = self.tokenizer.encode(text)
return torch.tensor(tokens).long().unsqueeze(0).to(self.device)

def _loader(self, cfg):
try:
Expand Down
36 changes: 28 additions & 8 deletions nemo/collections/tts/torch/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from typing import Callable, Dict, List, Optional, Union

import librosa
import numpy as np
import torch
from nemo_text_processing.text_normalization.normalize import Normalizer
from tqdm import tqdm
Expand Down Expand Up @@ -134,9 +135,12 @@ def __init__(

# Initialize text tokenizer
self.text_tokenizer = text_tokenizer

self.phoneme_probability = None
if isinstance(self.text_tokenizer, BaseTokenizer):
self.text_tokenizer_pad_id = text_tokenizer.pad
self.tokens = text_tokenizer.tokens
self.phoneme_probability = self.text_tokenizer.phoneme_probability
else:
if text_tokenizer_pad_id is None:
raise ValueError(f"text_tokenizer_pad_id must be specified if text_tokenizer is not BaseTokenizer")
Expand All @@ -146,6 +150,7 @@ def __init__(

self.text_tokenizer_pad_id = text_tokenizer_pad_id
self.tokens = tokens
self.cache_text = True if self.phoneme_probability is None else False

# Initialize text normalizer is specified
self.text_normalizer = text_normalizer
Expand Down Expand Up @@ -179,15 +184,14 @@ def __init__(

if "normalized_text" not in item:
text = item["text"]

if self.text_normalizer is not None:
text = self.text_normalizer_call(text, **self.text_normalizer_call_kwargs)

file_info["normalized_text"] = text
file_info["text_tokens"] = self.text_tokenizer(text)
else:
file_info["normalized_text"] = item["normalized_text"]
file_info["text_tokens"] = self.text_tokenizer(item["normalized_text"])

if self.cache_text:
file_info["text_tokens"] = self.text_tokenizer(file_info["normalized_text"])

data.append(file_info)

Expand Down Expand Up @@ -241,6 +245,7 @@ def __init__(
hop_length=self.hop_len,
win_length=self.win_length,
window=window_fn(self.win_length, periodic=False).to(torch.float) if window_fn else None,
return_complex=True,
)

# Initialize sup_data_path, sup_data_types and run preprocessing methods for every supplementary data type
Expand Down Expand Up @@ -331,6 +336,13 @@ def add_align_prior_matrix(self, **kwargs):
self.align_prior_matrix_folder.mkdir(exist_ok=True, parents=True)

self.use_beta_binomial_interpolator = kwargs.pop('use_beta_binomial_interpolator', False)
if not self.cache_text:
if 'use_beta_binomial_interpolator' in kwargs and not self.use_beta_binomial_interpolator:
logging.warning(
"phoneme_probability is not None, but use_beta_binomial_interpolator=False, we"
" set use_beta_binomial_interpolator=True manually to use phoneme_probability."
)
self.use_beta_binomial_interpolator = True
Oktai15 marked this conversation as resolved.
Show resolved Hide resolved

if self.use_beta_binomial_interpolator:
self.beta_binomial_interpolator = BetaBinomialInterpolator()
Expand Down Expand Up @@ -386,9 +398,13 @@ def __getitem__(self, index):
features = self.featurizer.process(sample["audio_filepath"], trim=self.trim)
audio, audio_length = features, torch.tensor(features.shape[0]).long()

# Load text
text = torch.tensor(sample["text_tokens"]).long()
text_length = torch.tensor(len(sample["text_tokens"])).long()
if "text_tokens" in sample:
text = torch.tensor(sample["text_tokens"]).long()
text_length = torch.tensor(len(sample["text_tokens"])).long()
else:
tokenized = self.text_tokenizer(sample["normalized_text"])
text = torch.tensor(tokenized).long()
text_length = torch.tensor(len(tokenized)).long()

# Load mel if needed
log_mel, log_mel_length = None, None
Expand Down Expand Up @@ -417,6 +433,7 @@ def __getitem__(self, index):
# Load alignment prior matrix if needed
align_prior_matrix = None
if AlignPriorMatrix in self.sup_data_types_set:
align_prior_matrix = None
if self.use_beta_binomial_interpolator:
mel_len = self.get_log_mel(audio).shape[2]
align_prior_matrix = torch.from_numpy(self.beta_binomial_interpolator(mel_len, text_length.item()))
Expand Down Expand Up @@ -823,7 +840,10 @@ def __getitem__(self, index):
features = self.featurizer.process(sample["audio_filepath"], trim=self.trim)
audio, audio_length = features, torch.tensor(features.shape[0]).long()

mel = torch.load(sample["mel_filepath"])
if Path(sample["mel_filepath"]).suffix == ".npy":
mel = np.load(sample["mel_filepath"])
else:
mel = torch.load(sample["mel_filepath"])
frames = math.ceil(self.n_segments / self.hop_length)

if len(audio) > self.n_segments:
Expand Down
11 changes: 11 additions & 0 deletions nemo/collections/tts/torch/g2ps.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,10 @@

import abc
import pathlib
import random
import re
import time
from typing import Optional

import nltk
import torch
Expand Down Expand Up @@ -53,6 +55,7 @@ def __init__(
ignore_ambiguous_words=True,
heteronyms=None,
encoding='latin-1',
phoneme_probability: Optional[float] = None,
):
"""English G2P module. This module converts words from grapheme to phoneme representation using phoneme_dict in CMU dict format.
Optionally, it can ignore words which are heteronyms, ambiguous or marked as unchangeable by word_tokenize_func (see code for details).
Expand All @@ -67,6 +70,9 @@ def __init__(
ignore_ambiguous_words: Whether to not handle word via phoneme_dict with ambiguous phoneme sequences. Defaults to True.
heteronyms (str, Path, List): Path to file with heteronyms (every line is new word) or list of words.
encoding: Encoding type.
phoneme_probability (Optional[float]): The probability (0.<var<1.) that each word is phonemized. Defaults to None which is the same as 1.
Oktai15 marked this conversation as resolved.
Show resolved Hide resolved
Note that this code path is only run if the word can be phonemized. For example: If the word does not have a entry in the g2p dict, it will be returned
as characters. If the word has multiple entries and ignore_ambiguous_words is True, it will be returned as characters.
"""
phoneme_dict = (
self._parse_as_cmu_dict(phoneme_dict, encoding)
Expand All @@ -91,6 +97,8 @@ def __init__(
if isinstance(heteronyms, str) or isinstance(heteronyms, pathlib.Path)
else heteronyms
)
self.phoneme_probability = phoneme_probability
self._rng = random.Random()

@staticmethod
def _parse_as_cmu_dict(phoneme_dict_path=None, encoding='latin-1'):
Expand Down Expand Up @@ -163,6 +171,9 @@ def parse_one_word(self, word: str):
`status` will be `False` if word wasn't handled, `True` otherwise.
"""

if self.phoneme_probability is not None and self._rng.random() > self.phoneme_probability:
return word, True

# punctuation
if re.search("[a-zA-Z]", word) is None:
return list(word), True
Expand Down
25 changes: 22 additions & 3 deletions nemo/collections/tts/torch/tts_tokenizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import abc
import itertools
import string
from contextlib import contextmanager
from typing import List

from nemo.collections.tts.torch.de_utils import german_text_preprocessing
Expand Down Expand Up @@ -282,6 +283,9 @@ def __init__(
Note that lower() function shouldn't applied here, because text can contains phonemes (it will be handled by g2p).
Oktai15 marked this conversation as resolved.
Show resolved Hide resolved
"""

self.phoneme_probability = None
if hasattr(g2p, "phoneme_probability"):
self.phoneme_probability = g2p.phoneme_probability
tokens = []
self.space, tokens = len(tokens), tokens + [space] # Space

Expand All @@ -295,7 +299,12 @@ def __init__(
vowels = [f'{p}{s}' for p, s in itertools.product(vowels, (0, 1, 2))]
tokens.extend(vowels)

if chars:
if chars or self.phoneme_probability is not None:
Oktai15 marked this conversation as resolved.
Show resolved Hide resolved
if not chars:
logging.warning(
"phoneme_probability was not None, characters will be enabled even though "
"chars was set to False."
)
tokens.extend(string.ascii_lowercase)

if apostrophe:
Expand All @@ -308,7 +317,7 @@ def __init__(

super().__init__(tokens, oov=oov, sep=sep, add_blank_at=add_blank_at)

self.chars = chars
self.chars = chars if self.phoneme_probability is None else True
self.punct = punct
self.stresses = stresses
self.pad_with_space = pad_with_space
Expand All @@ -321,7 +330,7 @@ def encode(self, text):
ps, space, tokens = [], self.tokens[self.space], set(self.tokens)

text = self.text_preprocessing_func(text)
g2p_text = self.g2p(text)
g2p_text = self.g2p(text) # TODO: handle infer

for p in g2p_text: # noqa
# Remove stress
Expand Down Expand Up @@ -351,3 +360,13 @@ def encode(self, text):
ps = [space] + ps + [space]

return [self._token2id[p] for p in ps]

@contextmanager
def set_phone_prob(self, prob):
if hasattr(self.g2p, "phoneme_probability"):
self.g2p.phoneme_probability = prob
try:
yield
finally:
if hasattr(self.g2p, "phoneme_probability"):
self.g2p.phoneme_probability = self.phoneme_probability