Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix HiFiGAN compatibility #334

Merged
merged 3 commits into from
Feb 27, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 50 additions & 51 deletions parallel_wavegan/layers/residual_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,8 +152,6 @@ def __init__(
use_additional_convs=True,
nonlinear_activation="LeakyReLU",
nonlinear_activation_params={"negative_slope": 0.1},
pad="ReplicationPad1d",
pad_params={},
use_causal_conv=False,
):
"""Initialize HiFiGANResidualBlock module.
Expand All @@ -166,8 +164,6 @@ def __init__(
bias (bool): Whether to add bias parameter in convolution layers.
nonlinear_activation (str): Activation function module name.
nonlinear_activation_params (dict): Hyperparameters for activation function.
pad (str): Padding function module name before convolution layer.
pad_params (dict): Hyperparameters for padding function.
use_causal_conv (bool): Whether to use causal structure.

"""
Expand All @@ -180,66 +176,69 @@ def __init__(
assert kernel_size % 2 == 1, "Kernel size must be odd number."
for dilation in dilations:
if not use_causal_conv:
conv = torch.nn.Sequential(
getattr(torch.nn, pad)(
(kernel_size - 1) // 2 * dilation, **pad_params
),
torch.nn.Conv1d(
channels,
channels,
kernel_size,
dilation=dilation,
bias=bias,
),
)
else:
conv = CausalConv1d(
channels,
channels,
kernel_size,
dilation=dilation,
bias=bias,
pad=pad,
pad_params=pad_params,
)
self.convs1 += [
torch.nn.Sequential(
getattr(torch.nn, nonlinear_activation)(
**nonlinear_activation_params
),
conv,
)
]
if use_additional_convs:
if not use_causal_conv:
conv = torch.nn.Sequential(
getattr(torch.nn, pad)((kernel_size - 1) // 2, **pad_params),
self.convs1 += [
torch.nn.Sequential(
getattr(torch.nn, nonlinear_activation)(
**nonlinear_activation_params
),
torch.nn.Conv1d(
channels,
channels,
kernel_size,
dilation=1,
1,
dilation=dilation,
bias=bias,
padding=(kernel_size - 1) // 2 * dilation,
),
)
else:
conv = CausalConv1d(
channels,
channels,
kernel_size,
dilation=1,
bias=bias,
pad=pad,
pad_params=pad_params,
)
self.convs2 += [
]
else:
self.convs1 += [
torch.nn.Sequential(
getattr(torch.nn, nonlinear_activation)(
**nonlinear_activation_params
),
conv,
CausalConv1d(
channels,
channels,
kernel_size,
dilation=dilation,
bias=bias,
),
)
]
if use_additional_convs:
if not use_causal_conv:
self.convs2 += [
torch.nn.Sequential(
getattr(torch.nn, nonlinear_activation)(
**nonlinear_activation_params
),
torch.nn.Conv1d(
channels,
channels,
kernel_size,
dilation=1,
bias=bias,
padding=(kernel_size - 1) // 2,
),
)
]
else:
self.convs2 += [
torch.nn.Sequential(
getattr(torch.nn, nonlinear_activation)(
**nonlinear_activation_params
),
CausalConv1d(
channels,
channels,
kernel_size,
dilation=1,
bias=bias,
),
),
]

def forward(self, x):
"""Calculate forward propagation.
Expand Down
111 changes: 53 additions & 58 deletions parallel_wavegan/models/hifigan.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,6 @@ def __init__(
bias=True,
nonlinear_activation="LeakyReLU",
nonlinear_activation_params={"negative_slope": 0.1},
pad="ReplicationPad1d",
pad_params={},
use_causal_conv=False,
use_weight_norm=True,
):
Expand All @@ -56,8 +54,6 @@ def __init__(
bias (bool): Whether to add bias parameter in convolution layers.
nonlinear_activation (str): Activation function module name.
nonlinear_activation_params (dict): Hyperparameters for activation function.
pad (str): Padding function module name before convolution layer.
pad_params (dict): Hyperparameters for padding function.
use_causal_conv (bool): Whether to use causal structure.
use_weight_norm (bool): Whether to use weight norm.
If set to true, it will be applied to all of the conv layers.
Expand All @@ -75,56 +71,56 @@ def __init__(
self.num_blocks = len(resblock_kernel_sizes)
self.use_causal_conv = use_causal_conv
if not use_causal_conv:
self.input_conv = torch.nn.Sequential(
getattr(torch.nn, pad)((kernel_size - 1) // 2, **pad_params),
torch.nn.Conv1d(
in_channels,
channels,
kernel_size,
bias=bias,
),
self.input_conv = torch.nn.Conv1d(
in_channels,
channels,
kernel_size,
bias=bias,
padding=(kernel_size - 1) // 2,
)
else:
self.input_conv = CausalConv1d(
in_channels,
channels,
kernel_size,
bias=bias,
pad=pad,
pad_params=pad_params,
)
self.upsamples = torch.nn.ModuleList()
self.blocks = torch.nn.ModuleList()
for i in range(len(upsample_kernel_sizes)):
assert upsample_kernel_sizes[i] == 2 * upsample_scales[i]
if not use_causal_conv:
conv = torch.nn.ConvTranspose1d(
channels // (2 ** i),
channels // (2 ** (i + 1)),
upsample_kernel_sizes[i],
upsample_scales[i],
padding=upsample_scales[i] // 2 + upsample_scales[i] % 2,
output_padding=upsample_scales[i] % 2,
bias=bias,
)
self.upsamples += [
torch.nn.Sequential(
getattr(torch.nn, nonlinear_activation)(
**nonlinear_activation_params
),
torch.nn.ConvTranspose1d(
channels // (2 ** i),
channels // (2 ** (i + 1)),
upsample_kernel_sizes[i],
upsample_scales[i],
padding=upsample_scales[i] // 2 + upsample_scales[i] % 2,
output_padding=upsample_scales[i] % 2,
bias=bias,
),
)
]
else:
conv = CausalConvTranspose1d(
channels // (2 ** i),
channels // (2 ** (i + 1)),
upsample_kernel_sizes[i],
upsample_scales[i],
bias=bias,
pad=pad,
pad_params=pad_params,
)
self.upsamples += [
torch.nn.Sequential(
getattr(torch.nn, nonlinear_activation)(
**nonlinear_activation_params
),
conv,
)
]
self.upsamples += [
torch.nn.Sequential(
getattr(torch.nn, nonlinear_activation)(
**nonlinear_activation_params
),
CausalConvTranspose1d(
channels // (2 ** i),
channels // (2 ** (i + 1)),
upsample_kernel_sizes[i],
upsample_scales[i],
bias=bias,
),
)
]
for j in range(len(resblock_kernel_sizes)):
self.blocks += [
ResidualBlock(
Expand All @@ -135,37 +131,36 @@ def __init__(
use_additional_convs=use_additional_convs,
nonlinear_activation=nonlinear_activation,
nonlinear_activation_params=nonlinear_activation_params,
pad=pad,
pad_params=pad_params,
use_causal_conv=use_causal_conv,
)
]
if not use_causal_conv:
conv = torch.nn.Sequential(
getattr(torch.nn, pad)((kernel_size - 1) // 2, **pad_params),
self.output_conv = torch.nn.Sequential(
# NOTE(kan-bayashi): follow official implementation but why
# using different slope parameter here? (0.1 vs. 0.01)
torch.nn.LeakyReLU(),
torch.nn.Conv1d(
channels // (2 ** (i + 1)),
out_channels,
kernel_size,
bias=bias,
padding=(kernel_size - 1) // 2,
),
torch.nn.Tanh(),
)
else:
conv = CausalConv1d(
channels // (2 ** (i + 1)),
out_channels,
kernel_size,
bias=bias,
pad=pad,
pad_params=pad_params,
self.output_conv = torch.nn.Sequential(
# NOTE(kan-bayashi): follow official implementation but why
# using different slope parameter here? (0.1 vs. 0.01)
torch.nn.LeakyReLU(),
CausalConv1d(
channels // (2 ** (i + 1)),
out_channels,
kernel_size,
bias=bias,
),
torch.nn.Tanh(),
)
self.output_conv = torch.nn.Sequential(
# NOTE(kan-bayashi): follow official implementation but why
# using different slope parameter here? (0.1 vs. 0.01)
torch.nn.LeakyReLU(),
conv,
torch.nn.Tanh(),
)

# apply weight norm
if use_weight_norm:
Expand Down
2 changes: 0 additions & 2 deletions test/test_hifigan.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,6 @@ def make_hifigan_generator_args(**kwargs):
bias=True,
nonlinear_activation="LeakyReLU",
nonlinear_activation_params={"negative_slope": 0.1},
pad="ReplicationPad1d",
pad_params={},
use_weight_norm=True,
use_causal_conv=False,
)
Expand Down