From a0beaf6ad56843981d91e83bbd1e70b93291dc26 Mon Sep 17 00:00:00 2001 From: Ryan Date: Thu, 20 Apr 2023 17:10:16 -0700 Subject: [PATCH 1/2] [TTS] Add cosine distance option to TTS aligner Signed-off-by: Ryan --- examples/tts/conf/fastpitch/fastpitch.yaml | 2 + nemo/collections/tts/models/fastpitch.py | 16 ++--- nemo/collections/tts/modules/aligner.py | 78 ++++++++++++++++------ 3 files changed, 66 insertions(+), 30 deletions(-) diff --git a/examples/tts/conf/fastpitch/fastpitch.yaml b/examples/tts/conf/fastpitch/fastpitch.yaml index 1d552d058d76..39d5f395afbc 100644 --- a/examples/tts/conf/fastpitch/fastpitch.yaml +++ b/examples/tts/conf/fastpitch/fastpitch.yaml @@ -193,6 +193,8 @@ model: alignment_module: _target_: nemo.collections.tts.modules.aligner.AlignmentEncoder n_text_channels: ${model.symbols_embedding_dim} + dist_type: cosine + temperature: 15.0 duration_predictor: _target_: nemo.collections.tts.modules.fastpitch.TemporalPredictor diff --git a/nemo/collections/tts/models/fastpitch.py b/nemo/collections/tts/models/fastpitch.py index 1a68d9e51aeb..dc598a9a76d1 100644 --- a/nemo/collections/tts/models/fastpitch.py +++ b/nemo/collections/tts/models/fastpitch.py @@ -121,16 +121,10 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): self.log_images = cfg.get("log_images", False) self.log_train_images = False - loss_scale = 0.1 if self.learn_alignment else 1.0 - dur_loss_scale = loss_scale - pitch_loss_scale = loss_scale - energy_loss_scale = loss_scale - if "dur_loss_scale" in cfg: - dur_loss_scale = cfg.dur_loss_scale - if "pitch_loss_scale" in cfg: - pitch_loss_scale = cfg.pitch_loss_scale - if "energy_loss_scale" in cfg: - energy_loss_scale = cfg.energy_loss_scale + default_prosody_loss_scale = 0.1 if self.learn_alignment else 1.0 + dur_loss_scale = cfg.get("dur_loss_scale", default_prosody_loss_scale) + pitch_loss_scale = cfg.get("pitch_loss_scale", default_prosody_loss_scale) + energy_loss_scale = cfg.get("energy_loss_scale", default_prosody_loss_scale) self.mel_loss_fn = MelLoss() self.pitch_loss_fn = PitchLoss(loss_scale=pitch_loss_scale) @@ -139,7 +133,7 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): self.aligner = None if self.learn_alignment: - aligner_loss_scale = cfg.aligner_loss_scale if "aligner_loss_scale" in cfg else 1.0 + aligner_loss_scale = cfg.get("aligner_loss_scale", 1.0) self.aligner = instantiate(self._cfg.alignment_module) self.forward_sum_loss_fn = ForwardSumLoss(loss_scale=aligner_loss_scale) self.bin_loss_fn = BinLoss(loss_scale=aligner_loss_scale) diff --git a/nemo/collections/tts/modules/aligner.py b/nemo/collections/tts/modules/aligner.py index bc170742df23..198ddb635446 100644 --- a/nemo/collections/tts/modules/aligner.py +++ b/nemo/collections/tts/modules/aligner.py @@ -14,6 +14,7 @@ import torch +from einops import rearrange from torch import nn from nemo.collections.tts.modules.submodules import ConditionalInput, ConvNorm @@ -24,7 +25,13 @@ class AlignmentEncoder(torch.nn.Module): """Module for alignment text and mel spectrogram. """ def __init__( - self, n_mel_channels=80, n_text_channels=512, n_att_channels=80, temperature=0.0005, condition_types=[] + self, + n_mel_channels=80, + n_text_channels=512, + n_att_channels=80, + temperature=0.0005, + condition_types=[], + dist_type="l2", ): super().__init__() self.temperature = temperature @@ -45,6 +52,20 @@ def __init__( torch.nn.ReLU(), ConvNorm(n_mel_channels, n_att_channels, kernel_size=1, bias=True), ) + if dist_type == "l2": + self.dist_fn = self.get_euclidean_dist + elif dist_type == "cosine": + self.dist_fn = self.get_cosine_dist + else: + raise ValueError(f"Unknown distance type '{dist_type}'") + + @staticmethod + def _apply_mask(inputs, mask, mask_value): + if mask is None: + return + + mask = rearrange(mask, "B T2 1 -> B 1 1 T2") + inputs.data.masked_fill_(mask, mask_value) def get_dist(self, keys, queries, mask=None): """Calculation of distance matrix. @@ -57,15 +78,35 @@ def get_dist(self, keys, queries, mask=None): Output: dist (torch.tensor): B x T1 x T2 tensor. """ - keys_enc = self.key_proj(keys) # B x n_attn_dims x T2 - queries_enc = self.query_proj(queries) # B x n_attn_dims x T1 - attn = (queries_enc[:, :, :, None] - keys_enc[:, :, None]) ** 2 # B x n_attn_dims x T1 x T2 - dist = attn.sum(1, keepdim=True) # B x 1 x T1 x T2 + # B x n_attn_dims x T2 + keys_enc = self.key_proj(keys) + # B x n_attn_dims x T1 + queries_enc = self.query_proj(queries) + + # B x 1 x T1 x T2 + dist = self.dist_fn(queries=queries_enc, keys=keys_enc) + + self._apply_mask(dist, mask, float("inf")) - if mask is not None: - dist.data.masked_fill_(mask.permute(0, 2, 1).unsqueeze(2), float("inf")) + return dist - return dist.squeeze(1) + @staticmethod + def get_euclidean_dist(queries, keys): + queries = rearrange(queries, "B C T1 -> B C T1 1") + keys = rearrange(keys, "B C T2 -> B C 1 T2") + # B x C x T1 x T2 + distance = (queries - keys) ** 2 + # B x 1 x T1 x T2 + l2_dist = distance.sum(axis=1, keepdim=True) + return l2_dist + + @staticmethod + def get_cosine_dist(queries, keys): + queries = rearrange(queries, "B C T1 -> B C T1 1") + keys = rearrange(keys, "B C T2 -> B C 1 T2") + cosine_dist = -torch.nn.functional.cosine_similarity(queries, keys, dim=1) + cosine_dist = rearrange(cosine_dist, "B T1 T2 -> B 1 T1 T2") + return cosine_dist @staticmethod def get_durations(attn_soft, text_len, spect_len): @@ -96,8 +137,7 @@ def get_mean_dist_by_durations(dist, durations, mask=None): batch_size, t1_size, t2_size = dist.size() assert torch.all(torch.eq(durations.sum(dim=1), t1_size)) - if mask is not None: - dist = dist.masked_fill(mask.permute(0, 2, 1).unsqueeze(2), 0) + AlignmentEncoder._apply_mask(dist, mask, 0) # TODO(oktai15): make it more efficient mean_dist_by_durations = [] @@ -149,7 +189,7 @@ def forward(self, queries, keys, mask=None, attn_prior=None, conditioning=None): """Forward pass of the aligner encoder. Args: - queries (torch.tensor): B x C x T1 tensor (probably going to be mel data). + queries (torch.tensor): B x C1 x T1 tensor (probably going to be mel data). keys (torch.tensor): B x C2 x T2 tensor (text data). mask (torch.tensor): B x T2 x 1 tensor, binary mask for variable length entries (True = mask element, False = leave unchanged). attn_prior (torch.tensor): prior for attention matrix. @@ -159,20 +199,20 @@ def forward(self, queries, keys, mask=None, attn_prior=None, conditioning=None): attn_logprob (torch.tensor): B x 1 x T1 x T2 log-prob attention mask. """ keys = self.cond_input(keys.transpose(1, 2), conditioning).transpose(1, 2) - keys_enc = self.key_proj(keys) # B x n_attn_dims x T2 - queries_enc = self.query_proj(queries) # B x n_attn_dims x T1 - - # Simplistic Gaussian Isotopic Attention - attn = (queries_enc[:, :, :, None] - keys_enc[:, :, None]) ** 2 # B x n_attn_dims x T1 x T2 - attn = -self.temperature * attn.sum(1, keepdim=True) + # B x C x T2 + keys_enc = self.key_proj(keys) + # B x C x T1 + queries_enc = self.query_proj(queries) + # B x 1 x T1 x T2 + distance = self.dist_fn(queries=queries_enc, keys=keys_enc) + attn = -self.temperature * distance if attn_prior is not None: attn = self.log_softmax(attn) + torch.log(attn_prior[:, None] + 1e-8) attn_logprob = attn.clone() - if mask is not None: - attn.data.masked_fill_(mask.permute(0, 2, 1).unsqueeze(2), -float("inf")) + self._apply_mask(attn, mask, -float("inf")) attn = self.softmax(attn) # softmax along T2 return attn, attn_logprob From 8f32238db9fc8e2c2daf2ad5d8492a5b4bd7b299 Mon Sep 17 00:00:00 2001 From: Ryan Date: Mon, 10 Jul 2023 14:33:07 -0700 Subject: [PATCH 2/2] [TTS] Update aligner comments Signed-off-by: Ryan --- nemo/collections/tts/modules/aligner.py | 46 +++++++++++++--------- nemo/collections/tts/modules/submodules.py | 2 +- 2 files changed, 29 insertions(+), 19 deletions(-) diff --git a/nemo/collections/tts/modules/aligner.py b/nemo/collections/tts/modules/aligner.py index 198ddb635446..2910602474fd 100644 --- a/nemo/collections/tts/modules/aligner.py +++ b/nemo/collections/tts/modules/aligner.py @@ -22,7 +22,18 @@ class AlignmentEncoder(torch.nn.Module): - """Module for alignment text and mel spectrogram. """ + """ + Module for alignment text and mel spectrogram. + + Args: + n_mel_channels: Dimension of mel spectrogram. + n_text_channels: Dimension of text embeddings. + n_att_channels: Dimension of model + temperature: Temperature to scale distance by. + Suggested to be 0.0005 when using dist_type "l2" and 15.0 when using "cosine". + condition_types: List of types for nemo.collections.tts.modules.submodules.ConditionalInput. + dist_type: Distance type to use for similarity measurement. Supports "l2" and "cosine" distance. + """ def __init__( self, @@ -71,40 +82,39 @@ def get_dist(self, keys, queries, mask=None): """Calculation of distance matrix. Args: - queries (torch.tensor): B x C x T1 tensor (probably going to be mel data). + queries (torch.tensor): B x C1 x T1 tensor (probably going to be mel data). keys (torch.tensor): B x C2 x T2 tensor (text data). mask (torch.tensor): B x T2 x 1 tensor, binary mask for variable length entries and also can be used for ignoring unnecessary elements from keys in the resulting distance matrix (True = mask element, False = leave unchanged). Output: dist (torch.tensor): B x T1 x T2 tensor. """ - # B x n_attn_dims x T2 - keys_enc = self.key_proj(keys) - # B x n_attn_dims x T1 + # B x C x T1 queries_enc = self.query_proj(queries) - + # B x C x T2 + keys_enc = self.key_proj(keys) # B x 1 x T1 x T2 - dist = self.dist_fn(queries=queries_enc, keys=keys_enc) + dist = self.dist_fn(queries_enc=queries_enc, keys_enc=keys_enc) self._apply_mask(dist, mask, float("inf")) return dist @staticmethod - def get_euclidean_dist(queries, keys): - queries = rearrange(queries, "B C T1 -> B C T1 1") - keys = rearrange(keys, "B C T2 -> B C 1 T2") + def get_euclidean_dist(queries_enc, keys_enc): + queries_enc = rearrange(queries_enc, "B C T1 -> B C T1 1") + keys_enc = rearrange(keys_enc, "B C T2 -> B C 1 T2") # B x C x T1 x T2 - distance = (queries - keys) ** 2 + distance = (queries_enc - keys_enc) ** 2 # B x 1 x T1 x T2 l2_dist = distance.sum(axis=1, keepdim=True) return l2_dist @staticmethod - def get_cosine_dist(queries, keys): - queries = rearrange(queries, "B C T1 -> B C T1 1") - keys = rearrange(keys, "B C T2 -> B C 1 T2") - cosine_dist = -torch.nn.functional.cosine_similarity(queries, keys, dim=1) + def get_cosine_dist(queries_enc, keys_enc): + queries_enc = rearrange(queries_enc, "B C T1 -> B C T1 1") + keys_enc = rearrange(keys_enc, "B C T2 -> B C 1 T2") + cosine_dist = -torch.nn.functional.cosine_similarity(queries_enc, keys_enc, dim=1) cosine_dist = rearrange(cosine_dist, "B T1 T2 -> B 1 T1 T2") return cosine_dist @@ -199,12 +209,12 @@ def forward(self, queries, keys, mask=None, attn_prior=None, conditioning=None): attn_logprob (torch.tensor): B x 1 x T1 x T2 log-prob attention mask. """ keys = self.cond_input(keys.transpose(1, 2), conditioning).transpose(1, 2) - # B x C x T2 - keys_enc = self.key_proj(keys) # B x C x T1 queries_enc = self.query_proj(queries) + # B x C x T2 + keys_enc = self.key_proj(keys) # B x 1 x T1 x T2 - distance = self.dist_fn(queries=queries_enc, keys=keys_enc) + distance = self.dist_fn(queries_enc=queries_enc, keys_enc=keys_enc) attn = -self.temperature * distance if attn_prior is not None: diff --git a/nemo/collections/tts/modules/submodules.py b/nemo/collections/tts/modules/submodules.py index 408ab02dead2..92218e807aac 100644 --- a/nemo/collections/tts/modules/submodules.py +++ b/nemo/collections/tts/modules/submodules.py @@ -509,7 +509,7 @@ def forward(self, inputs, conditioning=None): inputs = inputs + conditioning if "concat" in self.condition_types: - conditioning = conditionting.repeat(1, inputs.shape[1], 1) + conditioning = conditioning.repeat(1, inputs.shape[1], 1) inputs = torch.cat([inputs, conditioning]) inputs = self.concat_proj(inputs)