From 8adcd1de8e031ddc55fcda39da93cb909b136aea Mon Sep 17 00:00:00 2001 From: Eren G??lge Date: Tue, 17 May 2022 13:37:05 +0200 Subject: [PATCH] Rename `g` as `spk_emb` --- TTS/tts/models/forward_tts.py | 31 ++++++++++++++++++------------- 1 file changed, 18 insertions(+), 13 deletions(-) diff --git a/TTS/tts/models/forward_tts.py b/TTS/tts/models/forward_tts.py index 147093c5d0..01e3dff57f 100644 --- a/TTS/tts/models/forward_tts.py +++ b/TTS/tts/models/forward_tts.py @@ -170,6 +170,11 @@ class ForwardTTS(BaseTTS): If the pitch predictor is used, the model trains a pitch predictor that predicts average pitch value for each input character as in the FastPitch model. + :: + + |-----> (optional) PitchPredictor(o_en, spk_emb) --> pitch_emb --> o_en = o_en + pitch_emb-----| -> CondConv(spk_emb) -> spk_proj + spk, text -> Encoder(text, spk)--> o_en, spk_emb -----> DurationPredictor(o_en, spk_emb)--> dur -------------------------> Expand(o_en, dur) -> PositionEncoding(o_en_expand) -> Decoder(o_en_expand_pos, spk_proj) -> mel_out + `ForwardTTS` can be configured to one of these architectures, - FastPitch @@ -610,19 +615,19 @@ def forward( - g: :math:`[B, C]` - pitch: :math:`[B, 1, T]` """ - g = self._set_speaker_input(aux_input) + spk = self._set_speaker_input(aux_input) # compute sequence masks y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).float() # [B, 1, T_max2] x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.shape[1]), 1).float() # [B, 1, T_max] # encoder pass - x_emb, x_mask, g, o_en = self._forward_encoder( - x, x_mask, g + x_emb, x_mask, spk_emb, o_en = self._forward_encoder( + x, x_mask, spk ) # [B, T_max, C_en], [B, 1, T_max], [B, C], [B, C_en, T_max] # duration predictor pass if self.args.detach_duration_predictor: - o_dr_log = self.duration_predictor(x=o_en.detach(), x_mask=x_mask, g=g) # [B, 1, T_max] + o_dr_log = self.duration_predictor(x=o_en.detach(), x_mask=x_mask, g=spk_emb) # [B, 1, T_max] else: - o_dr_log = self.duration_predictor(x=o_en, x_mask=x_mask, g=g) # [B, 1, T_max] + o_dr_log = self.duration_predictor(x=o_en, x_mask=x_mask, g=spk_emb) # [B, 1, T_max] o_dr = torch.clamp(torch.exp(o_dr_log) - 1, 0, self.max_duration) # generate attn mask from predicted durations dur_predictor_attn = self.generate_attn(o_dr.squeeze(1), x_mask) # [B, T_max, T_max2'] @@ -644,7 +649,7 @@ def forward( avg_pitch = None if self.args.use_pitch: o_pitch_emb, o_pitch, avg_pitch = self._forward_pitch_predictor( - o_en=o_en, x_mask=x_mask, pitch=pitch, dr=dr, g=g + o_en=o_en, x_mask=x_mask, pitch=pitch, dr=dr, g=spk_emb ) o_en = o_en + o_pitch_emb # expand encoder outputs @@ -652,10 +657,10 @@ def forward( o_en=o_en, y_lengths=y_lengths, dr=dr, x_mask=x_mask ) # [B, 1, T_max2], [B, C_en, T_max2], [B, T_max2, T_max] # decoder pass - o_de = self._forward_decoder(o_en_ex=o_en_ex, y_mask=y_mask, g=g) # [B, T_max2, C_de] + o_de = self._forward_decoder(o_en_ex=o_en_ex, y_mask=y_mask, g=spk_emb) # [B, T_max2, C_de] outputs = { "model_outputs": o_de, # [B, T, C] - "g": g, # [B, C] + "spk_emb": spk_emb, # [B, C] "durations_log": o_dr_log.squeeze(1), # [B, T] "durations": o_dr.squeeze(1), # [B, T] "attn_durations": dur_predictor_attn, # for visualization [B, T_en, T_de'] @@ -688,11 +693,11 @@ def inference( - x_lengths: [B] - g: [B, C] """ - g = self._set_speaker_input(aux_input) + spk = self._set_speaker_input(aux_input) x_lengths = torch.tensor(x.shape[1:2]).to(x.device) x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.shape[1]), 1).to(x.dtype).float() # encoder pass - _, x_mask, g, o_en = self._forward_encoder(x, x_mask, g) + _, x_mask, spk_emb, o_en = self._forward_encoder(x, x_mask, spk) # duration predictor pass o_dr_log = self.duration_predictor(o_en, x_mask) o_dr = self.format_durations(o_dr_log, x_mask).squeeze(1) @@ -700,7 +705,7 @@ def inference( # pitch predictor pass o_pitch = None if self.args.use_pitch: - o_pitch_emb, o_pitch = self._forward_pitch_predictor(o_en, x_mask) + o_pitch_emb, o_pitch = self._forward_pitch_predictor(o_en=o_en, x_mask=x_mask, g=spk_emb) o_en = o_en + o_pitch_emb # expand encoder outputs y_mask, o_en_ex, attn = self._expand_encoder(o_en=o_en, y_lengths=y_lengths, dr=o_dr, x_mask=x_mask) @@ -708,13 +713,13 @@ def inference( "alignments": attn, "pitch": o_pitch, "durations": o_dr, - "g": g, + "spk_emb": spk_emb, } if skip_decoder: outputs["o_en_ex"] = o_en_ex else: # decoder pass - outputs["model_outputs"] = self._forward_decoder(o_en_ex=o_en_ex, y_mask=y_mask, g=g) + outputs["model_outputs"] = self._forward_decoder(o_en_ex=o_en_ex, y_mask=y_mask, g=spk_emb) return outputs def train_step(self, batch: dict, criterion: nn.Module):