Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TTS] Add cosine distance option to TTS aligner #6806

Merged
merged 2 commits into from
Jul 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions examples/tts/conf/fastpitch/fastpitch.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,8 @@ model:
alignment_module:
_target_: nemo.collections.tts.modules.aligner.AlignmentEncoder
n_text_channels: ${model.symbols_embedding_dim}
dist_type: cosine
XuesongYang marked this conversation as resolved.
Show resolved Hide resolved
temperature: 15.0
XuesongYang marked this conversation as resolved.
Show resolved Hide resolved

duration_predictor:
_target_: nemo.collections.tts.modules.fastpitch.TemporalPredictor
Expand Down
16 changes: 5 additions & 11 deletions nemo/collections/tts/models/fastpitch.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,16 +121,10 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None):
self.log_images = cfg.get("log_images", False)
self.log_train_images = False

loss_scale = 0.1 if self.learn_alignment else 1.0
dur_loss_scale = loss_scale
pitch_loss_scale = loss_scale
energy_loss_scale = loss_scale
if "dur_loss_scale" in cfg:
dur_loss_scale = cfg.dur_loss_scale
if "pitch_loss_scale" in cfg:
pitch_loss_scale = cfg.pitch_loss_scale
if "energy_loss_scale" in cfg:
energy_loss_scale = cfg.energy_loss_scale
default_prosody_loss_scale = 0.1 if self.learn_alignment else 1.0
dur_loss_scale = cfg.get("dur_loss_scale", default_prosody_loss_scale)
pitch_loss_scale = cfg.get("pitch_loss_scale", default_prosody_loss_scale)
energy_loss_scale = cfg.get("energy_loss_scale", default_prosody_loss_scale)

self.mel_loss_fn = MelLoss()
self.pitch_loss_fn = PitchLoss(loss_scale=pitch_loss_scale)
Expand All @@ -139,7 +133,7 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None):

self.aligner = None
if self.learn_alignment:
aligner_loss_scale = cfg.aligner_loss_scale if "aligner_loss_scale" in cfg else 1.0
aligner_loss_scale = cfg.get("aligner_loss_scale", 1.0)
self.aligner = instantiate(self._cfg.alignment_module)
self.forward_sum_loss_fn = ForwardSumLoss(loss_scale=aligner_loss_scale)
self.bin_loss_fn = BinLoss(loss_scale=aligner_loss_scale)
Expand Down
92 changes: 71 additions & 21 deletions nemo/collections/tts/modules/aligner.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,35 @@


import torch
from einops import rearrange
from torch import nn

from nemo.collections.tts.modules.submodules import ConditionalInput, ConvNorm
from nemo.collections.tts.parts.utils.helpers import binarize_attention_parallel


class AlignmentEncoder(torch.nn.Module):
"""Module for alignment text and mel spectrogram. """
"""
Module for alignment text and mel spectrogram.

Args:
n_mel_channels: Dimension of mel spectrogram.
n_text_channels: Dimension of text embeddings.
n_att_channels: Dimension of model
temperature: Temperature to scale distance by.
Suggested to be 0.0005 when using dist_type "l2" and 15.0 when using "cosine".
condition_types: List of types for nemo.collections.tts.modules.submodules.ConditionalInput.
dist_type: Distance type to use for similarity measurement. Supports "l2" and "cosine" distance.
"""

def __init__(
self, n_mel_channels=80, n_text_channels=512, n_att_channels=80, temperature=0.0005, condition_types=[]
self,
n_mel_channels=80,
n_text_channels=512,
n_att_channels=80,
temperature=0.0005,
condition_types=[],
dist_type="l2",
):
super().__init__()
self.temperature = temperature
Expand All @@ -45,27 +63,60 @@ def __init__(
torch.nn.ReLU(),
ConvNorm(n_mel_channels, n_att_channels, kernel_size=1, bias=True),
)
if dist_type == "l2":
self.dist_fn = self.get_euclidean_dist
elif dist_type == "cosine":
self.dist_fn = self.get_cosine_dist
else:
raise ValueError(f"Unknown distance type '{dist_type}'")

@staticmethod
def _apply_mask(inputs, mask, mask_value):
Fixed Show fixed Hide fixed
if mask is None:
return

mask = rearrange(mask, "B T2 1 -> B 1 1 T2")
inputs.data.masked_fill_(mask, mask_value)

