Skip to content

Commit

Permalink
Make nemo text processing optional in TTS (NVIDIA#10584)
Browse files Browse the repository at this point in the history
* move TN guard to better location; make guard print error message rather than throwing error

Signed-off-by: Jason <jasoli@nvidia.com>

* Apply isort and black reformatting

Signed-off-by: blisc <blisc@users.noreply.github.com>

* Forgot to add the actual normalizer

Signed-off-by: Jason <jasoli@nvidia.com>

* Apply isort and black reformatting

Signed-off-by: blisc <blisc@users.noreply.github.com>

---------

Signed-off-by: Jason <jasoli@nvidia.com>
Signed-off-by: blisc <blisc@users.noreply.github.com>
Co-authored-by: blisc <blisc@users.noreply.github.com>
Signed-off-by: Hainan Xu <hainanx@nvidia.com>
  • Loading branch information
2 people authored and Hainan Xu committed Nov 5, 2024
1 parent 063de03 commit 34b4fdd
Show file tree
Hide file tree
Showing 8 changed files with 213 additions and 185 deletions.
3 changes: 2 additions & 1 deletion nemo/collections/tts/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,8 @@ def __init__(
self.text_normalizer_call = None
elif not PYNINI_AVAILABLE:
raise ImportError(
"`nemo_text_processing` is not installed, see https://github.com/NVIDIA/NeMo-text-processing for details"
"`nemo_text_processing` is not installed, see https://github.com/NVIDIA/NeMo-text-processing for details. "
"If you wish to continue without text normalization, please remove the text_normalizer part in your TTS yaml file."
)
else:
self.text_normalizer_call = (
Expand Down
26 changes: 2 additions & 24 deletions nemo/collections/tts/models/aligner.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from torch import nn

from nemo.collections.tts.losses.aligner_loss import BinLoss, ForwardSumLoss
from nemo.collections.tts.models.base import NeedsNormalizer
from nemo.collections.tts.parts.utils.helpers import (
binarize_attention,
g2p_backward_compatible_support,
Expand All @@ -41,7 +42,7 @@
HAVE_WANDB = False


class AlignerModel(ModelPT):
class AlignerModel(NeedsNormalizer, ModelPT):
"""Speech-to-text alignment model (https://arxiv.org/pdf/2108.10447.pdf) that is used to learn alignments between mel spectrogram and text."""

def __init__(self, cfg: DictConfig, trainer: 'Trainer' = None):
Expand Down Expand Up @@ -77,29 +78,6 @@ def __init__(self, cfg: DictConfig, trainer: 'Trainer' = None):
self.bin_loss_start_ratio = cfg.bin_loss_start_ratio
self.bin_loss_warmup_epochs = cfg.bin_loss_warmup_epochs

def _setup_normalizer(self, cfg):
if "text_normalizer" in cfg:
normalizer_kwargs = {}

if "whitelist" in cfg.text_normalizer:
normalizer_kwargs["whitelist"] = self.register_artifact(
'text_normalizer.whitelist', cfg.text_normalizer.whitelist
)

try:
import nemo_text_processing

self.normalizer = instantiate(cfg.text_normalizer, **normalizer_kwargs)
except Exception as e:
logging.error(e)
raise ImportError(
"`nemo_text_processing` not installed, see https://github.com/NVIDIA/NeMo-text-processing for more details"
)

self.text_normalizer_call = self.normalizer.normalize
if "text_normalizer_call_kwargs" in cfg:
self.text_normalizer_call_kwargs = cfg.text_normalizer_call_kwargs

def _setup_tokenizer(self, cfg):
text_tokenizer_kwargs = {}
if "g2p" in cfg.text_tokenizer:
Expand Down
54 changes: 44 additions & 10 deletions nemo/collections/tts/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from typing import List, Optional

import torch
from hydra.utils import instantiate
from omegaconf import DictConfig
from tqdm import tqdm

Expand All @@ -28,9 +29,39 @@
from nemo.core.neural_types.neural_type import NeuralType
from nemo.utils import logging, model_utils

PYNINI_AVAILABLE = True
try:
import nemo_text_processing
except (ImportError, ModuleNotFoundError):
PYNINI_AVAILABLE = False

class SpectrogramGenerator(ModelPT, ABC):
""" Base class for all TTS models that turn text into a spectrogram """

class NeedsNormalizer:
"""Base class for all TTS models that needs text normalization(TN)"""

def _setup_normalizer(self, cfg):
if "text_normalizer" in cfg:
if not PYNINI_AVAILABLE:
logging.error(
"`nemo_text_processing` not installed, see https://github.com/NVIDIA/NeMo-text-processing for more details."
)
logging.error("The normalizer will be disabled.")
return
normalizer_kwargs = {}

if "whitelist" in cfg.text_normalizer:
normalizer_kwargs["whitelist"] = self.register_artifact(
'text_normalizer.whitelist', cfg.text_normalizer.whitelist
)

self.normalizer = instantiate(cfg.text_normalizer, **normalizer_kwargs)
self.text_normalizer_call = self.normalizer.normalize
if "text_normalizer_call_kwargs" in cfg:
self.text_normalizer_call_kwargs = cfg.text_normalizer_call_kwargs


class SpectrogramGenerator(NeedsNormalizer, ModelPT, ABC):
"""Base class for all TTS models that turn text into a spectrogram"""

@abstractmethod
def parse(self, str_input: str, **kwargs) -> 'torch.tensor':
Expand Down Expand Up @@ -115,7 +146,7 @@ def list_available_models(cls) -> 'List[PretrainedModelInfo]':


class GlowVocoder(Vocoder):
""" Base class for all Vocoders that use a Glow or reversible Flow-based setup. All child class are expected
"""Base class for all Vocoders that use a Glow or reversible Flow-based setup. All child class are expected
to have a parameter called audio_to_melspec_precessor that is an instance of
nemo.collections.asr.parts.FilterbankFeatures"""

Expand Down Expand Up @@ -175,7 +206,11 @@ def yet_another_patch(audio, n_fft, hop_length, win_length, window):
return torch.sqrt(spec.pow(2).sum(-1)), torch.atan2(spec[..., -1], spec[..., 0])

self.stft = lambda x: yet_another_patch(
x, n_fft=n_fft, hop_length=hop_length, win_length=win_length, window=window,
x,
n_fft=n_fft,
hop_length=hop_length,
win_length=win_length,
window=window,
)
self.istft = lambda x, y: torch.istft(
torch.complex(x * torch.cos(y), x * torch.sin(y)),
Expand Down Expand Up @@ -252,15 +287,15 @@ def list_available_models(cls) -> 'List[PretrainedModelInfo]':
return list_of_models


class TextToWaveform(ModelPT, ABC):
""" Base class for all end-to-end TTS models that generate a waveform from text """
class TextToWaveform(NeedsNormalizer, ModelPT, ABC):
"""Base class for all end-to-end TTS models that generate a waveform from text"""

@abstractmethod
def parse(self, str_input: str, **kwargs) -> 'torch.tensor':
"""
A helper function that accepts a raw python string and turns it into a tensor. The tensor should have 2
dimensions. The first is the batch, which should be of size 1. The second should represent time. The tensor
should represent either tokenized or embedded text, depending on the model.
A helper function that accepts a raw python string and turns it into a tensor. The tensor should have 2
dimensions. The first is the batch, which should be of size 1. The second should represent time. The tensor
should represent either tokenized or embedded text, depending on the model.
"""

@abstractmethod
Expand Down Expand Up @@ -299,7 +334,6 @@ def convert_graphemes_to_phonemes(
num_workers: int = 0,
pred_field: Optional[str] = "pred_text",
) -> List[str]:

"""
Main function for Inference. Converts grapheme entries from the manifest "graheme_field" to phonemes
Args:
Expand Down
73 changes: 42 additions & 31 deletions nemo/collections/tts/models/fastpitch.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,28 +200,6 @@ def _get_default_text_tokenizer_conf(self):
text_tokenizer: TextTokenizerConfig = TextTokenizerConfig()
return OmegaConf.create(OmegaConf.to_yaml(text_tokenizer))

def _setup_normalizer(self, cfg):
if "text_normalizer" in cfg:
normalizer_kwargs = {}

if "whitelist" in cfg.text_normalizer:
normalizer_kwargs["whitelist"] = self.register_artifact(
'text_normalizer.whitelist', cfg.text_normalizer.whitelist
)
try:
import nemo_text_processing

self.normalizer = instantiate(cfg.text_normalizer, **normalizer_kwargs)
except Exception as e:
logging.error(e)
raise ImportError(
"`nemo_text_processing` not installed, see https://github.com/NVIDIA/NeMo-text-processing for more details"
)

self.text_normalizer_call = self.normalizer.normalize
if "text_normalizer_call_kwargs" in cfg:
self.text_normalizer_call_kwargs = cfg.text_normalizer_call_kwargs

def _setup_tokenizer(self, cfg):
text_tokenizer_kwargs = {}

Expand All @@ -240,12 +218,14 @@ def _setup_tokenizer(self, cfg):

if "phoneme_dict" in cfg.text_tokenizer.g2p:
g2p_kwargs["phoneme_dict"] = self.register_artifact(
'text_tokenizer.g2p.phoneme_dict', cfg.text_tokenizer.g2p.phoneme_dict,
'text_tokenizer.g2p.phoneme_dict',
cfg.text_tokenizer.g2p.phoneme_dict,
)

if "heteronyms" in cfg.text_tokenizer.g2p:
g2p_kwargs["heteronyms"] = self.register_artifact(
'text_tokenizer.g2p.heteronyms', cfg.text_tokenizer.g2p.heteronyms,
'text_tokenizer.g2p.heteronyms',
cfg.text_tokenizer.g2p.heteronyms,
)

# for backward compatability
Expand Down Expand Up @@ -478,16 +458,25 @@ def training_step(self, batch, batch_idx):
)
spec_predict = mels_pred[0].data.cpu().float().numpy()
self.tb_logger.add_image(
"train_mel_predicted", plot_spectrogram_to_numpy(spec_predict), self.global_step, dataformats="HWC",
"train_mel_predicted",
plot_spectrogram_to_numpy(spec_predict),
self.global_step,
dataformats="HWC",
)
if self.learn_alignment:
attn = attn_hard[0].data.cpu().float().numpy().squeeze()
self.tb_logger.add_image(
"train_attn", plot_alignment_to_numpy(attn.T), self.global_step, dataformats="HWC",
"train_attn",
plot_alignment_to_numpy(attn.T),
self.global_step,
dataformats="HWC",
)
soft_attn = attn_soft[0].data.cpu().float().numpy().squeeze()
self.tb_logger.add_image(
"train_soft_attn", plot_alignment_to_numpy(soft_attn.T), self.global_step, dataformats="HWC",
"train_soft_attn",
plot_alignment_to_numpy(soft_attn.T),
self.global_step,
dataformats="HWC",
)

return loss
Expand Down Expand Up @@ -527,7 +516,20 @@ def validation_step(self, batch, batch_idx):
)

# Calculate val loss on ground truth durations to better align L2 loss in time
(mels_pred, _, _, log_durs_pred, pitch_pred, _, _, _, attn_hard_dur, pitch, energy_pred, energy_tgt,) = self(
(
mels_pred,
_,
_,
log_durs_pred,
pitch_pred,
_,
_,
_,
attn_hard_dur,
pitch,
energy_pred,
energy_tgt,
) = self(
text=text,
durs=durs,
pitch=pitch,
Expand Down Expand Up @@ -587,7 +589,10 @@ def on_validation_epoch_end(self):
)
spec_predict = spec_predict[0].data.cpu().float().numpy()
self.tb_logger.add_image(
"val_mel_predicted", plot_spectrogram_to_numpy(spec_predict), self.global_step, dataformats="HWC",
"val_mel_predicted",
plot_spectrogram_to_numpy(spec_predict),
self.global_step,
dataformats="HWC",
)
self.log_train_images = True
self.validation_step_outputs.clear() # free memory)
Expand All @@ -598,7 +603,10 @@ def _setup_train_dataloader(self, cfg):
phon_mode = self.vocab.set_phone_prob(self.vocab.phoneme_probability)

with phon_mode:
dataset = instantiate(cfg.dataset, text_tokenizer=self.vocab,)
dataset = instantiate(
cfg.dataset,
text_tokenizer=self.vocab,
)

sampler = dataset.get_sampler(cfg.dataloader_params.batch_size, world_size=self.trainer.world_size)
return torch.utils.data.DataLoader(
Expand All @@ -611,7 +619,10 @@ def _setup_test_dataloader(self, cfg):
phon_mode = self.vocab.set_phone_prob(0.0)

with phon_mode:
dataset = instantiate(cfg.dataset, text_tokenizer=self.vocab,)
dataset = instantiate(
cfg.dataset,
text_tokenizer=self.vocab,
)

return torch.utils.data.DataLoader(dataset, collate_fn=dataset.collate_fn, **cfg.dataloader_params)

Expand Down
Loading

0 comments on commit 34b4fdd

Please sign in to comment.