Skip to content

Commit

Permalink
Merge branch 'main' into prompt-learning-pipeline-parallel
Browse files Browse the repository at this point in the history
  • Loading branch information
vadam5 authored Jun 9, 2022
2 parents 6aaf222 + 51d7182 commit 68b2fa3
Show file tree
Hide file tree
Showing 7 changed files with 192 additions and 65 deletions.
16 changes: 13 additions & 3 deletions nemo/collections/tts/helpers/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,9 +367,13 @@ def tacotron2_log_to_wandb_func(
swriter.log({"audios": audios})


def plot_alignment_to_numpy(alignment, info=None):
fig, ax = plt.subplots(figsize=(6, 4))
im = ax.imshow(alignment, aspect='auto', origin='lower', interpolation='none')
def plot_alignment_to_numpy(alignment, title='', info=None, phoneme_seq=None, vmin=None, vmax=None):
if phoneme_seq:
fig, ax = plt.subplots(figsize=(15, 10))
else:
fig, ax = plt.subplots(figsize=(6, 4))
im = ax.imshow(alignment, aspect='auto', origin='lower', interpolation='none', vmin=vmin, vmax=vmax)
ax.set_title(title)
fig.colorbar(im, ax=ax)
xlabel = 'Decoder timestep'
if info is not None:
Expand All @@ -378,6 +382,12 @@ def plot_alignment_to_numpy(alignment, info=None):
plt.ylabel('Encoder timestep')
plt.tight_layout()

if phoneme_seq != None:
# for debugging of phonemes and durs in maps. Not used by def in training code
ax.set_yticks(np.arange(len(phoneme_seq)))
ax.set_yticklabels(phoneme_seq)
ax.hlines(np.arange(len(phoneme_seq)), xmin=0.0, xmax=max(ax.get_xticks()))

fig.canvas.draw()
data = save_figure_to_numpy(fig)
plt.close()
Expand Down
35 changes: 1 addition & 34 deletions nemo/collections/tts/modules/aligner.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,40 +16,7 @@
from torch import nn

from nemo.collections.tts.helpers.helpers import binarize_attention_parallel


class ConvNorm(torch.nn.Module):
def __init__(
self,
in_channels,
out_channels,
kernel_size=1,
stride=1,
padding=None,
dilation=1,
bias=True,
w_init_gain='linear',
):
super().__init__()
if padding is None:
assert kernel_size % 2 == 1
padding = int(dilation * (kernel_size - 1) / 2)

self.conv = torch.nn.Conv1d(
in_channels,
out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
bias=bias,
)

torch.nn.init.xavier_uniform_(self.conv.weight, gain=torch.nn.init.calculate_gain(w_init_gain))

def forward(self, signal):
conv_signal = self.conv(signal)
return conv_signal
from nemo.collections.tts.modules.submodules import ConvNorm


class AlignmentEncoder(torch.nn.Module):
Expand Down
2 changes: 1 addition & 1 deletion nemo/collections/tts/modules/mixer_tts.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import torch
from torch import nn

from nemo.collections.tts.modules.aligner import ConvNorm
from nemo.collections.tts.modules.submodules import ConvNorm
from nemo.collections.tts.modules.transformer import PositionalEmbedding


Expand Down
91 changes: 85 additions & 6 deletions nemo/collections/tts/modules/submodules.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,73 @@
from torch.nn import functional as F


class PartialConv1d(torch.nn.Conv1d):
"""
Zero padding creates a unique identifier for where the edge of the data is, such that the model can almost always identify
exactly where it is relative to either edge given a sufficient receptive field. Partial padding goes to some lengths to remove
this affect.
"""

def __init__(self, *args, **kwargs):

self.multi_channel = False
self.return_mask = False
super(PartialConv1d, self).__init__(*args, **kwargs)

self.weight_maskUpdater = torch.ones(1, 1, self.kernel_size[0])
self.slide_winsize = self.weight_maskUpdater.shape[1] * self.weight_maskUpdater.shape[2]

self.last_size = (None, None, None)
self.update_mask = None
self.mask_ratio = None

def forward(self, input: torch.Tensor, mask_in: Tuple[int, int, int] = None):
assert len(input.shape) == 3
# if a mask is input, or tensor shape changed, update mask ratio
if mask_in is not None or self.last_size != tuple(input.shape):
# borisf: disabled update for inference
if self.training:
self.last_size = tuple(input.shape)
with torch.no_grad():
if self.weight_maskUpdater.type() != input.type():
self.weight_maskUpdater = self.weight_maskUpdater.to(input)
if mask_in is None:
mask = torch.ones(1, 1, input.shape[2]).to(input)
else:
mask = mask_in
update_mask = F.conv1d(
mask,
self.weight_maskUpdater,
bias=None,
stride=self.stride,
padding=self.padding,
dilation=self.dilation,
groups=1,
)
# for mixed precision training, change 1e-8 to 1e-6
mask_ratio = self.slide_winsize / (update_mask + 1e-6)
update_mask = torch.clamp(update_mask, 0, 1)
mask_ratio = torch.mul(mask_ratio, update_mask)
self.update_mask = update_mask
self.mask_ratio = mask_ratio
else:
mask_ratio = self.mask_ratio
update_mask = self.update_mask

raw_out = super(PartialConv1d, self).forward(torch.mul(input, mask) if mask_in is not None else input)
if self.bias is not None:
bias_view = self.bias.view(1, self.out_channels, 1)
output = torch.mul(raw_out - bias_view, mask_ratio) + bias_view
output = torch.mul(output, update_mask)
else:
output = torch.mul(raw_out, mask_ratio)

if self.return_mask:
return output, update_mask
else:
return output


class LinearNorm(torch.nn.Module):
def __init__(self, in_dim, out_dim, bias=True, w_init_gain='linear'):
super().__init__()
Expand All @@ -41,13 +108,21 @@ def __init__(
dilation=1,
bias=True,
w_init_gain='linear',
use_partial_padding=False,
use_weight_norm=False,
):
super().__init__()
super(ConvNorm, self).__init__()
if padding is None:
assert kernel_size % 2 == 1
padding = int(dilation * (kernel_size - 1) / 2)

self.conv = torch.nn.Conv1d(
self.kernel_size = kernel_size
self.dilation = dilation
self.use_partial_padding = use_partial_padding
self.use_weight_norm = use_weight_norm
conv_fn = torch.nn.Conv1d
if self.use_partial_padding:
conv_fn = PartialConv1d
self.conv = conv_fn(
in_channels,
out_channels,
kernel_size=kernel_size,
Expand All @@ -56,11 +131,15 @@ def __init__(
dilation=dilation,
bias=bias,
)

torch.nn.init.xavier_uniform_(self.conv.weight, gain=torch.nn.init.calculate_gain(w_init_gain))
if self.use_weight_norm:
self.conv = nn.utils.weight_norm(self.conv)

def forward(self, signal):
conv_signal = self.conv(signal)
def forward(self, signal, mask=None):
if self.use_partial_padding:
conv_signal = self.conv(signal, mask)
else:
conv_signal = self.conv(signal)
return conv_signal


Expand Down
90 changes: 73 additions & 17 deletions nemo/collections/tts/torch/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,11 @@
Energy,
LMTokens,
LogMel,
P_voiced,
Pitch,
SpeakerID,
TTSDataType,
Voiced_mask,
WithLens,
)
from nemo.collections.tts.torch.tts_tokenizers import BaseTokenizer, EnglishCharsTokenizer, EnglishPhonemesTokenizer
Expand Down Expand Up @@ -130,6 +132,8 @@ def __init__(
log_mel_folder (Optional[Union[Path, str]]): The folder that contains or will contain log mel spectrograms.
align_prior_matrix_folder (Optional[Union[Path, str]]): The folder that contains or will contain align prior matrices.
pitch_folder (Optional[Union[Path, str]]): The folder that contains or will contain pitch.
voiced_mask_folder (Optional[Union[Path, str]]): The folder that contains or will contain voiced mask of the pitch
p_voiced_folder (Optional[Union[Path, str]]): The folder that contains or will contain p_voiced(probability) of the pitch
energy_folder (Optional[Union[Path, str]]): The folder that contains or will contain energy.
durs_file (Optional[str]): String path to pickled durations location.
durs_type (Optional[str]): Type of durations. Currently supported only "aligner-based".
Expand Down Expand Up @@ -375,6 +379,23 @@ def add_pitch(self, **kwargs):
self.pitch_std = kwargs.pop("pitch_std", None)
self.pitch_norm = kwargs.pop("pitch_norm", False)

# saving voiced_mask and p_voiced with pitch
def add_voiced_mask(self, **kwargs):
self.voiced_mask_folder = kwargs.pop('voiced_mask_folder', None)

if self.voiced_mask_folder is None:
self.voiced_mask_folder = Path(self.sup_data_path) / Voiced_mask.name

self.voiced_mask_folder.mkdir(exist_ok=True, parents=True)

def add_p_voiced(self, **kwargs):
self.p_voiced_folder = kwargs.pop('p_voiced_folder', None)

if self.p_voiced_folder is None:
self.p_voiced_folder = Path(self.sup_data_path) / P_voiced.name

self.p_voiced_folder.mkdir(exist_ok=True, parents=True)

def add_energy(self, **kwargs):
self.energy_folder = kwargs.pop('energy_folder', None)

Expand Down Expand Up @@ -402,6 +423,23 @@ def get_log_mel(self, audio):
return log_mel

def __getitem__(self, index):
def load_from_dir(folder):
voiced_path = folder / f"{rel_audio_path_as_text_id}.pt"
if voiced_path.exists():
voiced = torch.load(voiced_path).float()
else:
_, voiced, _ = librosa.pyin(
audio.numpy(),
fmin=self.pitch_fmin,
fmax=self.pitch_fmax,
frame_length=self.win_length,
sr=self.sample_rate,
fill_na=0.0,
)
voiced = torch.from_numpy(voiced).float()
torch.save(voiced, voiced_path)
return voiced

sample = self.data[index]

# Let's keep audio name and all internal directories in rel_audio_path_as_text_id to avoid any collisions
Expand Down Expand Up @@ -466,29 +504,19 @@ def __getitem__(self, index):
# Load pitch if needed
pitch, pitch_length = None, None
if Pitch in self.sup_data_types_set:
pitch_path = self.pitch_folder / f"{rel_audio_path_as_text_id}.pt"

if pitch_path.exists():
pitch = torch.load(pitch_path).float()
else:
pitch, _, _ = librosa.pyin(
audio.numpy(),
fmin=self.pitch_fmin,
fmax=self.pitch_fmax,
frame_length=self.win_length,
sr=self.sample_rate,
fill_na=0.0,
)
pitch = torch.from_numpy(pitch).float()
torch.save(pitch, pitch_path)

pitch = load_from_dir(self.pitch_folder)
if self.pitch_mean is not None and self.pitch_std is not None and self.pitch_norm:
pitch -= self.pitch_mean
pitch[pitch == -self.pitch_mean] = 0.0 # Zero out values that were previously zero
pitch /= self.pitch_std

pitch_length = torch.tensor(len(pitch)).long()

# Load voiced_mask if needed
voiced_mask = load_from_dir(self.voiced_mask_folder) if Voiced_mask in self.sup_data_types_set else None
# Load p_voiced if needed
p_voiced = load_from_dir(self.p_voiced_folder) if P_voiced in self.sup_data_types_set else None

# Load energy if needed
energy, energy_length = None, None
if Energy in self.sup_data_types_set:
Expand Down Expand Up @@ -522,6 +550,8 @@ def __getitem__(self, index):
energy,
energy_length,
speaker_id,
voiced_mask,
p_voiced,
)

def __len__(self):
Expand Down Expand Up @@ -551,6 +581,8 @@ def general_collate_fn(self, batch):
pitches_lengths,
energies,
energies_lengths,
voiced_masks,
p_voiceds,
_,
) = zip(*batch)

Expand All @@ -573,7 +605,17 @@ def general_collate_fn(self, batch):
if AlignPriorMatrix in self.sup_data_types_set
else []
)
audios, tokens, log_mels, durations_list, pitches, energies, speaker_ids = [], [], [], [], [], [], []
audios, tokens, log_mels, durations_list, pitches, energies, speaker_ids, voiced_masks, p_voiceds = (
[],
[],
[],
[],
[],
[],
[],
[],
[],
)

