From 44456b0483bf42b1337a8e408ac17af38b26b1fa Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Fri, 29 Apr 2022 07:28:39 -0300 Subject: [PATCH] Fix style --- TTS/vocoder/models/gan.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/TTS/vocoder/models/gan.py b/TTS/vocoder/models/gan.py index f5d0a33e77..367efdc24b 100644 --- a/TTS/vocoder/models/gan.py +++ b/TTS/vocoder/models/gan.py @@ -89,14 +89,14 @@ def train_step(self, batch: Dict, criterion: Dict, optimizer_idx: int) -> Tuple[ if optimizer_idx not in [0, 1]: raise ValueError(" [!] Unexpected `optimizer_idx`.") - if optimizer_idx == 0: # DISCRIMINATOR optimization # generator pass y_hat = self.model_g(x)[:, :, : y.size(2)] - + # cache for generator loss + # pylint: disable=W0201 self.y_hat_g = y_hat self.y_hat_sub = None self.y_sub_g = None @@ -178,7 +178,9 @@ def train_step(self, batch: Dict, criterion: Dict, optimizer_idx: int) -> Tuple[ feats_fake, feats_real = None, None # compute losses - loss_dict = criterion[optimizer_idx](self.y_hat_g, y, scores_fake, feats_fake, feats_real, self.y_hat_sub, self.y_sub_g) + loss_dict = criterion[optimizer_idx]( + self.y_hat_g, y, scores_fake, feats_fake, feats_real, self.y_hat_sub, self.y_sub_g + ) outputs = {"model_outputs": self.y_hat_g} return outputs, loss_dict