Skip to content

Commit

Permalink
Adding speaker embedding conditioning in fastpitch (#4986)
Browse files Browse the repository at this point in the history
Signed-off-by: subhankar-ghosh <subhankar2321@gmail.com>

Signed-off-by: subhankar-ghosh <subhankar2321@gmail.com>
Signed-off-by: George Zelenfroynd <zelenfr@ya.ru>
  • Loading branch information
subhankar-ghosh authored and Jorjeous committed Sep 27, 2022
1 parent 910b09b commit bfb4edd
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 9 deletions.
7 changes: 7 additions & 0 deletions nemo/collections/tts/models/fastpitch.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,9 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None):
output_fft = instantiate(self._cfg.output_fft)
duration_predictor = instantiate(self._cfg.duration_predictor)
pitch_predictor = instantiate(self._cfg.pitch_predictor)
speaker_emb_condition_prosody = cfg.get("speaker_emb_condition_prosody", False)
speaker_emb_condition_decoder = cfg.get("speaker_emb_condition_decoder", False)
speaker_emb_condition_aligner = cfg.get("speaker_emb_condition_aligner", False)

self.fastpitch = FastPitchModule(
input_fft,
Expand All @@ -150,6 +153,10 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None):
cfg.symbols_embedding_dim,
cfg.pitch_embedding_kernel_size,
cfg.n_mel_channels,
cfg.max_token_duration,
speaker_emb_condition_prosody,
speaker_emb_condition_decoder,
speaker_emb_condition_aligner,
)
self._input_types = self._output_types = None
self.export_config = {"enable_volume": False, "enable_ragged_batches": False}
Expand Down
5 changes: 4 additions & 1 deletion nemo/collections/tts/modules/aligner.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,18 +143,21 @@ def get_mean_distance_for_word(l2_dists, durs, start_token, num_tokens):

return dist_sum / total_frames

def forward(self, queries, keys, mask=None, attn_prior=None):
def forward(self, queries, keys, mask=None, attn_prior=None, conditioning=None):
"""Forward pass of the aligner encoder.
Args:
queries (torch.tensor): B x C x T1 tensor (probably going to be mel data).
keys (torch.tensor): B x C2 x T2 tensor (text data).
mask (torch.tensor): B x T2 x 1 tensor, binary mask for variable length entries (True = mask element, False = leave unchanged).
attn_prior (torch.tensor): prior for attention matrix.
conditioning (torch.tensor): B x T2 x 1 conditioning embedding
Output:
attn (torch.tensor): B x 1 x T1 x T2 attention mask. Final dim T2 should sum to 1.
attn_logprob (torch.tensor): B x 1 x T1 x T2 log-prob attention mask.
"""
if conditioning is not None:
keys = keys + conditioning.transpose(1, 2)
keys_enc = self.key_proj(keys) # B x n_attn_dims x T2
queries_enc = self.query_proj(queries) # B x n_attn_dims x T1

Expand Down
40 changes: 32 additions & 8 deletions nemo/collections/tts/modules/fastpitch.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,9 @@ def __init__(
pitch_embedding_kernel_size: int,
n_mel_channels: int = 80,
max_token_duration: int = 75,
speaker_emb_condition_prosody: bool = False,
speaker_emb_condition_decoder: bool = False,
speaker_emb_condition_aligner: bool = False,
):
super().__init__()

