From c9766e073112a4dd7487d7a1236f57f03be412cb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ante=20Jukic=CC=81?= Date: Wed, 9 Aug 2023 18:44:15 -0700 Subject: [PATCH] Fix discriminator update in AudioCodecModel MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Ante Jukić --- examples/tts/audio_codec.py | 3 ++ examples/tts/conf/audio_codec/encodec.yaml | 4 ++- nemo/collections/tts/models/audio_codec.py | 32 +++++++++++++++++----- 3 files changed, 31 insertions(+), 8 deletions(-) diff --git a/examples/tts/audio_codec.py b/examples/tts/audio_codec.py index ffc91cd98f01..9244721298de 100644 --- a/examples/tts/audio_codec.py +++ b/examples/tts/audio_codec.py @@ -13,14 +13,17 @@ # limitations under the License. import pytorch_lightning as pl +from omegaconf import OmegaConf from nemo.collections.tts.models import AudioCodecModel from nemo.core.config import hydra_runner +from nemo.utils import logging from nemo.utils.exp_manager import exp_manager @hydra_runner(config_path="conf/audio_codec", config_name="audio_codec") def main(cfg): + logging.info('\nConfig Params:\n%s', OmegaConf.to_yaml(cfg)) trainer = pl.Trainer(**cfg.trainer) exp_manager(trainer, cfg.get("exp_manager", None)) model = AudioCodecModel(cfg=cfg.model, trainer=trainer) diff --git a/examples/tts/conf/audio_codec/encodec.yaml b/examples/tts/conf/audio_codec/encodec.yaml index 3cbf00e1a35f..402be0ad4721 100644 --- a/examples/tts/conf/audio_codec/encodec.yaml +++ b/examples/tts/conf/audio_codec/encodec.yaml @@ -37,7 +37,9 @@ model: samples_per_frame: ${samples_per_frame} time_domain_loss_scale: 0.1 # Probability of updating the discriminator during each training step - disc_update_prob: 0.67 + # For example, update the discriminator 2/3 times (2 updates for every 3 batches) + disc_updates_per_period: 2 + disc_update_period: 3 # All resolutions for mel reconstruction loss, ordered [num_fft, hop_length, window_length] mel_loss_resolutions: [ diff --git a/nemo/collections/tts/models/audio_codec.py b/nemo/collections/tts/models/audio_codec.py index f840eab279ed..4acacfd51397 100644 --- a/nemo/collections/tts/models/audio_codec.py +++ b/nemo/collections/tts/models/audio_codec.py @@ -52,7 +52,13 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): self.sample_rate = cfg.sample_rate self.samples_per_frame = cfg.samples_per_frame - self.disc_update_prob = cfg.get("disc_update_prob", 1.0) + self.disc_updates_per_period = cfg.get("disc_updates_per_period", 1) + self.disc_update_period = cfg.get("disc_update_period", 1) + if self.disc_updates_per_period > self.disc_update_period: + raise ValueError( + f'Number of discriminator updates ({self.disc_updates_per_period}) per period must be less or equal to the configured period ({self.disc_update_period})' + ) + self.audio_encoder = instantiate(cfg.audio_encoder) # Optionally, add gaussian noise to encoder output as an information bottleneck @@ -204,18 +210,26 @@ def _process_batch(self, batch): return audio, audio_len, audio_gen, commit_loss + @property + def disc_update_prob(self) -> float: + """Probability of updating the discriminator. + """ + return self.disc_updates_per_period / self.disc_update_period + + def should_update_disc(self, batch_idx) -> bool: + """Decide whether to update the descriminator based + on the batch index and configured discriminator update period. + """ + disc_update_step = batch_idx % self.disc_update_period + return disc_update_step < self.disc_updates_per_period + def training_step(self, batch, batch_idx): optim_gen, optim_disc = self.optimizers() - optim_gen.zero_grad() audio, audio_len, audio_gen, commit_loss = self._process_batch(batch) - if self.disc_update_prob < random.random(): - loss_disc = None - else: + if self.should_update_disc(batch_idx): # Train discriminator - optim_disc.zero_grad() - disc_scores_real, disc_scores_gen, _, _ = self.discriminator( audio_real=audio, audio_gen=audio_gen.detach() ) @@ -224,6 +238,9 @@ def training_step(self, batch, batch_idx): self.manual_backward(train_disc_loss) optim_disc.step() + optim_disc.zero_grad() + else: + loss_disc = None loss_time_domain = self.time_domain_loss_fn(audio_real=audio, audio_gen=audio_gen, audio_len=audio_len) train_loss_time_domain = self.time_domain_loss_scale * loss_time_domain @@ -245,6 +262,7 @@ def training_step(self, batch, batch_idx): self.manual_backward(loss_gen_all) optim_gen.step() + optim_gen.zero_grad() self.update_lr()