def get_dist(self, keys, queries, mask=None):
"""Calculation of distance matrix.

Args:
queries (torch.tensor): B x C x T1 tensor (probably going to be mel data).
queries (torch.tensor): B x C1 x T1 tensor (probably going to be mel data).
keys (torch.tensor): B x C2 x T2 tensor (text data).
mask (torch.tensor): B x T2 x 1 tensor, binary mask for variable length entries and also can be used
for ignoring unnecessary elements from keys in the resulting distance matrix (True = mask element, False = leave unchanged).
Output:
dist (torch.tensor): B x T1 x T2 tensor.
"""
keys_enc = self.key_proj(keys) # B x n_attn_dims x T2
queries_enc = self.query_proj(queries) # B x n_attn_dims x T1
attn = (queries_enc[:, :, :, None] - keys_enc[:, :, None]) ** 2 # B x n_attn_dims x T1 x T2
dist = attn.sum(1, keepdim=True) # B x 1 x T1 x T2
# B x C x T1
queries_enc = self.query_proj(queries)
# B x C x T2
keys_enc = self.key_proj(keys)
# B x 1 x T1 x T2
dist = self.dist_fn(queries_enc=queries_enc, keys_enc=keys_enc)

self._apply_mask(dist, mask, float("inf"))

if mask is not None:
dist.data.masked_fill_(mask.permute(0, 2, 1).unsqueeze(2), float("inf"))
return dist

return dist.squeeze(1)
@staticmethod
def get_euclidean_dist(queries_enc, keys_enc):
queries_enc = rearrange(queries_enc, "B C T1 -> B C T1 1")
keys_enc = rearrange(keys_enc, "B C T2 -> B C 1 T2")
# B x C x T1 x T2
distance = (queries_enc - keys_enc) ** 2
# B x 1 x T1 x T2
l2_dist = distance.sum(axis=1, keepdim=True)
return l2_dist

@staticmethod
def get_cosine_dist(queries_enc, keys_enc):
queries_enc = rearrange(queries_enc, "B C T1 -> B C T1 1")
keys_enc = rearrange(keys_enc, "B C T2 -> B C 1 T2")
cosine_dist = -torch.nn.functional.cosine_similarity(queries_enc, keys_enc, dim=1)
cosine_dist = rearrange(cosine_dist, "B T1 T2 -> B 1 T1 T2")
return cosine_dist

@staticmethod
def get_durations(attn_soft, text_len, spect_len):
Expand Down Expand Up @@ -96,8 +147,7 @@ def get_mean_dist_by_durations(dist, durations, mask=None):
batch_size, t1_size, t2_size = dist.size()
assert torch.all(torch.eq(durations.sum(dim=1), t1_size))

if mask is not None:
dist = dist.masked_fill(mask.permute(0, 2, 1).unsqueeze(2), 0)
AlignmentEncoder._apply_mask(dist, mask, 0)

# TODO(oktai15): make it more efficient
mean_dist_by_durations = []
Expand Down Expand Up @@ -149,7 +199,7 @@ def forward(self, queries, keys, mask=None, attn_prior=None, conditioning=None):
"""Forward pass of the aligner encoder.

Args:
queries (torch.tensor): B x C x T1 tensor (probably going to be mel data).
queries (torch.tensor): B x C1 x T1 tensor (probably going to be mel data).
keys (torch.tensor): B x C2 x T2 tensor (text data).
mask (torch.tensor): B x T2 x 1 tensor, binary mask for variable length entries (True = mask element, False = leave unchanged).
attn_prior (torch.tensor): prior for attention matrix.
Expand All @@ -159,20 +209,20 @@ def forward(self, queries, keys, mask=None, attn_prior=None, conditioning=None):
attn_logprob (torch.tensor): B x 1 x T1 x T2 log-prob attention mask.
"""
keys = self.cond_input(keys.transpose(1, 2), conditioning).transpose(1, 2)
keys_enc = self.key_proj(keys) # B x n_attn_dims x T2
queries_enc = self.query_proj(queries) # B x n_attn_dims x T1

# Simplistic Gaussian Isotopic Attention
attn = (queries_enc[:, :, :, None] - keys_enc[:, :, None]) ** 2 # B x n_attn_dims x T1 x T2
attn = -self.temperature * attn.sum(1, keepdim=True)
# B x C x T1
queries_enc = self.query_proj(queries)
# B x C x T2
keys_enc = self.key_proj(keys)
# B x 1 x T1 x T2
distance = self.dist_fn(queries_enc=queries_enc, keys_enc=keys_enc)
attn = -self.temperature * distance

if attn_prior is not None:
attn = self.log_softmax(attn) + torch.log(attn_prior[:, None] + 1e-8)

attn_logprob = attn.clone()

if mask is not None:
attn.data.masked_fill_(mask.permute(0, 2, 1).unsqueeze(2), -float("inf"))
self._apply_mask(attn, mask, -float("inf"))

attn = self.softmax(attn) # softmax along T2
return attn, attn_logprob
2 changes: 1 addition & 1 deletion nemo/collections/tts/modules/submodules.py
Original file line number Diff line number Diff line change
Expand Up @@ -509,7 +509,7 @@ def forward(self, inputs, conditioning=None):
inputs = inputs + conditioning

if "concat" in self.condition_types:
conditioning = conditionting.repeat(1, inputs.shape[1], 1)
conditioning = conditioning.repeat(1, inputs.shape[1], 1)
inputs = torch.cat([inputs, conditioning])
inputs = self.concat_proj(inputs)

Expand Down
Loading