Expand All @@ -152,6 +155,9 @@ def __init__(
self.learn_alignment = aligner is not None
self.use_duration_predictor = True
self.binarize = False
self.speaker_emb_condition_prosody = speaker_emb_condition_prosody
self.speaker_emb_condition_decoder = speaker_emb_condition_decoder
self.speaker_emb_condition_aligner = speaker_emb_condition_aligner

if n_speakers > 1:
self.speaker_emb = torch.nn.Embedding(n_speakers, symbols_embedding_dim)
Expand Down Expand Up @@ -230,19 +236,27 @@ def forward(

# Input FFT
enc_out, enc_mask = self.encoder(input=text, conditioning=spk_emb)

log_durs_predicted = self.duration_predictor(enc_out, enc_mask)
if self.speaker_emb_condition_prosody:
prosody_input = enc_out + spk_emb
else:
prosody_input = enc_out
log_durs_predicted = self.duration_predictor(prosody_input, enc_mask)
durs_predicted = torch.clamp(torch.exp(log_durs_predicted) - 1, 0, self.max_token_duration)

attn_soft, attn_hard, attn_hard_dur, attn_logprob = None, None, None, None
if self.learn_alignment and spec is not None:
text_emb = self.encoder.word_emb(text)
attn_soft, attn_logprob = self.aligner(spec, text_emb.permute(0, 2, 1), enc_mask == 0, attn_prior)
if self.speaker_emb_condition_aligner and not isinstance(spk_emb, int):
attn_soft, attn_logprob = self.aligner(
spec, text_emb.permute(0, 2, 1), enc_mask == 0, attn_prior, conditioning=spk_emb
)
else:
attn_soft, attn_logprob = self.aligner(spec, text_emb.permute(0, 2, 1), enc_mask == 0, attn_prior)
attn_hard = binarize_attention_parallel(attn_soft, input_lens, mel_lens)
attn_hard_dur = attn_hard.sum(2)[:, 0, :]

# Predict pitch
pitch_predicted = self.pitch_predictor(enc_out, enc_mask)
pitch_predicted = self.pitch_predictor(prosody_input, enc_mask)
if pitch is not None:
if self.learn_alignment and pitch.shape[-1] != pitch_predicted.shape[-1]:
# Pitch during training is per spectrogram frame, but during inference, it should be per character
Expand All @@ -262,7 +276,10 @@ def forward(
len_regulated, dec_lens = regulate_len(durs_predicted, enc_out, pace)

# Output FFT
dec_out, _ = self.decoder(input=len_regulated, seq_lens=dec_lens)
if self.speaker_emb_condition_decoder:
dec_out, _ = self.decoder(input=len_regulated, seq_lens=dec_lens, conditioning=spk_emb)
else:
dec_out, _ = self.decoder(input=len_regulated, seq_lens=dec_lens)
spect = self.proj(dec_out).transpose(1, 2)
return (
spect,
Expand All @@ -286,13 +303,17 @@ def infer(self, *, text, pitch=None, speaker=None, pace=1.0, volume=None):

# Input FFT
enc_out, enc_mask = self.encoder(input=text, conditioning=spk_emb)
if self.speaker_emb_condition_prosody:
prosody_input = enc_out + spk_emb
else:
prosody_input = enc_out

# Predict duration and pitch
log_durs_predicted = self.duration_predictor(enc_out, enc_mask)
log_durs_predicted = self.duration_predictor(prosody_input, enc_mask)
durs_predicted = torch.clamp(
torch.exp(log_durs_predicted) - 1.0, self.min_token_duration, self.max_token_duration
)
pitch_predicted = self.pitch_predictor(enc_out, enc_mask) + pitch
pitch_predicted = self.pitch_predictor(prosody_input, enc_mask) + pitch
pitch_emb = self.pitch_emb(pitch_predicted.unsqueeze(1))
enc_out = enc_out + pitch_emb.transpose(1, 2)

Expand All @@ -304,7 +325,10 @@ def infer(self, *, text, pitch=None, speaker=None, pace=1.0, volume=None):
volume_extended = volume_extended.squeeze(-1).float()

# Output FFT
dec_out, _ = self.decoder(input=len_regulated, seq_lens=dec_lens)
if self.speaker_emb_condition_decoder:
dec_out, _ = self.decoder(input=len_regulated, seq_lens=dec_lens, conditioning=spk_emb)
else:
dec_out, _ = self.decoder(input=len_regulated, seq_lens=dec_lens)
spect = self.proj(dec_out).transpose(1, 2)
return (
spect.to(torch.float),
Expand Down

0 comments on commit bfb4edd

Please sign in to comment.