for i, sample_tuple in enumerate(batch):
(
Expand All @@ -590,6 +632,8 @@ def general_collate_fn(self, batch):
energy,
energy_length,
speaker_id,
voiced_mask,
p_voiceds,
) = sample_tuple

audio = general_padding(audio, audio_len.item(), max_audio_len)
Expand All @@ -608,6 +652,12 @@ def general_collate_fn(self, batch):
] = align_prior_matrix
if Pitch in self.sup_data_types_set:
pitches.append(general_padding(pitch, pitch_length.item(), max_pitches_len))

if Voiced_mask in self.sup_data_types_set:
voiced_masks.append(general_padding(voiced_mask, pitch_length.item(), max_pitches_len))
if P_voiced in self.sup_data_types_set:
p_voiceds.append(general_padding(voiced_mask, pitch_length.item(), max_pitches_len))

if Energy in self.sup_data_types_set:
energies.append(general_padding(energy, energy_length.item(), max_energies_len))
if SpeakerID in self.sup_data_types_set:
Expand All @@ -627,6 +677,8 @@ def general_collate_fn(self, batch):
"energy": torch.stack(energies) if Energy in self.sup_data_types_set else None,
"energy_lens": torch.stack(energies_lengths) if Energy in self.sup_data_types_set else None,
"speaker_id": torch.stack(speaker_ids) if SpeakerID in self.sup_data_types_set else None,
"voiced_mask": torch.stack(voiced_masks) if Voiced_mask in self.sup_data_types_set else None,
"p_voiced": torch.stack(p_voiceds) if P_voiced in self.sup_data_types_set else None,
}

return data_dict
Expand Down Expand Up @@ -689,6 +741,8 @@ def __getitem__(self, index):
energy,
energy_length,
speaker_id,
voiced_mask,
p_voiced,
) = super().__getitem__(index)

lm_tokens = None
Expand All @@ -709,6 +763,8 @@ def __getitem__(self, index):
energy,
energy_length,
speaker_id,
voiced_mask,
p_voiced,
lm_tokens,
)

Expand Down
Loading

0 comments on commit 68b2fa3

Please sign in to comment.