Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix norm issue #409

Merged
merged 8 commits into from
Jun 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions parallel_wavegan/bin/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,15 +156,19 @@ def load_checkpoint(self, checkpoint_path, load_only_params=False):
state_dict = torch.load(checkpoint_path, map_location="cpu")
if self.config["distributed"]:
self.model["generator"].module.load_state_dict(
state_dict["model"]["generator"]
state_dict["model"]["generator"],
)
self.model["discriminator"].module.load_state_dict(
state_dict["model"]["discriminator"]
state_dict["model"]["discriminator"],
strict=False,
)
else:
self.model["generator"].load_state_dict(state_dict["model"]["generator"])
self.model["generator"].load_state_dict(
state_dict["model"]["generator"],
)
self.model["discriminator"].load_state_dict(
state_dict["model"]["discriminator"]
state_dict["model"]["discriminator"],
strict=False,
)
if not load_only_params:
self.steps = state_dict["steps"]
Expand Down
89 changes: 87 additions & 2 deletions parallel_wavegan/models/hifigan.py
Original file line number Diff line number Diff line change
Expand Up @@ -571,13 +571,18 @@ def __init__(
raise ValueError("Either use use_weight_norm or use_spectral_norm.")

# apply weight norm
self.use_weight_norm = use_weight_norm
if use_weight_norm:
self.apply_weight_norm()

# apply spectral norm
self.use_spectral_norm = use_spectral_norm
if use_spectral_norm:
self.apply_spectral_norm()

# backward compatibility
self._register_load_state_dict_pre_hook(self._load_state_dict_pre_hook)

def forward(self, x):
"""Calculate forward propagation.

Expand All @@ -599,7 +604,7 @@ def apply_weight_norm(self):
"""Apply weight normalization module from all of the layers."""

def _apply_weight_norm(m):
if isinstance(m, torch.nn.Conv2d):
if isinstance(m, torch.nn.Conv1d):
torch.nn.utils.weight_norm(m)
logging.debug(f"Weight norm is applied to {m}.")

Expand All @@ -609,12 +614,92 @@ def apply_spectral_norm(self):
"""Apply spectral normalization module from all of the layers."""

def _apply_spectral_norm(m):
if isinstance(m, torch.nn.Conv2d):
if isinstance(m, torch.nn.Conv1d):
torch.nn.utils.spectral_norm(m)
logging.debug(f"Spectral norm is applied to {m}.")

self.apply(_apply_spectral_norm)

def remove_weight_norm(self):
"""Remove weight normalization module from all of the layers."""

def _remove_weight_norm(m):
try:
logging.debug(f"Weight norm is removed from {m}.")
torch.nn.utils.remove_weight_norm(m)
except ValueError: # this module didn't have weight norm
return

self.apply(_remove_weight_norm)

def remove_spectral_norm(self):
"""Remove spectral normalization module from all of the layers."""

def _remove_spectral_norm(m):
try:
logging.debug(f"Spectral norm is removed from {m}.")
torch.nn.utils.remove_spectral_norm(m)
except ValueError: # this module didn't have weight norm
return

self.apply(_remove_spectral_norm)

def _load_state_dict_pre_hook(
self,
state_dict,
prefix,
local_metadata,
strict,
missing_keys,
unexpected_keys,
error_msgs,
):
"""Fix the compatibility of weight / spectral normalization issue.

Some pretrained models are trained with configs that use weight / spectral
normalization, but actually, the norm is not applied. This causes the mismatch
of the parameters with configs. To solve this issue, when parameter mismatch
happens in loading pretrained model, we remove the norm from the current model.

See also:
- https://github.com/kan-bayashi/ParallelWaveGAN/pull/409
- https://github.com/espnet/espnet/pull/5240

"""
if self.use_weight_norm and not any(
["weight_g" in k for k in state_dict.keys()]
):
logging.warning(
"It seems weight norm is not applied in the pretrained model but the"
" current model uses it. To keep the compatibility, we remove the norm"
" from the current model. This may causes training error due to the the"
" parameter mismatch when finetuning. To avoid this issue, please"
" change the following parameters in config to false: \n"
" - discriminator_params.follow_official_norm \n"
" - discriminator_params.scale_discriminator_params.use_weight_norm \n"
" - discriminator_params.scale_discriminator_params.use_spectral_norm \n"
" See also: https://github.com/kan-bayashi/ParallelWaveGAN/pull/409"
)
self.remove_weight_norm()
self.use_weight_norm = False

if self.use_spectral_norm and not any(
["weight_u" in k for k in state_dict.keys()]
):
logging.warning(
"It seems spectral norm is not applied in the pretrained model but the"
" current model uses it. To keep the compatibility, we remove the norm"
" from the current model. This may causes training error due to the the"
" parameter mismatch when finetuning. To avoid this issue, please"
" change the following parameters in config to false: \n"
" - discriminator_params.follow_official_norm \n"
" - discriminator_params.scale_discriminator_params.use_weight_norm \n"
" - discriminator_params.scale_discriminator_params.use_spectral_norm \n"
" See also: https://github.com/kan-bayashi/ParallelWaveGAN/pull/409"
)
self.remove_spectral_norm()
self.use_spectral_norm = False


class HiFiGANMultiScaleDiscriminator(torch.nn.Module):
"""HiFi-GAN multi-scale discriminator module."""
Expand Down
23 changes: 23 additions & 0 deletions test/test_hifigan.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,15 @@
"""Test code for HiFi-GAN modules."""

import logging
import os

import numpy as np
import pytest
import torch
import yaml
from test_parallel_wavegan import make_mutli_reso_stft_loss_args

import parallel_wavegan.models
from parallel_wavegan.losses import (
DiscriminatorAdversarialLoss,
FeatureMatchLoss,
Expand Down Expand Up @@ -219,3 +222,23 @@ def test_causal_hifigan(dict_g):
y[..., : c.size(-1) // 2 * upsampling_factor].detach().cpu().numpy(),
y_[..., : c_.size(-1) // 2 * upsampling_factor].detach().cpu().numpy(),
)


@pytest.mark.skipif(not torch.cuda.is_available(), reason="Run in only local")
def test_fix_norm_issue():
from parallel_wavegan.utils import download_pretrained_model

checkpoint = download_pretrained_model("ljspeech_hifigan.v1")
config = os.path.join(os.path.dirname(checkpoint), "config.yml")
with open(config) as f:
config = yaml.load(f, Loader=yaml.Loader)

# get model and load parameters
discriminator_type = config.get("discriminator_type")
model_class = getattr(
parallel_wavegan.models,
discriminator_type,
)
model = model_class(**config["discriminator_params"])
state_dict = torch.load(checkpoint, map_location="cpu")["model"]["discriminator"]
model.load_state_dict(state_dict, strict=False)