Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Fix] Discriminator update in AudioCodecModel #7209

Merged
merged 1 commit into from
Aug 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions examples/tts/audio_codec.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,17 @@
# limitations under the License.

import pytorch_lightning as pl
from omegaconf import OmegaConf

from nemo.collections.tts.models import AudioCodecModel
from nemo.core.config import hydra_runner
from nemo.utils import logging
from nemo.utils.exp_manager import exp_manager


@hydra_runner(config_path="conf/audio_codec", config_name="audio_codec")
def main(cfg):
logging.info('\nConfig Params:\n%s', OmegaConf.to_yaml(cfg))
trainer = pl.Trainer(**cfg.trainer)
exp_manager(trainer, cfg.get("exp_manager", None))
model = AudioCodecModel(cfg=cfg.model, trainer=trainer)
Expand Down
4 changes: 3 additions & 1 deletion examples/tts/conf/audio_codec/encodec.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,9 @@ model:
samples_per_frame: ${samples_per_frame}
time_domain_loss_scale: 0.1
# Probability of updating the discriminator during each training step
disc_update_prob: 0.67
# For example, update the discriminator 2/3 times (2 updates for every 3 batches)
disc_updates_per_period: 2
disc_update_period: 3

# All resolutions for mel reconstruction loss, ordered [num_fft, hop_length, window_length]
mel_loss_resolutions: [
Expand Down
32 changes: 25 additions & 7 deletions nemo/collections/tts/models/audio_codec.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,13 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None):
self.sample_rate = cfg.sample_rate
self.samples_per_frame = cfg.samples_per_frame

self.disc_update_prob = cfg.get("disc_update_prob", 1.0)
self.disc_updates_per_period = cfg.get("disc_updates_per_period", 1)
self.disc_update_period = cfg.get("disc_update_period", 1)
if self.disc_updates_per_period > self.disc_update_period:
raise ValueError(
f'Number of discriminator updates ({self.disc_updates_per_period}) per period must be less or equal to the configured period ({self.disc_update_period})'
)

self.audio_encoder = instantiate(cfg.audio_encoder)

# Optionally, add gaussian noise to encoder output as an information bottleneck
Expand Down Expand Up @@ -204,18 +210,26 @@ def _process_batch(self, batch):

return audio, audio_len, audio_gen, commit_loss

@property
def disc_update_prob(self) -> float:
"""Probability of updating the discriminator.
"""
return self.disc_updates_per_period / self.disc_update_period

def should_update_disc(self, batch_idx) -> bool:
"""Decide whether to update the descriminator based
on the batch index and configured discriminator update period.
"""
disc_update_step = batch_idx % self.disc_update_period
return disc_update_step < self.disc_updates_per_period

def training_step(self, batch, batch_idx):
optim_gen, optim_disc = self.optimizers()
optim_gen.zero_grad()

audio, audio_len, audio_gen, commit_loss = self._process_batch(batch)

if self.disc_update_prob < random.random():
loss_disc = None
else:
if self.should_update_disc(batch_idx):
# Train discriminator
optim_disc.zero_grad()

disc_scores_real, disc_scores_gen, _, _ = self.discriminator(
audio_real=audio, audio_gen=audio_gen.detach()
)
Expand All @@ -224,6 +238,9 @@ def training_step(self, batch, batch_idx):

self.manual_backward(train_disc_loss)
optim_disc.step()
optim_disc.zero_grad()
else:
loss_disc = None

loss_time_domain = self.time_domain_loss_fn(audio_real=audio, audio_gen=audio_gen, audio_len=audio_len)
train_loss_time_domain = self.time_domain_loss_scale * loss_time_domain
Expand All @@ -245,6 +262,7 @@ def training_step(self, batch, batch_idx):

self.manual_backward(loss_gen_all)
optim_gen.step()
optim_gen.zero_grad()

self.update_lr()

Expand Down