From 9cc17be53a59d9e50e57659990ff12157328e622 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Thu, 15 Apr 2021 16:36:51 +0200 Subject: [PATCH] formatting and a small bug fix in Tacotron model --- TTS/bin/train_vocoder_gan.py | 4 +++- TTS/tts/layers/tacotron/common_layers.py | 10 +++++++++- TTS/tts/layers/tacotron/tacotron.py | 1 - TTS/tts/models/tacotron.py | 1 + TTS/tts/utils/generic_utils.py | 4 ++-- TTS/tts/utils/text/symbols.py | 2 +- TTS/vocoder/tf/models/melgan_generator.py | 1 + 7 files changed, 17 insertions(+), 6 deletions(-) diff --git a/TTS/bin/train_vocoder_gan.py b/TTS/bin/train_vocoder_gan.py index 38aa69e578..730506c192 100644 --- a/TTS/bin/train_vocoder_gan.py +++ b/TTS/bin/train_vocoder_gan.py @@ -445,7 +445,9 @@ def evaluate(model_G, criterion_G, model_D, criterion_D, ap, global_step, epoch) # Sample audio predict_waveform = y_hat[0].squeeze(0).detach().cpu().numpy() real_waveform = y_G[0].squeeze(0).cpu().numpy() - tb_logger.tb_eval_audios(global_step, {"eval/audio": predict_waveform, "eval/real_waveformo": real_waveform}, c.audio["sample_rate"]) + tb_logger.tb_eval_audios( + global_step, {"eval/audio": predict_waveform, "eval/real_waveformo": real_waveform}, c.audio["sample_rate"] + ) tb_logger.tb_eval_stats(global_step, keep_avg.avg_values) diff --git a/TTS/tts/layers/tacotron/common_layers.py b/TTS/tts/layers/tacotron/common_layers.py index d110319a9f..f78ff1e75f 100644 --- a/TTS/tts/layers/tacotron/common_layers.py +++ b/TTS/tts/layers/tacotron/common_layers.py @@ -87,7 +87,15 @@ class Prenet(nn.Module): """ # pylint: disable=dangerous-default-value - def __init__(self, in_features, prenet_type="original", prenet_dropout=True, dropout_at_inference=False, out_features=[256, 256], bias=True): + def __init__( + self, + in_features, + prenet_type="original", + prenet_dropout=True, + dropout_at_inference=False, + out_features=[256, 256], + bias=True, + ): super().__init__() self.prenet_type = prenet_type self.prenet_dropout = prenet_dropout diff --git a/TTS/tts/layers/tacotron/tacotron.py b/TTS/tts/layers/tacotron/tacotron.py index 153af5b766..dcb5fdc505 100644 --- a/TTS/tts/layers/tacotron/tacotron.py +++ b/TTS/tts/layers/tacotron/tacotron.py @@ -306,7 +306,6 @@ def __init__( # processed_inputs, processed_memory -> |Attention| -> Attention, attention, RNN_State # attention_rnn generates queries for the attention mechanism self.attention_rnn = nn.GRUCell(in_channels + 128, self.query_dim) - self.attention = init_attn( attn_type=attn_type, query_dim=self.query_dim, diff --git a/TTS/tts/models/tacotron.py b/TTS/tts/models/tacotron.py index 1608a92a54..297c8e3e93 100644 --- a/TTS/tts/models/tacotron.py +++ b/TTS/tts/models/tacotron.py @@ -93,6 +93,7 @@ def __init__( attn_norm, prenet_type, prenet_dropout, + prenet_dropout_at_inference, forward_attn, trans_agent, forward_attn_mask, diff --git a/TTS/tts/utils/generic_utils.py b/TTS/tts/utils/generic_utils.py index dcbdc5b765..641d49b2d4 100644 --- a/TTS/tts/utils/generic_utils.py +++ b/TTS/tts/utils/generic_utils.py @@ -68,7 +68,7 @@ def setup_model(num_chars, num_speakers, c, speaker_embedding_dim=None): attn_norm=c.attention_norm, prenet_type=c.prenet_type, prenet_dropout=c.prenet_dropout, - prenet_dropout_at_inference=c.prenet_dropout_at_inference if 'prenet_dropout_at_inference' in c else False, + prenet_dropout_at_inference=c.prenet_dropout_at_inference if "prenet_dropout_at_inference" in c else False, forward_attn=c.use_forward_attn, trans_agent=c.transition_agent, forward_attn_mask=c.forward_attn_mask, @@ -97,7 +97,7 @@ def setup_model(num_chars, num_speakers, c, speaker_embedding_dim=None): attn_norm=c.attention_norm, prenet_type=c.prenet_type, prenet_dropout=c.prenet_dropout, - prenet_dropout_at_inference=c.prenet_dropout_at_inference if 'prenet_dropout_at_inference' in c else False, + prenet_dropout_at_inference=c.prenet_dropout_at_inference if "prenet_dropout_at_inference" in c else False, forward_attn=c.use_forward_attn, trans_agent=c.transition_agent, forward_attn_mask=c.forward_attn_mask, diff --git a/TTS/tts/utils/text/symbols.py b/TTS/tts/utils/text/symbols.py index b48082b3cf..6efe392009 100644 --- a/TTS/tts/utils/text/symbols.py +++ b/TTS/tts/utils/text/symbols.py @@ -11,7 +11,7 @@ def make_symbols( characters, phonemes=None, punctuations="!'(),-.:;? ", pad="_", eos="~", bos="^" ): # pylint: disable=redefined-outer-name """ Function to create symbols and phonemes """ - _symbols = list(characters) + _symbols = list(characters) _symbols = [bos] + _symbols if len(bos) > 0 and bos is not None else _symbols _symbols = [eos] + _symbols if len(bos) > 0 and eos is not None else _symbols _symbols = [pad] + _symbols if len(bos) > 0 and pad is not None else _symbols diff --git a/TTS/vocoder/tf/models/melgan_generator.py b/TTS/vocoder/tf/models/melgan_generator.py index 90e0fa0ca0..205a240ec2 100644 --- a/TTS/vocoder/tf/models/melgan_generator.py +++ b/TTS/vocoder/tf/models/melgan_generator.py @@ -10,6 +10,7 @@ from TTS.vocoder.tf.layers.melgan import ReflectionPad1d, ResidualStack + # pylint: disable=too-many-ancestors # pylint: disable=abstract-method class MelganGenerator(tf.keras.models.Model):