diff --git a/Jenkinsfile b/Jenkinsfile index 4567a7f15399..8c50f0de997e 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -4509,4 +4509,4 @@ assert_frame_equal(training_curve, gt_curve, rtol=1e-3, atol=1e-3)"''' cleanWs() } } -} +} \ No newline at end of file diff --git a/examples/tts/conf/vits.yaml b/examples/tts/conf/vits.yaml new file mode 100644 index 000000000000..1002bdfe89f5 --- /dev/null +++ b/examples/tts/conf/vits.yaml @@ -0,0 +1,215 @@ +# This config contains the default values for training VITS model on LJSpeech dataset. +# If you want to train model on other dataset, you can change config values according to your dataset. +# Most dataset-specific arguments are in the head of the config file, see below. + +# TODO: remove unnecessary arguments, refactoring + +name: VITS + +train_dataset: ??? +validation_datasets: ??? +sup_data_path: null +sup_data_types: null + +phoneme_dict_path: "scripts/tts_dataset_files/ipa_cmudict-0.7b_nv22.10.txt" +heteronyms_path: "scripts/tts_dataset_files/heteronyms-052722" +whitelist_path: "nemo_text_processing/text_normalization/en/data/whitelist/lj_speech.tsv" + +# Default values from librosa.pyin +pitch_fmin: 65.40639132514966 +pitch_fmax: 2093.004522404789 + +sample_rate: 22050 +n_mel_channels: 80 +n_window_size: 1024 +n_window_stride: 256 +n_fft: 1024 +lowfreq: 0 +highfreq: null +window: hann + +model: + pitch_fmin: ${pitch_fmin} + pitch_fmax: ${pitch_fmax} + + sample_rate: ${sample_rate} + n_mel_channels: ${n_mel_channels} + n_window_size: ${n_window_size} + n_window_stride: ${n_window_stride} + n_fft: ${n_fft} + lowfreq: ${lowfreq} + highfreq: ${highfreq} + window: ${window} + mel_fmin: 0.0 + mel_fmax: null + + n_speakers: 0 + segment_size: 8192 + c_mel: 45 + c_kl: 1. + use_spectral_norm: false + + text_normalizer: + _target_: nemo_text_processing.text_normalization.normalize.Normalizer + lang: en + input_case: cased + whitelist: ${whitelist_path} + + text_normalizer_call_kwargs: + verbose: false + punct_pre_process: true + punct_post_process: true + + text_tokenizer: + _target_: nemo.collections.common.tokenizers.text_to_speech.tts_tokenizers.IPATokenizer + punct: true + apostrophe: true + pad_with_space: false + g2p: + _target_: nemo_text_processing.g2p.modules.IPAG2P + phoneme_dict: ${phoneme_dict_path} + heteronyms: ${heteronyms_path} + phoneme_probability: 0.8 + # Relies on the heteronyms list for anything that needs to be disambiguated + ignore_ambiguous_words: false + use_chars: true + use_stresses: true + + train_ds: + dataset: + _target_: "nemo.collections.tts.torch.data.TTSDataset" + manifest_filepath: ${train_dataset} + sample_rate: ${model.sample_rate} + sup_data_path: ${sup_data_path} + sup_data_types: ${sup_data_types} + n_fft: ${model.n_fft} + win_length: ${model.n_window_size} + hop_length: ${model.n_window_stride} + window: ${model.window} + n_mels: ${model.n_mel_channels} + lowfreq: ${model.lowfreq} + highfreq: ${model.highfreq} + max_duration: null + min_duration: 0.1 + ignore_file: null + trim: False + pitch_fmin: ${model.pitch_fmin} + pitch_fmax: ${model.pitch_fmax} + + dataloader_params: + num_workers: 8 + pin_memory: false + + batch_sampler: + batch_size: 32 + boundaries: [32,300,400,500,600,700,800,900,1000] + num_replicas: ${trainer.devices} + shuffle: true + + validation_ds: + dataset: + _target_: "nemo.collections.tts.torch.data.TTSDataset" + manifest_filepath: ${validation_datasets} + sample_rate: ${model.sample_rate} + sup_data_path: ${sup_data_path} + sup_data_types: ${sup_data_types} + n_fft: ${model.n_fft} + win_length: ${model.n_window_size} + hop_length: ${model.n_window_stride} + window: ${model.window} + n_mels: ${model.n_mel_channels} + lowfreq: ${model.lowfreq} + highfreq: ${model.highfreq} + max_duration: null + min_duration: 0.1 + ignore_file: null + trim: False + pitch_fmin: ${model.pitch_fmin} + pitch_fmax: ${model.pitch_fmax} + + dataloader_params: + drop_last: false + shuffle: false + batch_size: 16 + num_workers: 4 + pin_memory: false + + preprocessor: + _target_: nemo.collections.asr.parts.preprocessing.features.FilterbankFeatures + nfilt: ${model.n_mel_channels} + highfreq: ${model.highfreq} + log: true + log_zero_guard_type: clamp + log_zero_guard_value: 1e-05 + lowfreq: ${model.lowfreq} + n_fft: ${model.n_fft} + n_window_size: ${model.n_window_size} + n_window_stride: ${model.n_window_stride} + pad_to: 1 + pad_value: 0 + sample_rate: ${model.sample_rate} + window: ${model.window} + normalize: null + preemph: null + dither: 0.0 + frame_splicing: 1 + stft_conv: false + nb_augmentation_prob : 0 + mag_power: 1.0 + exact_pad: true + use_grads: true + + synthesizer: + _target_: nemo.collections.tts.modules.vits_modules.SynthesizerTrn + inter_channels: 192 + hidden_channels: 192 + filter_channels: 768 + n_heads: 2 + n_layers: 6 + kernel_size: 3 + p_dropout: 0.1 + resblock: "1" + resblock_kernel_sizes: [3,7,11] + resblock_dilation_sizes: [[1,3,5], [1,3,5], [1,3,5]] + upsample_rates: [8,8,2,2] + upsample_initial_channel: 512 + upsample_kernel_sizes: [16,16,4,4] + n_speakers: ${model.n_speakers} + gin_channels: 256 # for multi-speaker + + optim: + _target_: torch.optim.AdamW + lr: 2e-4 + betas: [0.9, 0.99] + eps: 1e-9 + + sched: + name: ExponentialLR + lr_decay: 0.999875 + +trainer: + num_nodes: 1 + devices: 2 + accelerator: gpu + strategy: ddp + precision: 32 + # amp_backend: 'apex' + # amp_level: 'O2' + # benchmark: true + max_epochs: -1 + accumulate_grad_batches: 1 + enable_checkpointing: false # Provided by exp_manager + logger: false # Provided by exp_manager + log_every_n_steps: 50 + check_val_every_n_epoch: 1 + +exp_manager: + exp_dir: ??? + name: ${name} + create_tensorboard_logger: true + create_checkpoint_callback: true + checkpoint_callback_params: + monitor: loss_gen_all + mode: min + resume_if_exists: false + resume_ignore_no_checkpoint: false diff --git a/examples/tts/conf/vits_44100.yaml b/examples/tts/conf/vits_44100.yaml new file mode 100644 index 000000000000..c9955e70abce --- /dev/null +++ b/examples/tts/conf/vits_44100.yaml @@ -0,0 +1,211 @@ +# This config contains the default values for training VITS model on LJSpeech dataset. +# If you want to train model on other dataset, you can change config values according to your dataset. +# Most dataset-specific arguments are in the head of the config file, see below. + +name: VITS + +train_dataset: ??? +validation_datasets: ??? +sup_data_path: ??? +sup_data_types: [speaker_id] + +pitch_fmin: 65.40639132514966 +pitch_fmax: 2093.004522404789 + +sample_rate: 44100 +n_mel_channels: 80 +n_window_size: 2048 +n_window_stride: 512 +n_fft: 2048 +lowfreq: 0 +highfreq: null +window: hann + +phoneme_dict_path: "scripts/tts_dataset_files/ipa_cmudict-0.7b_nv22.10.txt" +heteronyms_path: "scripts/tts_dataset_files/heteronyms-052722" +whitelist_path: "nemo_text_processing/text_normalization/en/data/whitelist/lj_speech.tsv" + +model: + n_speakers: 13000 + segment_size: 16384 + c_mel: 45 + c_kl: 1. + use_spectral_norm: false + + pitch_fmin: ${pitch_fmin} + pitch_fmax: ${pitch_fmax} + + sample_rate: ${sample_rate} + n_mel_channels: ${n_mel_channels} + n_window_size: ${n_window_size} + n_window_stride: ${n_window_stride} + n_fft: ${n_fft} + lowfreq: ${lowfreq} + highfreq: ${highfreq} + window: ${window} + + text_normalizer: + _target_: nemo_text_processing.text_normalization.normalize.Normalizer + lang: en + input_case: cased + whitelist: ${whitelist_path} + + text_normalizer_call_kwargs: + verbose: false + punct_pre_process: true + punct_post_process: true + + text_tokenizer: + _target_: nemo.collections.common.tokenizers.text_to_speech.tts_tokenizers.IPATokenizer + punct: true + apostrophe: true + pad_with_space: false + g2p: + _target_: nemo_text_processing.g2p.modules.IPAG2P + phoneme_dict: ${phoneme_dict_path} + heteronyms: ${heteronyms_path} + phoneme_probability: 0.8 + # Relies on the heteronyms list for anything that needs to be disambiguated + ignore_ambiguous_words: false + use_chars: true + use_stresses: true + + train_ds: + dataset: + _target_: "nemo.collections.tts.torch.data.TTSDataset" + manifest_filepath: ${train_dataset} + sample_rate: ${model.sample_rate} + sup_data_path: ${sup_data_path} + sup_data_types: ${sup_data_types} + n_fft: ${model.n_fft} + win_length: ${model.n_window_size} + hop_length: ${model.n_window_stride} + window: ${model.window} + n_mels: ${model.n_mel_channels} + lowfreq: ${model.lowfreq} + highfreq: ${model.highfreq} + max_duration: null + min_duration: 0.1 + ignore_file: null + trim: False + pitch_fmin: ${model.pitch_fmin} + pitch_fmax: ${model.pitch_fmax} + + dataloader_params: + num_workers: 8 + pin_memory: false + + batch_sampler: + batch_size: 32 + boundaries: [32,300,400,500,600,700,800,900,1000] + num_replicas: ${trainer.devices} + shuffle: true + + validation_ds: + dataset: + _target_: "nemo.collections.tts.torch.data.TTSDataset" + manifest_filepath: ${validation_datasets} + sample_rate: ${model.sample_rate} + sup_data_path: ${sup_data_path} + sup_data_types: ${sup_data_types} + n_fft: ${model.n_fft} + win_length: ${model.n_window_size} + hop_length: ${model.n_window_stride} + window: ${model.window} + n_mels: ${model.n_mel_channels} + lowfreq: ${model.lowfreq} + highfreq: ${model.highfreq} + max_duration: null + min_duration: 0.1 + ignore_file: null + trim: False + pitch_fmin: ${model.pitch_fmin} + pitch_fmax: ${model.pitch_fmax} + + dataloader_params: + drop_last: false + shuffle: false + batch_size: 32 + num_workers: 4 + pin_memory: false + + preprocessor: + _target_: nemo.collections.asr.parts.preprocessing.features.FilterbankFeatures + nfilt: ${model.n_mel_channels} + highfreq: ${model.highfreq} + log: true + log_zero_guard_type: clamp + log_zero_guard_value: 1e-05 + lowfreq: ${model.lowfreq} + n_fft: ${model.n_fft} + n_window_size: ${model.n_window_size} + n_window_stride: ${model.n_window_stride} + pad_to: 1 + pad_value: 0 + sample_rate: ${model.sample_rate} + window: ${model.window} + normalize: null + preemph: null + dither: 0.0 + frame_splicing: 1 + stft_conv: false + nb_augmentation_prob : 0 + mag_power: 1.0 + exact_pad: true + use_grads: true + + synthesizer: + _target_: nemo.collections.tts.modules.vits_modules.SynthesizerTrn + inter_channels: 192 + hidden_channels: 192 + filter_channels: 768 + n_heads: 2 + n_layers: 6 + kernel_size: 3 + p_dropout: 0.1 + resblock: "1" + resblock_kernel_sizes: [3,7,11] + resblock_dilation_sizes: [[1,3,5], [1,3,5], [1,3,5]] + upsample_rates: [8,8,4,2] + upsample_initial_channel: 512 + upsample_kernel_sizes: [16,16,4,4] + n_speakers: ${model.n_speakers} + gin_channels: 256 # for multi-speaker + + optim: + _target_: torch.optim.AdamW + lr: 2e-4 + betas: [0.9, 0.99] + eps: 1e-9 + + sched: + name: CosineAnnealing + max_steps: 1000000 + min_lr: 1e-5 + +trainer: + num_nodes: 1 + devices: 2 + accelerator: gpu + strategy: ddp + precision: 32 + # amp_backend: 'apex' + # amp_level: 'O2' + # benchmark: true + max_epochs: -1 + accumulate_grad_batches: 1 + enable_checkpointing: false # Provided by exp_manager + logger: false # Provided by exp_manager + log_every_n_steps: 50 + check_val_every_n_epoch: 1 + +exp_manager: + exp_dir: ??? + name: ${name} + create_tensorboard_logger: true + create_checkpoint_callback: true + checkpoint_callback_params: + monitor: loss_gen_all + mode: min + resume_if_exists: false + resume_ignore_no_checkpoint: false diff --git a/examples/tts/vits.py b/examples/tts/vits.py new file mode 100644 index 000000000000..ac966900ba47 --- /dev/null +++ b/examples/tts/vits.py @@ -0,0 +1,34 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytorch_lightning as pl + +from nemo.collections.common.callbacks import LogEpochTimeCallback +from nemo.collections.tts.models.vits import VitsModel +from nemo.core.config import hydra_runner +from nemo.utils.exp_manager import exp_manager + + +@hydra_runner(config_path="conf", config_name="vits") +def main(cfg): + trainer = pl.Trainer(replace_sampler_ddp=False, **cfg.trainer) + exp_manager(trainer, cfg.get("exp_manager", None)) + model = VitsModel(cfg=cfg.model, trainer=trainer) + + trainer.callbacks.extend([pl.callbacks.LearningRateMonitor(), LogEpochTimeCallback()]) + trainer.fit(model) + + +if __name__ == '__main__': + main() # noqa pylint: disable=no-value-for-parameter diff --git a/nemo/collections/asr/parts/preprocessing/features.py b/nemo/collections/asr/parts/preprocessing/features.py index 62e6969a0dff..b8cf3b07859d 100644 --- a/nemo/collections/asr/parts/preprocessing/features.py +++ b/nemo/collections/asr/parts/preprocessing/features.py @@ -373,7 +373,7 @@ def get_seq_len(self, seq_len): def filter_banks(self): return self.fb - def forward(self, x, seq_len): + def forward(self, x, seq_len, linear_spec=False): seq_len = self.get_seq_len(seq_len.float()) if self.stft_pad_amount is not None: @@ -408,9 +408,12 @@ def forward(self, x, seq_len): if self.mag_power != 1.0: x = x.pow(self.mag_power) + # return plain spectrogram if required + if linear_spec: + return x, seq_len + # dot with filterbank energies x = torch.matmul(self.fb.to(x.dtype), x) - # log features if required if self.log: if self.log_zero_guard_type == "add": diff --git a/nemo/collections/common/tokenizers/text_to_speech/ipa_lexicon.py b/nemo/collections/common/tokenizers/text_to_speech/ipa_lexicon.py index 385f30ee55c9..746c783bd1b6 100644 --- a/nemo/collections/common/tokenizers/text_to_speech/ipa_lexicon.py +++ b/nemo/collections/common/tokenizers/text_to_speech/ipa_lexicon.py @@ -23,6 +23,13 @@ ')', '[', ']', '{', '}', ) +VITS_PUNCTUATION = ( + ',', '.', '!', '?', '-', + ':', ';', '"', '«', '»', + '“', '”', '¡', '¿', '—', + '…', +) + GRAPHEME_CHARACTER_SETS = { "en-US": ( 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', diff --git a/nemo/collections/tts/helpers/helpers.py b/nemo/collections/tts/helpers/helpers.py index 554f7c0d2cab..561a9e0564c6 100644 --- a/nemo/collections/tts/helpers/helpers.py +++ b/nemo/collections/tts/helpers/helpers.py @@ -586,6 +586,80 @@ def split_view(tensor, split_size: int, dim: int = 0): return tensor.reshape(*new_shape) +def slice_segments(x, ids_str, segment_size=4): + """ + Time-wise slicing (patching) of bathches for audio/spectrogram + [B x C x T] -> [B x C x segment_size] + """ + ret = torch.zeros_like(x[:, :, :segment_size]) + for i in range(x.size(0)): + idx_str = ids_str[i] + idx_end = idx_str + segment_size + x_i = x[i] + if idx_end >= x.size(2): + # pad the sample if it is shorter than the segment size + x_i = torch.nn.functional.pad(x_i, (0, (idx_end + 1) - x.size(2))) + ret[i] = x_i[:, idx_str:idx_end] + return ret + + +def rand_slice_segments(x, x_lengths=None, segment_size=4): + """ + Chooses random indices and slices segments from batch + [B x C x T] -> [B x C x segment_size] + """ + b, d, t = x.size() + if x_lengths is None: + x_lengths = t + ids_str_max = x_lengths - segment_size + 1 + ids_str_max = ids_str_max.to(device=x.device) + ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long) + + ret = slice_segments(x, ids_str, segment_size) + + return ret, ids_str + + +def clip_grad_value_(parameters, clip_value, norm_type=2): + if isinstance(parameters, torch.Tensor): + parameters = [parameters] + parameters = list(filter(lambda p: p.grad is not None, parameters)) + norm_type = float(norm_type) + if clip_value is not None: + clip_value = float(clip_value) + + total_norm = 0 + for p in parameters: + param_norm = p.grad.data.norm(norm_type) + total_norm += param_norm.item() ** norm_type + if clip_value is not None: + p.grad.data.clamp_(min=-clip_value, max=clip_value) + total_norm = total_norm ** (1.0 / norm_type) + return total_norm + + +def convert_pad_shape(pad_shape): + l = pad_shape[::-1] + pad_shape = [item for sublist in l for item in sublist] + return pad_shape + + +def generate_path(duration, mask): + """ + duration: [b, 1, t_x] + mask: [b, 1, t_y, t_x] + """ + b, _, t_y, t_x = mask.shape + cum_duration = torch.cumsum(duration, -1) + + cum_duration_flat = cum_duration.view(b * t_x) + path = get_mask_from_lengths(cum_duration_flat, torch.Tensor(t_y).reshape(1, 1, -1)).to(mask.dtype) + path = path.view(b, t_x, t_y) + path = path - torch.nn.functional.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1] + path = path.unsqueeze(1).transpose(2, 3) * mask + return path + + def process_batch(batch_data, sup_data_types_set): batch_dict = {} batch_index = 0 diff --git a/nemo/collections/tts/helpers/splines.py b/nemo/collections/tts/helpers/splines.py index e697f0671200..a4494efa2b0e 100644 --- a/nemo/collections/tts/helpers/splines.py +++ b/nemo/collections/tts/helpers/splines.py @@ -18,7 +18,7 @@ # The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. - +import numpy as np import torch import torch.nn.functional as F @@ -288,3 +288,198 @@ def piecewise_quadratic_transform(x, w_tilde, v_tilde, inverse=False): # make sure it falls into [0,1) inv = inv.clamp(min=torch.finfo(c.dtype).eps, max=1.0 - torch.finfo(inv.dtype).eps) return inv, None + + +def piecewise_rational_quadratic_transform( + inputs, + unnormalized_widths, + unnormalized_heights, + unnormalized_derivatives, + inverse=False, + tails=None, + tail_bound=1.0, + min_bin_width=1e-3, + min_bin_height=1e-3, + min_derivative=1e-3, +): + + if tails is None: + spline_fn = rational_quadratic_spline + spline_kwargs = {} + else: + spline_fn = unconstrained_rational_quadratic_spline + spline_kwargs = {'tails': tails, 'tail_bound': tail_bound} + + outputs, logabsdet = spline_fn( + inputs=inputs, + unnormalized_widths=unnormalized_widths, + unnormalized_heights=unnormalized_heights, + unnormalized_derivatives=unnormalized_derivatives, + inverse=inverse, + min_bin_width=min_bin_width, + min_bin_height=min_bin_height, + min_derivative=min_derivative, + **spline_kwargs + ) + return outputs, logabsdet + + +def searchsorted(bin_locations, inputs, eps=1e-6): + bin_locations[..., -1] += eps + return torch.sum(inputs[..., None] >= bin_locations, dim=-1) - 1 + + +def unconstrained_rational_quadratic_spline( + inputs, + unnormalized_widths, + unnormalized_heights, + unnormalized_derivatives, + inverse=False, + tails='linear', + tail_bound=1.0, + min_bin_width=1e-3, + min_bin_height=1e-3, + min_derivative=1e-3, +): + inside_interval_mask = (inputs >= -tail_bound) & (inputs <= tail_bound) + outside_interval_mask = ~inside_interval_mask + + outputs = torch.zeros_like(inputs) + logabsdet = torch.zeros_like(inputs) + + if tails == 'linear': + unnormalized_derivatives = F.pad(unnormalized_derivatives, pad=(1, 1)) + constant = np.log(np.exp(1 - min_derivative) - 1) + unnormalized_derivatives[..., 0] = constant + unnormalized_derivatives[..., -1] = constant + + outputs[outside_interval_mask] = inputs[outside_interval_mask] + logabsdet[outside_interval_mask] = 0 + else: + raise RuntimeError('{} tails are not implemented.'.format(tails)) + + outputs[inside_interval_mask], logabsdet[inside_interval_mask] = rational_quadratic_spline( + inputs=inputs[inside_interval_mask], + unnormalized_widths=unnormalized_widths[inside_interval_mask, :], + unnormalized_heights=unnormalized_heights[inside_interval_mask, :], + unnormalized_derivatives=unnormalized_derivatives[inside_interval_mask, :], + inverse=inverse, + left=-tail_bound, + right=tail_bound, + bottom=-tail_bound, + top=tail_bound, + min_bin_width=min_bin_width, + min_bin_height=min_bin_height, + min_derivative=min_derivative, + ) + + return outputs, logabsdet + + +def rational_quadratic_spline( + inputs, + unnormalized_widths, + unnormalized_heights, + unnormalized_derivatives, + inverse=False, + left=0.0, + right=1.0, + bottom=0.0, + top=1.0, + min_bin_width=1e-3, + min_bin_height=1e-3, + min_derivative=1e-3, +): + + if torch.min(inputs) < left or torch.max(inputs) > right: + raise ValueError('Input to a transform is not within its domain') + + num_bins = unnormalized_widths.shape[-1] + + if min_bin_width * num_bins > 1.0: + raise ValueError('Minimal bin width too large for the number of bins') + if min_bin_height * num_bins > 1.0: + raise ValueError('Minimal bin height too large for the number of bins') + + widths = F.softmax(unnormalized_widths, dim=-1) + widths = min_bin_width + (1 - min_bin_width * num_bins) * widths + cumwidths = torch.cumsum(widths, dim=-1) + cumwidths = F.pad(cumwidths, pad=(1, 0), mode='constant', value=0.0) + cumwidths = (right - left) * cumwidths + left + cumwidths[..., 0] = left + cumwidths[..., -1] = right + widths = cumwidths[..., 1:] - cumwidths[..., :-1] + + derivatives = min_derivative + F.softplus(unnormalized_derivatives) + + heights = F.softmax(unnormalized_heights, dim=-1) + heights = min_bin_height + (1 - min_bin_height * num_bins) * heights + cumheights = torch.cumsum(heights, dim=-1) + cumheights = F.pad(cumheights, pad=(1, 0), mode='constant', value=0.0) + cumheights = (top - bottom) * cumheights + bottom + cumheights[..., 0] = bottom + cumheights[..., -1] = top + heights = cumheights[..., 1:] - cumheights[..., :-1] + + if inverse: + bin_idx = searchsorted(cumheights, inputs)[..., None] + else: + bin_idx = searchsorted(cumwidths, inputs)[..., None] + + input_cumwidths = cumwidths.gather(-1, bin_idx)[..., 0] + input_bin_widths = widths.gather(-1, bin_idx)[..., 0] + + input_cumheights = cumheights.gather(-1, bin_idx)[..., 0] + delta = heights / widths + input_delta = delta.gather(-1, bin_idx)[..., 0] + + input_derivatives = derivatives.gather(-1, bin_idx)[..., 0] + input_derivatives_plus_one = derivatives[..., 1:].gather(-1, bin_idx)[..., 0] + + input_heights = heights.gather(-1, bin_idx)[..., 0] + + if inverse: + a = (inputs - input_cumheights) * ( + input_derivatives + input_derivatives_plus_one - 2 * input_delta + ) + input_heights * (input_delta - input_derivatives) + b = input_heights * input_derivatives - (inputs - input_cumheights) * ( + input_derivatives + input_derivatives_plus_one - 2 * input_delta + ) + c = -input_delta * (inputs - input_cumheights) + + discriminant = b.pow(2) - 4 * a * c + assert (discriminant >= 0).all() + + root = (2 * c) / (-b - torch.sqrt(discriminant)) + outputs = root * input_bin_widths + input_cumwidths + + theta_one_minus_theta = root * (1 - root) + denominator = input_delta + ( + (input_derivatives + input_derivatives_plus_one - 2 * input_delta) * theta_one_minus_theta + ) + derivative_numerator = input_delta.pow(2) * ( + input_derivatives_plus_one * root.pow(2) + + 2 * input_delta * theta_one_minus_theta + + input_derivatives * (1 - root).pow(2) + ) + logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator) + + return outputs, -logabsdet + else: + theta = (inputs - input_cumwidths) / input_bin_widths + theta_one_minus_theta = theta * (1 - theta) + + numerator = input_heights * (input_delta * theta.pow(2) + input_derivatives * theta_one_minus_theta) + denominator = input_delta + ( + (input_derivatives + input_derivatives_plus_one - 2 * input_delta) * theta_one_minus_theta + ) + outputs = input_cumheights + numerator / denominator + + derivative_numerator = input_delta.pow(2) * ( + input_derivatives_plus_one * theta.pow(2) + + 2 * input_delta * theta_one_minus_theta + + input_derivatives * (1 - theta).pow(2) + ) + logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator) + + return outputs, logabsdet diff --git a/nemo/collections/tts/losses/hifigan_losses.py b/nemo/collections/tts/losses/hifigan_losses.py index 1386606b3f84..649f075994d8 100644 --- a/nemo/collections/tts/losses/hifigan_losses.py +++ b/nemo/collections/tts/losses/hifigan_losses.py @@ -35,7 +35,7 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. -# The forward functions onf the following classes are based on code from https://github.com/jik876/hifi-gan: +# The forward functions of the following classes are based on code from https://github.com/jik876/hifi-gan: # FeatureMatchingLoss, DiscriminatorLoss, GeneratorLoss import torch diff --git a/nemo/collections/tts/losses/vits_losses.py b/nemo/collections/tts/losses/vits_losses.py new file mode 100644 index 000000000000..b2945a2aa362 --- /dev/null +++ b/nemo/collections/tts/losses/vits_losses.py @@ -0,0 +1,177 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# MIT License +# +# Copyright (c) 2021 Jaehyeon Kim +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# The forward functions of the following classes are based on code from https://github.com/jaywalnut310/vits: +# KlLoss + +import torch + +from nemo.core.classes import Loss, typecheck +from nemo.core.neural_types.elements import LossType, VoidType +from nemo.core.neural_types.neural_type import NeuralType + + +class KlLoss(Loss): + @property + def input_types(self): + return { + "z_p": [NeuralType(('B', 'D', 'T'), VoidType())], + "logs_q": [NeuralType(('B', 'D', 'T'), VoidType())], + "m_p": [NeuralType(('B', 'D', 'T'), VoidType())], + "logs_p": [NeuralType(('B', 'D', 'T'), VoidType())], + "z_mask": [NeuralType(('B', 'D', 'T'), VoidType())], + } + + @property + def output_types(self): + return { + "loss": NeuralType(elements_type=LossType()), + } + + @typecheck() + def forward(self, z_p, logs_q, m_p, logs_p, z_mask): + """ + z_p: Input distribution + logs_q: LogVariance of target distrubution + m_p: Mean of input distrubution + logs_p: LogVariance of input distrubution + """ + z_p = z_p.float() + logs_q = logs_q.float() + m_p = m_p.float() + logs_p = logs_p.float() + z_mask = z_mask.float() + + kl = logs_p - logs_q - 0.5 + kl += 0.5 * ((z_p - m_p) ** 2) * torch.exp(-2.0 * logs_p) + kl = torch.sum(kl * z_mask) + l = kl / torch.sum(z_mask) + return l + + +class FeatureMatchingLoss(Loss): + """VITS Feature Matching Loss module""" + + @property + def input_types(self): + return { + "fmap_r": [[NeuralType(elements_type=VoidType())]], + "fmap_g": [[NeuralType(elements_type=VoidType())]], + } + + @property + def output_types(self): + return { + "loss": NeuralType(elements_type=LossType()), + } + + @typecheck() + def forward(self, fmap_r, fmap_g): + """ + fmap_r, fmap_g: List[List[Tensor]] + """ + loss = 0 + for dr, dg in zip(fmap_r, fmap_g): + for rl, gl in zip(dr, dg): + rl = rl.float().detach() + gl = gl.float() + loss += torch.mean(torch.abs(rl - gl)) + + return loss * 2 + + +class DiscriminatorLoss(Loss): + """Discriminator Loss module""" + + @property + def input_types(self): + return { + "disc_real_outputs": [NeuralType(('B', 'T'), VoidType())], + "disc_generated_outputs": [NeuralType(('B', 'T'), VoidType())], + } + + @property + def output_types(self): + return { + "loss": NeuralType(elements_type=LossType()), + "real_losses": [NeuralType(elements_type=LossType())], + "fake_losses": [NeuralType(elements_type=LossType())], + } + + @typecheck() + def forward(self, disc_real_outputs, disc_generated_outputs): + r_losses = [] + g_losses = [] + loss = 0 + for i, (dr, dg) in enumerate(zip(disc_real_outputs, disc_generated_outputs)): + dr = dr.float() + dg = dg.float() + r_loss = torch.mean((1 - dr) ** 2) + g_loss = torch.mean(dg ** 2) + loss += r_loss + g_loss + r_losses.append(r_loss.item()) + g_losses.append(g_loss.item()) + + return loss, r_losses, g_losses + + +class GeneratorLoss(Loss): + """Generator Loss module""" + + @property + def input_types(self): + return { + "disc_outputs": [NeuralType(('B', 'T'), VoidType())], + } + + @property + def output_types(self): + return { + "loss": NeuralType(elements_type=LossType()), + "fake_losses": [NeuralType(elements_type=LossType())], + } + + @typecheck() + def forward(self, disc_outputs): + loss = 0 + gen_losses = [] + for dg in disc_outputs: + dg = dg.float() + l = torch.mean((1 - dg) ** 2) + gen_losses.append(l) + loss += l + + return loss, gen_losses diff --git a/nemo/collections/tts/models/__init__.py b/nemo/collections/tts/models/__init__.py index bb0d11aad114..adb93b65e61a 100644 --- a/nemo/collections/tts/models/__init__.py +++ b/nemo/collections/tts/models/__init__.py @@ -20,6 +20,7 @@ from nemo.collections.tts.models.tacotron2 import Tacotron2Model from nemo.collections.tts.models.two_stages import GriffinLimModel, MelPsuedoInverseModel, TwoStagesModel from nemo.collections.tts.models.univnet import UnivNetModel +from nemo.collections.tts.models.vits import VitsModel from nemo.collections.tts.models.waveglow import WaveGlowModel __all__ = [ @@ -33,5 +34,6 @@ "Tacotron2Model", "TwoStagesModel", "UnivNetModel", + "VitsModel", "WaveGlowModel", ] diff --git a/nemo/collections/tts/models/base.py b/nemo/collections/tts/models/base.py index 6fe2fb3f8806..754e03550a9b 100644 --- a/nemo/collections/tts/models/base.py +++ b/nemo/collections/tts/models/base.py @@ -233,3 +233,39 @@ def list_available_models(cls) -> 'List[PretrainedModelInfo]': if subclass_models is not None and len(subclass_models) > 0: list_of_models.extend(subclass_models) return list_of_models + + +class TextToWaveform(ModelPT, ABC): + """ Base class for all end-to-end TTS models that generate a waveform from text """ + + @abstractmethod + def parse(self, str_input: str, **kwargs) -> 'torch.tensor': + """ + A helper function that accepts a raw python string and turns it into a tensor. The tensor should have 2 + dimensions. The first is the batch, which should be of size 1. The second should represent time. The tensor + should represent either tokenized or embedded text, depending on the model. + """ + + @abstractmethod + def convert_text_to_waveform(self, *, tokens: 'torch.tensor', **kwargs) -> 'List[torch.tensor]': + """ + Accepts a batch of text and returns a list containing a batch of audio + Args: + tokens: A torch tensor representing the text to be converted to speech + Returns: + audio: A list of length batch_size containing torch tensors representing the waveform output + """ + + @classmethod + def list_available_models(cls) -> 'List[PretrainedModelInfo]': + """ + This method returns a list of pre-trained model which can be instantiated directly from NVIDIA's NGC cloud. + Returns: + List of available pre-trained models. + """ + list_of_models = [] + for subclass in cls.__subclasses__(): + subclass_models = subclass.list_available_models() + if subclass_models is not None and len(subclass_models) > 0: + list_of_models.extend(subclass_models) + return list_of_models diff --git a/nemo/collections/tts/models/vits.py b/nemo/collections/tts/models/vits.py new file mode 100644 index 000000000000..d035c6a1b3ac --- /dev/null +++ b/nemo/collections/tts/models/vits.py @@ -0,0 +1,384 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import contextlib + +import omegaconf +import torch +import wandb +from hydra.utils import instantiate +from omegaconf import DictConfig, OmegaConf +from pytorch_lightning import Trainer +from pytorch_lightning.loggers import WandbLogger +from torch.cuda.amp import autocast +from torch.nn import functional as F + +from nemo.collections.tts.helpers.helpers import clip_grad_value_, plot_spectrogram_to_numpy, slice_segments +from nemo.collections.tts.losses.vits_losses import DiscriminatorLoss, FeatureMatchingLoss, GeneratorLoss, KlLoss +from nemo.collections.tts.models.base import TextToWaveform +from nemo.collections.tts.modules.vits_modules import MultiPeriodDiscriminator +from nemo.collections.tts.torch.data import DistributedBucketSampler +from nemo.collections.tts.torch.tts_data_types import SpeakerID +from nemo.core.classes.common import PretrainedModelInfo, typecheck +from nemo.core.neural_types.elements import AudioSignal, FloatType, Index, IntType, TokenIndex +from nemo.core.neural_types.neural_type import NeuralType +from nemo.core.optim.lr_scheduler import CosineAnnealing +from nemo.utils import logging, model_utils +from nemo.utils.decorators.experimental import experimental + +HAVE_WANDB = True +try: + import wandb +except ModuleNotFoundError: + HAVE_WANDB = False + + +@experimental +class VitsModel(TextToWaveform): + def __init__(self, cfg: DictConfig, trainer: 'Trainer' = None): + # Convert to Hydra 1.0 compatible DictConfig + + cfg = model_utils.convert_model_config_to_dict_config(cfg) + cfg = model_utils.maybe_update_config_version(cfg) + + # setup normalizer + self.normalizer = None + self.text_normalizer_call = None + self.text_normalizer_call_kwargs = {} + self._setup_normalizer(cfg) + + # setup tokenizer + self.tokenizer = None + self._setup_tokenizer(cfg) + assert self.tokenizer is not None + + num_tokens = len(self.tokenizer.tokens) + self.tokenizer_pad = self.tokenizer.pad + + super().__init__(cfg=cfg, trainer=trainer) + + self.audio_to_melspec_processor = instantiate(cfg.preprocessor, highfreq=cfg.train_ds.dataset.highfreq) + + self.feat_matching_loss = FeatureMatchingLoss() + self.disc_loss = DiscriminatorLoss() + self.gen_loss = GeneratorLoss() + self.kl_loss = KlLoss() + + self.net_g = instantiate( + cfg.synthesizer, + n_vocab=num_tokens, + spec_channels=cfg.n_fft // 2 + 1, + segment_size=cfg.segment_size // cfg.n_window_stride, + padding_idx=self.tokenizer_pad, + ) + + self.net_d = MultiPeriodDiscriminator(cfg.use_spectral_norm) + + self.automatic_optimization = False + + def _setup_normalizer(self, cfg): + if "text_normalizer" in cfg: + normalizer_kwargs = {} + + if "whitelist" in cfg.text_normalizer: + normalizer_kwargs["whitelist"] = self.register_artifact( + 'text_normalizer.whitelist', cfg.text_normalizer.whitelist + ) + + self.normalizer = instantiate(cfg.text_normalizer, **normalizer_kwargs) + self.text_normalizer_call = self.normalizer.normalize + if "text_normalizer_call_kwargs" in cfg: + self.text_normalizer_call_kwargs = cfg.text_normalizer_call_kwargs + + def _setup_tokenizer(self, cfg): + text_tokenizer_kwargs = {} + if "g2p" in cfg.text_tokenizer and cfg.text_tokenizer.g2p is not None: + g2p_kwargs = {} + + if "phoneme_dict" in cfg.text_tokenizer.g2p: + g2p_kwargs["phoneme_dict"] = self.register_artifact( + 'text_tokenizer.g2p.phoneme_dict', cfg.text_tokenizer.g2p.phoneme_dict, + ) + + if "heteronyms" in cfg.text_tokenizer.g2p: + g2p_kwargs["heteronyms"] = self.register_artifact( + 'text_tokenizer.g2p.heteronyms', cfg.text_tokenizer.g2p.heteronyms, + ) + + text_tokenizer_kwargs["g2p"] = instantiate(cfg.text_tokenizer.g2p, **g2p_kwargs) + + self.tokenizer = instantiate(cfg.text_tokenizer, **text_tokenizer_kwargs) + + def parse(self, text: str, normalize=True) -> torch.tensor: + if self.training: + logging.warning("parse() is meant to be called in eval mode.") + if normalize and self.text_normalizer_call is not None: + text = self.text_normalizer_call(text, **self.text_normalizer_call_kwargs) + + eval_phon_mode = contextlib.nullcontext() + if hasattr(self.tokenizer, "set_phone_prob"): + eval_phon_mode = self.tokenizer.set_phone_prob(prob=1.0) + + with eval_phon_mode: + tokens = self.tokenizer.encode(text) + + return torch.tensor(tokens).long().unsqueeze(0).to(self.device) + + def configure_optimizers(self): + optim_config = self._cfg.optim.copy() + OmegaConf.set_struct(optim_config, False) + sched_config = optim_config.pop("sched", None) + OmegaConf.set_struct(optim_config, True) + + optim_g = instantiate(optim_config, params=self.net_g.parameters(),) + optim_d = instantiate(optim_config, params=self.net_d.parameters(),) + + if sched_config is not None: + if sched_config.name == 'ExponentialLR': + scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optim_d, gamma=sched_config.lr_decay) + scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optim_g, gamma=sched_config.lr_decay) + elif sched_config.name == 'CosineAnnealing': + scheduler_g = CosineAnnealing( + optimizer=optim_g, max_steps=sched_config.max_steps, min_lr=sched_config.min_lr, + ) + scheduler_d = CosineAnnealing( + optimizer=optim_d, max_steps=sched_config.max_steps, min_lr=sched_config.min_lr, + ) + else: + raise ValueError("Unknown optimizer.") + + scheduler_g_dict = {'scheduler': scheduler_g, 'interval': 'step'} + scheduler_d_dict = {'scheduler': scheduler_d, 'interval': 'step'} + return [optim_g, optim_d], [scheduler_g_dict, scheduler_d_dict] + else: + return [optim_g, optim_d] + + # for inference + @typecheck( + input_types={ + "tokens": NeuralType(('B', 'T_text'), TokenIndex()), + "speakers": NeuralType(('B',), Index(), optional=True), + "noise_scale": NeuralType(('B',), FloatType(), optional=True), + "length_scale": NeuralType(('B',), FloatType(), optional=True), + "noise_scale_w": NeuralType(('B',), FloatType(), optional=True), + "max_len": NeuralType(('B',), IntType(), optional=True), + } + ) + def forward(self, tokens, speakers=None, noise_scale=1, length_scale=1, noise_scale_w=1.0, max_len=1000): + text_len = torch.tensor([tokens.size(-1)]).to(int).to(tokens.device) + audio_pred, attn, y_mask, (z, z_p, m_p, logs_p) = self.net_g.infer( + tokens, + text_len, + speakers=speakers, + noise_scale=noise_scale, + length_scale=length_scale, + noise_scale_w=noise_scale_w, + max_len=max_len, + ) + return audio_pred, attn, y_mask, (z, z_p, m_p, logs_p) + + def training_step(self, batch, batch_idx): + speakers = None + if SpeakerID in self._train_dl.dataset.sup_data_types_set: + (audio, audio_len, text, text_len, speakers) = batch + else: + (audio, audio_len, text, text_len) = batch + + spec, spec_lengths = self.audio_to_melspec_processor(audio, audio_len, linear_spec=True) + + with autocast(enabled=True): + audio_pred, l_length, attn, ids_slice, text_mask, z_mask, (z, z_p, m_p, logs_p, m_q, logs_q) = self.net_g( + text, text_len, spec, spec_lengths, speakers + ) + + audio_pred = audio_pred.float() + + audio_pred_mel, _ = self.audio_to_melspec_processor(audio_pred.squeeze(1), audio_len, linear_spec=False) + + audio = slice_segments(audio.unsqueeze(1), ids_slice * self.cfg.n_window_stride, self._cfg.segment_size) + audio_mel, _ = self.audio_to_melspec_processor(audio.squeeze(1), audio_len, linear_spec=False) + + with autocast(enabled=True): + y_d_hat_r, y_d_hat_g, _, _ = self.net_d(audio, audio_pred.detach()) + + with autocast(enabled=False): + loss_disc, losses_disc_r, losses_disc_g = self.disc_loss( + disc_real_outputs=y_d_hat_r, disc_generated_outputs=y_d_hat_g + ) + loss_disc_all = loss_disc + + # get optimizers + optim_g, optim_d = self.optimizers() + + # train discriminator + optim_d.zero_grad() + self.manual_backward(loss_disc_all) + norm_d = clip_grad_value_(self.net_d.parameters(), None) + optim_d.step() + + with autocast(enabled=True): + y_d_hat_r, y_d_hat_g, fmap_r, fmap_g = self.net_d(audio, audio_pred) + # Generator + with autocast(enabled=False): + loss_dur = torch.sum(l_length.float()) + loss_mel = F.l1_loss(audio_mel, audio_pred_mel) * self._cfg.c_mel + loss_kl = self.kl_loss(z_p=z_p, logs_q=logs_q, m_p=m_p, logs_p=logs_p, z_mask=z_mask) * self._cfg.c_kl + loss_fm = self.feat_matching_loss(fmap_r=fmap_r, fmap_g=fmap_g) + loss_gen, losses_gen = self.gen_loss(disc_outputs=y_d_hat_g) + loss_gen_all = loss_gen + loss_fm + loss_mel + loss_dur + loss_kl + + # train generator + optim_g.zero_grad() + self.manual_backward(loss_gen_all) + norm_g = clip_grad_value_(self.net_g.parameters(), None) + optim_g.step() + + schedulers = self.lr_schedulers() + if schedulers is not None: + sch1, sch2 = schedulers + if ( + self.trainer.is_last_batch + and isinstance(sch1, torch.optim.lr_scheduler.ExponentialLR) + or isinstance(sch1, CosineAnnealing) + ): + sch1.step() + sch2.step() + + metrics = { + "loss_gen": loss_gen, + "loss_fm": loss_fm, + "loss_mel": loss_mel, + "loss_dur": loss_dur, + "loss_kl": loss_kl, + "loss_gen_all": loss_gen_all, + "loss_disc_all": loss_disc_all, + "grad_gen": norm_g, + "grad_disc": norm_d, + } + + for i, v in enumerate(losses_gen): + metrics[f"loss_gen_i_{i}"] = v + + for i, v in enumerate(losses_disc_r): + metrics[f"loss_disc_r_{i}"] = v + + for i, v in enumerate(losses_disc_g): + metrics[f"loss_disc_g_{i}"] = v + + self.log_dict(metrics, on_step=True, sync_dist=True) + + def validation_step(self, batch, batch_idx): + speakers = None + if self.cfg.n_speakers > 1: + (audio, audio_len, text, text_len, speakers) = batch + else: + (audio, audio_len, text, text_len) = batch + + audio_pred, _, mask, *_ = self.net_g.infer(text, text_len, speakers, max_len=1000) + + audio_pred = audio_pred.squeeze() + audio_pred_len = mask.sum([1, 2]).long() * self._cfg.validation_ds.dataset.hop_length + + mel, mel_lengths = self.audio_to_melspec_processor(audio, audio_len) + audio_pred_mel, audio_pred_mel_len = self.audio_to_melspec_processor(audio_pred, audio_pred_len) + + # plot audio once per epoch + if batch_idx == 0 and isinstance(self.logger, WandbLogger) and HAVE_WANDB: + logger = self.logger.experiment + + specs = [] + audios = [] + specs += [ + wandb.Image( + plot_spectrogram_to_numpy(mel[0, :, : mel_lengths[0]].data.cpu().numpy()), + caption=f"val_mel_target", + ), + wandb.Image( + plot_spectrogram_to_numpy(audio_pred_mel[0, :, : audio_pred_mel_len[0]].data.cpu().numpy()), + caption=f"val_mel_predicted", + ), + ] + + audios += [ + wandb.Audio( + audio[0, : audio_len[0]].data.cpu().to(torch.float).numpy(), + caption=f"val_wav_target", + sample_rate=self._cfg.sample_rate, + ), + wandb.Audio( + audio_pred[0, : audio_pred_len[0]].data.cpu().to(torch.float).numpy(), + caption=f"val_wav_predicted", + sample_rate=self._cfg.sample_rate, + ), + ] + + logger.log({"specs": specs, "audios": audios}) + + def _loader(self, cfg): + try: + _ = cfg['dataset']['manifest_filepath'] + except omegaconf.errors.MissingMandatoryValue: + logging.warning("manifest_filepath was skipped. No dataset for this model.") + return None + + dataset = instantiate( + cfg.dataset, + text_normalizer=self.normalizer, + text_normalizer_call_kwargs=self.text_normalizer_call_kwargs, + text_tokenizer=self.tokenizer, + ) + return torch.utils.data.DataLoader( # noqa + dataset=dataset, collate_fn=dataset.collate_fn, **cfg.dataloader_params, + ) + + def train_dataloader(self): + # default used by the Trainer + dataset = instantiate( + self.cfg.train_ds.dataset, + text_normalizer=self.normalizer, + text_normalizer_call_kwargs=self.text_normalizer_call_kwargs, + text_tokenizer=self.tokenizer, + ) + + train_sampler = DistributedBucketSampler(dataset, **self.cfg.train_ds.batch_sampler) + + dataloader = torch.utils.data.DataLoader( + dataset, collate_fn=dataset.collate_fn, batch_sampler=train_sampler, **self.cfg.train_ds.dataloader_params, + ) + return dataloader + + def setup_training_data(self, cfg): + self._train_dl = self._loader(cfg) + + def setup_validation_data(self, cfg): + self._validation_dl = self._loader(cfg) + + def setup_test_data(self, cfg): + """Omitted.""" + pass + + @classmethod + def list_available_models(cls) -> 'List[PretrainedModelInfo]': + list_of_models = [] + # TODO: List available models?? + return list_of_models + + @typecheck( + input_types={"tokens": NeuralType(('B', 'T_text'), TokenIndex(), optional=True),}, + output_types={"audio": NeuralType(('B', 'T_audio'), AudioSignal())}, + ) + def convert_text_to_waveform(self, *, tokens, speakers=None): + audio = self(tokens=tokens, speakers=speakers)[0].squeeze(1) + return audio diff --git a/nemo/collections/tts/modules/monotonic_align/__init__.py b/nemo/collections/tts/modules/monotonic_align/__init__.py new file mode 100644 index 000000000000..da36a9eccd7e --- /dev/null +++ b/nemo/collections/tts/modules/monotonic_align/__init__.py @@ -0,0 +1,37 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# MIT License +# +# Copyright (c) 2021 Jaehyeon Kim +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + + +from .numba_core import maximum_path diff --git a/nemo/collections/tts/modules/monotonic_align/numba_core.py b/nemo/collections/tts/modules/monotonic_align/numba_core.py new file mode 100644 index 000000000000..e55ea964d6b8 --- /dev/null +++ b/nemo/collections/tts/modules/monotonic_align/numba_core.py @@ -0,0 +1,85 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import numba +import numpy as np +import torch + + +def maximum_path(neg_cent, mask): + """ Numba version. + neg_cent: [b, t_t, t_s] + mask: [b, t_t, t_s] + """ + device = neg_cent.device + dtype = neg_cent.dtype + neg_cent = neg_cent.data.cpu().numpy().astype(np.float32) + path = np.zeros(neg_cent.shape, dtype=np.int32) + + t_t_max = mask.sum(1)[:, 0].data.cpu().numpy().astype(np.int32) + t_s_max = mask.sum(2)[:, 0].data.cpu().numpy().astype(np.int32) + maximum_path_c(path, neg_cent, t_t_max, t_s_max) + return torch.from_numpy(path).to(device=device, dtype=dtype) + + +@numba.jit(nopython=True, boundscheck=False, parallel=True) +def maximum_path_each(path, value, t_y: int, t_x: int, max_neg_val=-1e9): + """ + Args: + path: int32[:, :] + value: float32[:, :] + t_y: int + t_x: int + max_neg_val: float + """ + index: int = t_x - 1 + + for y in range(t_y): + for x in range(max(0, t_x + y - t_y), min(t_x, y + 1)): + if x == y: + v_cur = max_neg_val + else: + v_cur = value[y - 1, x] + if x == 0: + if y == 0: + v_prev = 0.0 + else: + v_prev = max_neg_val + else: + v_prev = value[y - 1, x - 1] + value[y, x] += max(v_prev, v_cur) + + for y in range(t_y - 1, -1, -1): + path[y, index] = 1 + if index != 0 and (index == y or value[y - 1, index] < value[y - 1, index - 1]): + index = index - 1 + + +@numba.jit(nopython=True, boundscheck=False, parallel=True) +def maximum_path_c(paths, values, t_ys, t_xs): + """ + Args: + paths: int32[:, :, :] + values: float32[:, :, :] + t_ys: int[:] + t_xs: int[:] + """ + b: int = paths.shape[0] + for i in numba.prange(b): + maximum_path_each(paths[i], values[i], t_ys[i], t_xs[i]) + + +if __name__ == '__main__': + pass diff --git a/nemo/collections/tts/modules/vits_modules.py b/nemo/collections/tts/modules/vits_modules.py new file mode 100644 index 000000000000..1793f1f10565 --- /dev/null +++ b/nemo/collections/tts/modules/vits_modules.py @@ -0,0 +1,1214 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# MIT License +# +# Copyright (c) 2021 Jaehyeon Kim +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +import math + +import torch +import torch.nn as nn +from torch.nn import functional as F +from torch.nn.utils import remove_weight_norm, spectral_norm, weight_norm + +from nemo.collections.tts.helpers.helpers import ( + convert_pad_shape, + generate_path, + get_mask_from_lengths, + rand_slice_segments, +) +from nemo.collections.tts.helpers.splines import piecewise_rational_quadratic_transform +from nemo.collections.tts.modules.hifigan_modules import ResBlock1, ResBlock2, get_padding, init_weights +from nemo.collections.tts.modules.monotonic_align import maximum_path + +LRELU_SLOPE = 0.1 + + +@torch.jit.script +def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels): + n_channels_int = n_channels[0] + in_act = input_a + input_b + t_act = torch.tanh(in_act[:, :n_channels_int, :]) + s_act = torch.sigmoid(in_act[:, n_channels_int:, :]) + acts = t_act * s_act + return acts + + +class LayerNorm(nn.Module): + def __init__(self, channels, eps=1e-5): + super().__init__() + self.channels = channels + self.eps = eps + + self.gamma = nn.Parameter(torch.ones(channels)) + self.beta = nn.Parameter(torch.zeros(channels)) + + def forward(self, x): + x = x.transpose(1, -1) + x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps) + return x.transpose(1, -1) + + +class ConvReluNorm(nn.Module): + def __init__(self, in_channels, hidden_channels, out_channels, kernel_size, n_layers, p_dropout): + super().__init__() + self.in_channels = in_channels + self.hidden_channels = hidden_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.n_layers = n_layers + self.p_dropout = p_dropout + assert n_layers > 1, "Number of layers should be larger than 0." + + self.conv_layers = nn.ModuleList() + self.norm_layers = nn.ModuleList() + self.conv_layers.append(nn.Conv1d(in_channels, hidden_channels, kernel_size, padding=kernel_size // 2)) + self.norm_layers.append(LayerNorm(hidden_channels)) + self.relu_drop = nn.Sequential(nn.ReLU(), nn.Dropout(p_dropout)) + for _ in range(n_layers - 1): + self.conv_layers.append(nn.Conv1d(hidden_channels, hidden_channels, kernel_size, padding=kernel_size // 2)) + self.norm_layers.append(LayerNorm(hidden_channels)) + self.proj = nn.Conv1d(hidden_channels, out_channels, 1) + self.proj.weight.data.zero_() + self.proj.bias.data.zero_() + + def forward(self, x, x_mask): + x_org = x + for i in range(self.n_layers): + x = self.conv_layers[i](x * x_mask) + x = self.norm_layers[i](x) + x = self.relu_drop(x) + x = x_org + self.proj(x) + return x * x_mask + + +class DDSConv(nn.Module): + """ + Dilated and Depth-Separable Convolution + """ + + def __init__(self, channels, kernel_size, n_layers, p_dropout=0.0): + super().__init__() + self.channels = channels + self.kernel_size = kernel_size + self.n_layers = n_layers + self.p_dropout = p_dropout + + self.drop = nn.Dropout(p_dropout) + self.convs_sep = nn.ModuleList() + self.convs_1x1 = nn.ModuleList() + self.norms_1 = nn.ModuleList() + self.norms_2 = nn.ModuleList() + for i in range(n_layers): + dilation = kernel_size ** i + padding = (kernel_size * dilation - dilation) // 2 + self.convs_sep.append( + nn.Conv1d(channels, channels, kernel_size, groups=channels, dilation=dilation, padding=padding) + ) + self.convs_1x1.append(nn.Conv1d(channels, channels, 1)) + self.norms_1.append(LayerNorm(channels)) + self.norms_2.append(LayerNorm(channels)) + + def forward(self, x, x_mask, g=None): + if g is not None: + x = x + g + for i in range(self.n_layers): + y = self.convs_sep[i](x * x_mask) + y = self.norms_1[i](y) + y = F.gelu(y) + y = self.convs_1x1[i](y) + y = self.norms_2[i](y) + y = F.gelu(y) + y = self.drop(y) + x = x + y + return x * x_mask + + +class WN(torch.nn.Module): + def __init__(self, hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=0, p_dropout=0): + super(WN, self).__init__() + assert kernel_size % 2 == 1 + self.hidden_channels = hidden_channels + self.kernel_size = (kernel_size,) + self.dilation_rate = dilation_rate + self.n_layers = n_layers + self.gin_channels = gin_channels + self.p_dropout = p_dropout + + self.in_layers = torch.nn.ModuleList() + self.res_skip_layers = torch.nn.ModuleList() + self.drop = nn.Dropout(p_dropout) + + if gin_channels != 0: + cond_layer = torch.nn.Conv1d(gin_channels, 2 * hidden_channels * n_layers, 1) + self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name='weight') + + for i in range(n_layers): + dilation = dilation_rate ** i + padding = int((kernel_size * dilation - dilation) / 2) + in_layer = torch.nn.Conv1d( + hidden_channels, 2 * hidden_channels, kernel_size, dilation=dilation, padding=padding + ) + in_layer = torch.nn.utils.weight_norm(in_layer, name='weight') + self.in_layers.append(in_layer) + + # last one is not necessary + if i < n_layers - 1: + res_skip_channels = 2 * hidden_channels + else: + res_skip_channels = hidden_channels + + res_skip_layer = torch.nn.Conv1d(hidden_channels, res_skip_channels, 1) + res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name='weight') + self.res_skip_layers.append(res_skip_layer) + + def forward(self, x, x_mask, g=None, **kwargs): + output = torch.zeros_like(x) + n_channels_tensor = torch.IntTensor([self.hidden_channels]) + + if g is not None: + g = self.cond_layer(g) + + for i in range(self.n_layers): + x_in = self.in_layers[i](x) + if g is not None: + cond_offset = i * 2 * self.hidden_channels + g_l = g[:, cond_offset : cond_offset + 2 * self.hidden_channels, :] + else: + g_l = torch.zeros_like(x_in) + + acts = fused_add_tanh_sigmoid_multiply(x_in, g_l, n_channels_tensor) + acts = self.drop(acts) + + res_skip_acts = self.res_skip_layers[i](acts) + if i < self.n_layers - 1: + res_acts = res_skip_acts[:, : self.hidden_channels, :] + x = (x + res_acts) * x_mask + output = output + res_skip_acts[:, self.hidden_channels :, :] + else: + output = output + res_skip_acts + return output * x_mask + + def remove_weight_norm(self): + if self.gin_channels != 0: + torch.nn.utils.remove_weight_norm(self.cond_layer) + for l in self.in_layers: + torch.nn.utils.remove_weight_norm(l) + for l in self.res_skip_layers: + torch.nn.utils.remove_weight_norm(l) + + +class Log(nn.Module): + def forward(self, x, x_mask, reverse=False, **kwargs): + if not reverse: + y = torch.log(torch.clamp_min(x, 1e-5)) * x_mask + logdet = torch.sum(-y, [1, 2]) + return y, logdet + else: + x = torch.exp(x) * x_mask + return x + + +class Flip(nn.Module): + def forward(self, x, *args, reverse=False, **kwargs): + x = torch.flip(x, [1]) + if not reverse: + logdet = torch.zeros(x.size(0)).to(dtype=x.dtype, device=x.device) + return x, logdet + else: + return x + + +class ElementwiseAffine(nn.Module): + def __init__(self, channels): + super().__init__() + self.channels = channels + self.m = nn.Parameter(torch.zeros(channels, 1)) + self.logs = nn.Parameter(torch.zeros(channels, 1)) + + def forward(self, x, x_mask, reverse=False, **kwargs): + if not reverse: + y = self.m + torch.exp(self.logs) * x + y = y * x_mask + logdet = torch.sum(self.logs * x_mask, [1, 2]) + return y, logdet + else: + x = (x - self.m) * torch.exp(-self.logs) * x_mask + return x + + +class ResidualCouplingLayer(nn.Module): + def __init__( + self, + channels, + hidden_channels, + kernel_size, + dilation_rate, + n_layers, + p_dropout=0, + gin_channels=0, + mean_only=False, + ): + assert channels % 2 == 0, "channels should be divisible by 2" + super().__init__() + self.channels = channels + self.hidden_channels = hidden_channels + self.kernel_size = kernel_size + self.dilation_rate = dilation_rate + self.n_layers = n_layers + self.half_channels = channels // 2 + self.mean_only = mean_only + + self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1) + self.enc = WN( + hidden_channels, kernel_size, dilation_rate, n_layers, p_dropout=p_dropout, gin_channels=gin_channels + ) + self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1) + self.post.weight.data.zero_() + self.post.bias.data.zero_() + + def forward(self, x, x_mask, g=None, reverse=False): + x0, x1 = torch.split(x, [self.half_channels] * 2, 1) + h = self.pre(x0) * x_mask + h = self.enc(h, x_mask, g=g) + stats = self.post(h) * x_mask + if not self.mean_only: + m, logs = torch.split(stats, [self.half_channels] * 2, 1) + else: + m = stats + logs = torch.zeros_like(m) + + if not reverse: + x1 = m + x1 * torch.exp(logs) * x_mask + x = torch.cat([x0, x1], 1) + logdet = torch.sum(logs, [1, 2]) + return x, logdet + else: + x1 = (x1 - m) * torch.exp(-logs) * x_mask + x = torch.cat([x0, x1], 1) + return x + + +class ConvFlow(nn.Module): + def __init__(self, in_channels, filter_channels, kernel_size, n_layers, num_bins=10, tail_bound=5.0): + super().__init__() + self.in_channels = in_channels + self.filter_channels = filter_channels + self.kernel_size = kernel_size + self.n_layers = n_layers + self.num_bins = num_bins + self.tail_bound = tail_bound + self.half_channels = in_channels // 2 + + self.pre = nn.Conv1d(self.half_channels, filter_channels, 1) + self.convs = DDSConv(filter_channels, kernel_size, n_layers, p_dropout=0.0) + self.proj = nn.Conv1d(filter_channels, self.half_channels * (num_bins * 3 - 1), 1) + self.proj.weight.data.zero_() + self.proj.bias.data.zero_() + + def forward(self, x, x_mask, g=None, reverse=False): + x0, x1 = torch.split(x, [self.half_channels] * 2, 1) + h = self.pre(x0) + h = self.convs(h, x_mask, g=g) + h = self.proj(h) * x_mask + + b, c, t = x0.shape + h = h.reshape(b, c, -1, t).permute(0, 1, 3, 2) # [b, cx?, t] -> [b, c, t, ?] + + unnormalized_widths = h[..., : self.num_bins] / math.sqrt(self.filter_channels) + unnormalized_heights = h[..., self.num_bins : 2 * self.num_bins] / math.sqrt(self.filter_channels) + unnormalized_derivatives = h[..., 2 * self.num_bins :] + + x1, logabsdet = piecewise_rational_quadratic_transform( + x1, + unnormalized_widths, + unnormalized_heights, + unnormalized_derivatives, + inverse=reverse, + tails='linear', + tail_bound=self.tail_bound, + ) + + x = torch.cat([x0, x1], 1) * x_mask + logdet = torch.sum(logabsdet * x_mask, [1, 2]) + if not reverse: + return x, logdet + else: + return x + + +class StochasticDurationPredictor(nn.Module): + def __init__(self, in_channels, filter_channels, kernel_size, p_dropout, n_flows=4, gin_channels=0): + super().__init__() + filter_channels = in_channels # it needs to be removed from future version. + self.in_channels = in_channels + self.filter_channels = filter_channels + self.kernel_size = kernel_size + self.p_dropout = p_dropout + self.n_flows = n_flows + self.gin_channels = gin_channels + + self.log_flow = Log() + self.flows = nn.ModuleList() + self.flows.append(ElementwiseAffine(2)) + for i in range(n_flows): + self.flows.append(ConvFlow(2, filter_channels, kernel_size, n_layers=3)) + self.flows.append(Flip()) + + self.post_pre = nn.Conv1d(1, filter_channels, 1) + self.post_proj = nn.Conv1d(filter_channels, filter_channels, 1) + self.post_convs = DDSConv(filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout) + self.post_flows = nn.ModuleList() + self.post_flows.append(ElementwiseAffine(2)) + for i in range(4): + self.post_flows.append(ConvFlow(2, filter_channels, kernel_size, n_layers=3)) + self.post_flows.append(Flip()) + + self.pre = nn.Conv1d(in_channels, filter_channels, 1) + self.proj = nn.Conv1d(filter_channels, filter_channels, 1) + self.convs = DDSConv(filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout) + if gin_channels != 0: + self.cond = nn.Conv1d(gin_channels, filter_channels, 1) + + def forward(self, x, x_mask, w=None, g=None, reverse=False, noise_scale=1.0): + x = torch.detach(x) + x = self.pre(x) + if g is not None: + g = torch.detach(g) + x = x + self.cond(g) + x = self.convs(x, x_mask) + x = self.proj(x) * x_mask + + # torch.manual_seed(1) + # torch.cuda.manual_seed(1) + if not reverse: + flows = self.flows + assert w is not None + + logdet_tot_q = 0 + h_w = self.post_pre(w) + h_w = self.post_convs(h_w, x_mask) + h_w = self.post_proj(h_w) * x_mask + e_q = torch.randn(w.size(0), 2, w.size(2)).to(device=x.device, dtype=x.dtype) * x_mask + z_q = e_q + for flow in self.post_flows: + z_q, logdet_q = flow(z_q, x_mask, g=(x + h_w)) + logdet_tot_q += logdet_q + z_u, z1 = torch.split(z_q, [1, 1], 1) + u = torch.sigmoid(z_u) * x_mask + z0 = (w - u) * x_mask + logdet_tot_q += torch.sum((F.logsigmoid(z_u) + F.logsigmoid(-z_u)) * x_mask, [1, 2]) + logq = torch.sum(-0.5 * (math.log(2 * math.pi) + (e_q ** 2)) * x_mask, [1, 2]) - logdet_tot_q + + logdet_tot = 0 + z0, logdet = self.log_flow(z0, x_mask) + logdet_tot += logdet + z = torch.cat([z0, z1], 1) + for flow in flows: + z, logdet = flow(z, x_mask, g=x, reverse=reverse) + logdet_tot = logdet_tot + logdet + nll = torch.sum(0.5 * (math.log(2 * math.pi) + (z ** 2)) * x_mask, [1, 2]) - logdet_tot + return nll + logq # [b] + else: + flows = list(reversed(self.flows)) + flows = flows[:-2] + [flows[-1]] # remove a useless vflow + z = torch.randn(x.size(0), 2, x.size(2)).to(device=x.device, dtype=x.dtype) * noise_scale + for flow in flows: + z = flow(z, x_mask, g=x, reverse=reverse) + z0, z1 = torch.split(z, [1, 1], 1) + logw = z0 + return logw + + +class DurationPredictor(nn.Module): + def __init__(self, in_channels, filter_channels, kernel_size, p_dropout, gin_channels=0): + super().__init__() + + self.in_channels = in_channels + self.filter_channels = filter_channels + self.kernel_size = kernel_size + self.p_dropout = p_dropout + self.gin_channels = gin_channels + + self.drop = nn.Dropout(p_dropout) + self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size, padding=kernel_size // 2) + self.norm_1 = LayerNorm(filter_channels) + self.conv_2 = nn.Conv1d(filter_channels, filter_channels, kernel_size, padding=kernel_size // 2) + self.norm_2 = LayerNorm(filter_channels) + self.proj = nn.Conv1d(filter_channels, 1, 1) + + if gin_channels != 0: + self.cond = nn.Conv1d(gin_channels, in_channels, 1) + + def forward(self, x, x_mask, g=None): + x = torch.detach(x) + if g is not None: + g = torch.detach(g) + x = x + self.cond(g) + x = self.conv_1(x * x_mask) + x = torch.relu(x) + x = self.norm_1(x) + x = self.drop(x) + x = self.conv_2(x * x_mask) + x = torch.relu(x) + x = self.norm_2(x) + x = self.drop(x) + x = self.proj(x * x_mask) + return x * x_mask + + +class TextEncoder(nn.Module): + def __init__( + self, + n_vocab, + out_channels, + hidden_channels, + filter_channels, + n_heads, + n_layers, + kernel_size, + p_dropout, + padding_idx, + ): + super().__init__() + self.n_vocab = n_vocab + self.out_channels = out_channels + self.hidden_channels = hidden_channels + self.filter_channels = filter_channels + self.n_heads = n_heads + self.n_layers = n_layers + self.kernel_size = kernel_size + self.p_dropout = p_dropout + + self.emb = nn.Embedding(n_vocab, hidden_channels, padding_idx=padding_idx) + nn.init.normal_(self.emb.weight, 0.0, hidden_channels ** -0.5) + + self.encoder = AttentionEncoder(hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout) + self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1) + + def forward(self, x, x_lengths): + x = self.emb(x) * math.sqrt(self.hidden_channels) # [b, t, h] + x = torch.transpose(x, 1, -1) # [b, h, t] + x_mask = torch.unsqueeze(get_mask_from_lengths(x_lengths, x), 1).to(x.dtype) + + x = self.encoder(x * x_mask, x_mask) + stats = self.proj(x) * x_mask + + m, logs = torch.split(stats, self.out_channels, dim=1) + return x, m, logs, x_mask + + +class ResidualCouplingBlock(nn.Module): + def __init__(self, channels, hidden_channels, kernel_size, dilation_rate, n_layers, n_flows=4, gin_channels=0): + super().__init__() + self.channels = channels + self.hidden_channels = hidden_channels + self.kernel_size = kernel_size + self.dilation_rate = dilation_rate + self.n_layers = n_layers + self.n_flows = n_flows + self.gin_channels = gin_channels + + self.flows = nn.ModuleList() + for i in range(n_flows): + self.flows.append( + ResidualCouplingLayer( + channels, + hidden_channels, + kernel_size, + dilation_rate, + n_layers, + gin_channels=gin_channels, + mean_only=True, + ) + ) + self.flows.append(Flip()) + + def forward(self, x, x_mask, g=None, reverse=False): + if not reverse: + for flow in self.flows: + x, _ = flow(x, x_mask, g=g, reverse=reverse) + else: + for flow in reversed(self.flows): + x = flow(x, x_mask, g=g, reverse=reverse) + return x + + +class PosteriorEncoder(nn.Module): + def __init__( + self, in_channels, out_channels, hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=0 + ): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.hidden_channels = hidden_channels + self.kernel_size = kernel_size + self.dilation_rate = dilation_rate + self.n_layers = n_layers + self.gin_channels = gin_channels + + self.pre = nn.Conv1d(in_channels, hidden_channels, 1) + self.enc = WN(hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=gin_channels) + self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1) + + def forward(self, x, x_lengths, g=None): + x_mask = torch.unsqueeze(get_mask_from_lengths(x_lengths, x), 1).to(x.dtype).to(device=x.device) + x = self.pre(x) * x_mask + x = self.enc(x, x_mask, g=g) + stats = self.proj(x) * x_mask + m, logs = torch.split(stats, self.out_channels, dim=1) + z = (m + torch.randn_like(m) * torch.exp(logs)) * x_mask + return z, m, logs, x_mask + + +class Generator(torch.nn.Module): + def __init__( + self, + initial_channel, + resblock, + resblock_kernel_sizes, + resblock_dilation_sizes, + upsample_rates, + upsample_initial_channel, + upsample_kernel_sizes, + gin_channels=0, + ): + super(Generator, self).__init__() + self.num_kernels = len(resblock_kernel_sizes) + self.num_upsamples = len(upsample_rates) + self.conv_pre = nn.Conv1d(initial_channel, upsample_initial_channel, 7, 1, padding=3) + resblock = ResBlock1 if resblock == '1' else ResBlock2 + + self.ups = nn.ModuleList() + for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)): + self.ups.append( + weight_norm( + nn.ConvTranspose1d( + upsample_initial_channel // (2 ** i), + upsample_initial_channel // (2 ** (i + 1)), + k, + u, + padding=(k - u) // 2, + ) + ) + ) + + self.resblocks = nn.ModuleList() + for i in range(len(self.ups)): + ch = upsample_initial_channel // (2 ** (i + 1)) + for j, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)): + self.resblocks.append(resblock(ch, k, d)) + + self.conv_post = nn.Conv1d(ch, 1, 7, 1, padding=3, bias=False) + self.ups.apply(init_weights) + + if gin_channels != 0: + self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1) + + def forward(self, x, g=None): + x = self.conv_pre(x) + if g is not None: + x = x + self.cond(g) + + for i in range(self.num_upsamples): + x = F.leaky_relu(x, LRELU_SLOPE) + x = self.ups[i](x) + xs = torch.zeros(x.shape, dtype=x.dtype, device=x.device) + for j in range(self.num_kernels): + xs += self.resblocks[i * self.num_kernels + j](x) + x = xs / self.num_kernels + x = F.leaky_relu(x) + x = self.conv_post(x) + x = torch.tanh(x) + + return x + + def remove_weight_norm(self): + print('Removing weight norm...') + for l in self.ups: + remove_weight_norm(l) + for l in self.resblocks: + l.remove_weight_norm() + + +class DiscriminatorP(torch.nn.Module): + def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False): + super(DiscriminatorP, self).__init__() + self.period = period + self.use_spectral_norm = use_spectral_norm + norm_f = weight_norm if use_spectral_norm == False else spectral_norm + self.convs = nn.ModuleList( + [ + norm_f(nn.Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))), + norm_f(nn.Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))), + norm_f(nn.Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))), + norm_f(nn.Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))), + norm_f(nn.Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(get_padding(kernel_size, 1), 0))), + ] + ) + self.dropout = nn.Dropout(0.3) + self.conv_post = norm_f(nn.Conv2d(1024, 1, (3, 1), 1, padding=(1, 0))) + + def forward(self, x): + fmap = [] + + # 1d to 2d + b, c, t = x.shape + if t % self.period != 0: # pad first + n_pad = self.period - (t % self.period) + x = F.pad(x, (0, n_pad), "reflect") + t = t + n_pad + x = x.view(b, c, t // self.period, self.period) + + for l in self.convs: + x = l(x) + x = self.dropout(x) + x = F.leaky_relu(x, LRELU_SLOPE) + fmap.append(x) + x = self.conv_post(x) + fmap.append(x) + x = torch.flatten(x, 1, -1) + + return x, fmap + + +class DiscriminatorS(torch.nn.Module): + def __init__(self, use_spectral_norm=False): + super(DiscriminatorS, self).__init__() + norm_f = weight_norm if use_spectral_norm == False else spectral_norm + self.convs = nn.ModuleList( + [ + norm_f(nn.Conv1d(1, 16, 15, 1, padding=7)), + norm_f(nn.Conv1d(16, 64, 41, 4, groups=4, padding=20)), + norm_f(nn.Conv1d(64, 256, 41, 4, groups=16, padding=20)), + norm_f(nn.Conv1d(256, 1024, 41, 4, groups=64, padding=20)), + norm_f(nn.Conv1d(1024, 1024, 41, 4, groups=256, padding=20)), + norm_f(nn.Conv1d(1024, 1024, 5, 1, padding=2)), + ] + ) + self.dropout = nn.Dropout(0.3) + self.conv_post = norm_f(nn.Conv1d(1024, 1, 3, 1, padding=1)) + + def forward(self, x): + fmap = [] + + for l in self.convs: + x = l(x) + x = F.leaky_relu(x, LRELU_SLOPE) + fmap.append(x) + x = self.conv_post(x) + fmap.append(x) + x = torch.flatten(x, 1, -1) + + return x, fmap + + +class MultiPeriodDiscriminator(torch.nn.Module): + def __init__(self, use_spectral_norm=False): + super(MultiPeriodDiscriminator, self).__init__() + periods = [2, 3, 5, 7, 11] + + discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)] + discs = discs + [DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods] + self.discriminators = nn.ModuleList(discs) + + def forward(self, y, y_hat): + y_d_rs = [] + y_d_gs = [] + fmap_rs = [] + fmap_gs = [] + for i, d in enumerate(self.discriminators): + y_d_r, fmap_r = d(y) + y_d_g, fmap_g = d(y_hat) + y_d_rs.append(y_d_r) + y_d_gs.append(y_d_g) + fmap_rs.append(fmap_r) + fmap_gs.append(fmap_g) + + return y_d_rs, y_d_gs, fmap_rs, fmap_gs + + +class SynthesizerTrn(nn.Module): + """ + Synthesizer for Training + """ + + def __init__( + self, + n_vocab, + spec_channels, + segment_size, + inter_channels, + hidden_channels, + filter_channels, + n_heads, + n_layers, + kernel_size, + p_dropout, + padding_idx, + resblock, + resblock_kernel_sizes, + resblock_dilation_sizes, + upsample_rates, + upsample_initial_channel, + upsample_kernel_sizes, + n_speakers=0, + gin_channels=0, + use_sdp=True, + **kwargs + ): + + super().__init__() + self.n_vocab = n_vocab + self.spec_channels = spec_channels + self.inter_channels = inter_channels + self.hidden_channels = hidden_channels + self.filter_channels = filter_channels + self.n_heads = n_heads + self.n_layers = n_layers + self.kernel_size = kernel_size + self.p_dropout = p_dropout + self.padding_idx = padding_idx + self.resblock = resblock + self.resblock_kernel_sizes = resblock_kernel_sizes + self.resblock_dilation_sizes = resblock_dilation_sizes + self.upsample_rates = upsample_rates + self.upsample_initial_channel = upsample_initial_channel + self.upsample_kernel_sizes = upsample_kernel_sizes + self.segment_size = segment_size + self.n_speakers = n_speakers + self.gin_channels = gin_channels + + self.use_sdp = use_sdp + + self.enc_p = TextEncoder( + n_vocab, + inter_channels, + hidden_channels, + filter_channels, + n_heads, + n_layers, + kernel_size, + p_dropout, + padding_idx, + ) + self.dec = Generator( + inter_channels, + resblock, + resblock_kernel_sizes, + resblock_dilation_sizes, + upsample_rates, + upsample_initial_channel, + upsample_kernel_sizes, + gin_channels=gin_channels, + ) + self.enc_q = PosteriorEncoder( + spec_channels, inter_channels, hidden_channels, 5, 1, 16, gin_channels=gin_channels + ) + self.flow = ResidualCouplingBlock(inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels) + + if use_sdp: + self.dp = StochasticDurationPredictor(hidden_channels, 192, 3, 0.5, 4, gin_channels=gin_channels) + else: + self.dp = DurationPredictor(hidden_channels, 256, 3, 0.5, gin_channels=gin_channels) + + if n_speakers > 1: + self.emb_g = nn.Embedding(n_speakers, gin_channels) + + def forward(self, text, text_len, spec, spec_len, speakers=None): + x, mean_prior, logscale_prior, text_mask = self.enc_p(text, text_len) + if self.n_speakers > 1: + g = self.emb_g(speakers).unsqueeze(-1) # [b, h, 1] + else: + g = None + + z, mean_posterior, logscale_posterior, spec_mask = self.enc_q(spec, spec_len, g=g) + z_p = self.flow(z, spec_mask, g=g) + + with torch.no_grad(): + # negative cross-entropy + s_p_sq_r = torch.exp(-2 * logscale_prior) # [b, d, t] + neg_cent1 = torch.sum(-0.5 * math.log(2 * math.pi) - logscale_prior, [1], keepdim=True) # [b, 1, t_s] + neg_cent2 = torch.matmul( + -0.5 * (z_p ** 2).transpose(1, 2), s_p_sq_r + ) # [b, t_t, d] x [b, d, t_s] = [b, t_t, t_s] + neg_cent3 = torch.matmul( + z_p.transpose(1, 2), (mean_prior * s_p_sq_r) + ) # [b, t_t, d] x [b, d, t_s] = [b, t_t, t_s] + neg_cent4 = torch.sum(-0.5 * (mean_prior ** 2) * s_p_sq_r, [1], keepdim=True) # [b, 1, t_s] + neg_cent = neg_cent1 + neg_cent2 + neg_cent3 + neg_cent4 + + attn_mask = torch.unsqueeze(text_mask, 2) * torch.unsqueeze(spec_mask, -1) + attn = maximum_path(neg_cent, attn_mask.squeeze(1)).unsqueeze(1).detach() + + w = attn.sum(2) + if self.use_sdp: + l_length = self.dp(x, text_mask, w, g=g) + l_length = l_length / torch.sum(text_mask) + else: + logw_ = torch.log(w + 1e-6) * text_mask + logw = self.dp(x, text_mask, g=g) + l_length = torch.sum((logw - logw_) ** 2, [1, 2]) / torch.sum(text_mask) # for averaging + + # expand prior + mean_prior = torch.matmul(attn.squeeze(1), mean_prior.transpose(1, 2)).transpose( + 1, 2 + ) # [b, t', t], [b, t, d] -> [b, d, t'] + logscale_prior = torch.matmul(attn.squeeze(1), logscale_prior.transpose(1, 2)).transpose( + 1, 2 + ) # [b, t', t], [b, t, d] -> [b, d, t'] + + z_slice, ids_slice = rand_slice_segments(z, spec_len, self.segment_size) + audio = self.dec(z_slice, g=g) + return ( + audio, + l_length, + attn, + ids_slice, + text_mask, + spec_mask, + (z, z_p, mean_prior, logscale_prior, mean_posterior, logscale_posterior), + ) + + def infer(self, text, text_len, speakers=None, noise_scale=1, length_scale=1, noise_scale_w=1.0, max_len=None): + x, mean_prior, logscale_prior, text_mask = self.enc_p(text, text_len) + if self.n_speakers > 1 and speakers is not None: + g = self.emb_g(speakers).unsqueeze(-1) # [b, h, 1] + else: + g = None + + if self.use_sdp: + logw = self.dp(x, text_mask, g=g, reverse=True, noise_scale=noise_scale_w) + else: + logw = self.dp(x, text_mask, g=g) + w = torch.exp(logw) * text_mask * length_scale + w_ceil = torch.ceil(w) + audio_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long() + audio_mask = torch.unsqueeze(get_mask_from_lengths(audio_lengths, None), 1).to(text_mask.dtype) + attn_mask = torch.unsqueeze(text_mask, 2) * torch.unsqueeze(audio_mask, -1) + attn = generate_path(w_ceil, attn_mask) + + mean_prior = torch.matmul(attn.squeeze(1), mean_prior.transpose(1, 2)).transpose( + 1, 2 + ) # [b, t', t], [b, t, d] -> [b, d, t'] + logscale_prior = torch.matmul(attn.squeeze(1), logscale_prior.transpose(1, 2)).transpose( + 1, 2 + ) # [b, t', t], [b, t, d] -> [b, d, t'] + + z_p = mean_prior + torch.randn_like(mean_prior) * torch.exp(logscale_prior) * noise_scale + z = self.flow(z_p, audio_mask, g=g, reverse=True) + audio = self.dec((z * audio_mask)[:, :, :max_len], g=g) + return audio, attn, audio_mask, (z, z_p, mean_prior, logscale_prior) + + # Can be used for emotions + def voice_conversion(self, y, y_lengths, speaker_src, speaker_tgt): + assert self.n_speakers > 1, "n_speakers have to be larger than 1." + g_src = self.emb_g(speaker_src).unsqueeze(-1) + g_tgt = self.emb_g(speaker_tgt).unsqueeze(-1) + z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g_src) + z_p = self.flow(z, y_mask, g=g_src) + z_hat = self.flow(z_p, y_mask, g=g_tgt, reverse=True) + o_hat = self.dec(z_hat * y_mask, g=g_tgt) + return o_hat, y_mask, (z, z_p, z_hat) + + +############## +# Attentions # +############## +class AttentionEncoder(nn.Module): + def __init__( + self, + hidden_channels, + filter_channels, + n_heads, + n_layers, + kernel_size=1, + p_dropout=0.0, + window_size=4, + **kwargs + ): + super().__init__() + self.hidden_channels = hidden_channels + self.filter_channels = filter_channels + self.n_heads = n_heads + self.n_layers = n_layers + self.kernel_size = kernel_size + self.p_dropout = p_dropout + self.window_size = window_size + + self.drop = nn.Dropout(p_dropout) + self.attn_layers = nn.ModuleList() + self.norm_layers_1 = nn.ModuleList() + self.ffn_layers = nn.ModuleList() + self.norm_layers_2 = nn.ModuleList() + for _ in range(self.n_layers): + self.attn_layers.append( + MultiHeadAttention( + hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout, window_size=window_size + ) + ) + self.norm_layers_1.append(LayerNorm(hidden_channels)) + self.ffn_layers.append( + FFN(hidden_channels, hidden_channels, filter_channels, kernel_size, p_dropout=p_dropout) + ) + self.norm_layers_2.append(LayerNorm(hidden_channels)) + + def forward(self, x, x_mask): + attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1) + x = x * x_mask + for i in range(self.n_layers): + y = self.attn_layers[i](x, x, attn_mask) + y = self.drop(y) + x = self.norm_layers_1[i](x + y) + y = self.ffn_layers[i](x, x_mask) + y = self.drop(y) + x = self.norm_layers_2[i](x + y) + x = x * x_mask + return x + + +class MultiHeadAttention(nn.Module): + def __init__( + self, + channels, + out_channels, + n_heads, + p_dropout=0.0, + window_size=None, + heads_share=True, + block_length=None, + proximal_bias=False, + proximal_init=False, + ): + super().__init__() + assert channels % n_heads == 0 + + self.channels = channels + self.out_channels = out_channels + self.n_heads = n_heads + self.p_dropout = p_dropout + self.window_size = window_size + self.heads_share = heads_share + self.block_length = block_length + self.proximal_bias = proximal_bias + self.proximal_init = proximal_init + self.attn = None + + self.k_channels = channels // n_heads + self.conv_q = nn.Conv1d(channels, channels, 1) + self.conv_k = nn.Conv1d(channels, channels, 1) + self.conv_v = nn.Conv1d(channels, channels, 1) + self.conv_o = nn.Conv1d(channels, out_channels, 1) + self.drop = nn.Dropout(p_dropout) + + if window_size is not None: + n_heads_rel = 1 if heads_share else n_heads + rel_stddev = self.k_channels ** -0.5 + self.emb_rel_k = nn.Parameter(torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) * rel_stddev) + self.emb_rel_v = nn.Parameter(torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) * rel_stddev) + + nn.init.xavier_uniform_(self.conv_q.weight) + nn.init.xavier_uniform_(self.conv_k.weight) + nn.init.xavier_uniform_(self.conv_v.weight) + if proximal_init: + with torch.no_grad(): + self.conv_k.weight.copy_(self.conv_q.weight) + self.conv_k.bias.copy_(self.conv_q.bias) + + def forward(self, x, c, attn_mask=None): + q = self.conv_q(x) + k = self.conv_k(c) + v = self.conv_v(c) + + x, self.attn = self.attention(q, k, v, mask=attn_mask) + + x = self.conv_o(x) + return x + + def attention(self, query, key, value, mask=None): + # reshape [b, d, t] -> [b, n_h, t, d_k] + b, d, t_s, t_t = key.size(0), key.size(1), key.size(2), query.size(2) + query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3) + key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3) + value = value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3) + + scores = torch.matmul(query / math.sqrt(self.k_channels), key.transpose(-2, -1)) + if self.window_size is not None: + assert t_s == t_t, "Relative attention is only available for self-attention." + key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s) + rel_logits = self._matmul_with_relative_keys(query / math.sqrt(self.k_channels), key_relative_embeddings) + scores_local = self._relative_position_to_absolute_position(rel_logits) + scores = scores + scores_local + if self.proximal_bias: + assert t_s == t_t, "Proximal bias is only available for self-attention." + scores = scores + self._attention_bias_proximal(t_s).to(device=scores.device, dtype=scores.dtype) + if mask is not None: + scores = scores.masked_fill(mask == 0, -1e4) + if self.block_length is not None: + assert t_s == t_t, "Local attention is only available for self-attention." + block_mask = torch.ones_like(scores).triu(-self.block_length).tril(self.block_length) + scores = scores.masked_fill(block_mask == 0, -1e4) + p_attn = F.softmax(scores, dim=-1) # [b, n_h, t_t, t_s] + p_attn = self.drop(p_attn) + output = torch.matmul(p_attn, value) + if self.window_size is not None: + relative_weights = self._absolute_position_to_relative_position(p_attn) + value_relative_embeddings = self._get_relative_embeddings(self.emb_rel_v, t_s) + output = output + self._matmul_with_relative_values(relative_weights, value_relative_embeddings) + output = output.transpose(2, 3).contiguous().view(b, d, t_t) # [b, n_h, t_t, d_k] -> [b, d, t_t] + return output, p_attn + + def _matmul_with_relative_values(self, x, y): + """ + x: [b, h, l, m] + y: [h or 1, m, d] + ret: [b, h, l, d] + """ + ret = torch.matmul(x, y.unsqueeze(0)) + return ret + + def _matmul_with_relative_keys(self, x, y): + """ + x: [b, h, l, d] + y: [h or 1, m, d] + ret: [b, h, l, m] + """ + ret = torch.matmul(x, y.unsqueeze(0).transpose(-2, -1)) + return ret + + def _get_relative_embeddings(self, relative_embeddings, length): + # Pad first before slice to avoid using cond ops. + pad_length = max(length - (self.window_size + 1), 0) + slice_start_position = max((self.window_size + 1) - length, 0) + slice_end_position = slice_start_position + 2 * length - 1 + if pad_length > 0: + padded_relative_embeddings = F.pad( + relative_embeddings, convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]]) + ) + else: + padded_relative_embeddings = relative_embeddings + used_relative_embeddings = padded_relative_embeddings[:, slice_start_position:slice_end_position] + return used_relative_embeddings + + def _relative_position_to_absolute_position(self, x): + """ + x: [b, h, l, 2*l-1] + ret: [b, h, l, l] + """ + batch, heads, length, _ = x.size() + # Concat columns of pad to shift from relative to absolute indexing. + x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, 1]])) + + # Concat extra elements so to add up to shape (len+1, 2*len-1). + x_flat = x.view([batch, heads, length * 2 * length]) + x_flat = F.pad(x_flat, convert_pad_shape([[0, 0], [0, 0], [0, length - 1]])) + + # Reshape and slice out the padded elements. + x_final = x_flat.view([batch, heads, length + 1, 2 * length - 1])[:, :, :length, length - 1 :] + return x_final + + def _absolute_position_to_relative_position(self, x): + """ + x: [b, h, l, l] + ret: [b, h, l, 2*l-1] + """ + batch, heads, length, _ = x.size() + # padd along column + x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length - 1]])) + x_flat = x.view([batch, heads, length ** 2 + length * (length - 1)]) + # add 0's in the beginning that will skew the elements after reshape + x_flat = F.pad(x_flat, convert_pad_shape([[0, 0], [0, 0], [length, 0]])) + x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:] + return x_final + + def _attention_bias_proximal(self, length): + """Bias for self-attention to encourage attention to close positions. + Args: + length: an integer scalar. + Returns: + a Tensor with shape [1, 1, length, length] + """ + r = torch.arange(length, dtype=torch.float32) + diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1) + return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0) + + +class FFN(nn.Module): + def __init__( + self, in_channels, out_channels, filter_channels, kernel_size, p_dropout=0.0, activation=None, causal=False + ): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.filter_channels = filter_channels + self.kernel_size = kernel_size + self.p_dropout = p_dropout + self.activation = activation + self.causal = causal + + if causal: + self.padding = self._causal_padding + else: + self.padding = self._same_padding + + self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size) + self.conv_2 = nn.Conv1d(filter_channels, out_channels, kernel_size) + self.drop = nn.Dropout(p_dropout) + + def forward(self, x, x_mask): + x = self.conv_1(self.padding(x * x_mask)) + if self.activation == "gelu": + x = x * torch.sigmoid(1.702 * x) + else: + x = torch.relu(x) + x = self.drop(x) + x = self.conv_2(self.padding(x * x_mask)) + return x * x_mask + + def _causal_padding(self, x): + if self.kernel_size == 1: + return x + pad_l = self.kernel_size - 1 + pad_r = 0 + padding = [[0, 0], [0, 0], [pad_l, pad_r]] + x = F.pad(x, convert_pad_shape(padding)) + return x + + def _same_padding(self, x): + if self.kernel_size == 1: + return x + pad_l = (self.kernel_size - 1) // 2 + pad_r = self.kernel_size // 2 + padding = [[0, 0], [0, 0], [pad_l, pad_r]] + x = F.pad(x, convert_pad_shape(padding)) + return x diff --git a/nemo/collections/tts/torch/data.py b/nemo/collections/tts/torch/data.py index 19043995c87e..113826af8cef 100644 --- a/nemo/collections/tts/torch/data.py +++ b/nemo/collections/tts/torch/data.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. - import json import math import os @@ -218,6 +217,7 @@ def __init__( if isinstance(manifest_filepath, str): manifest_filepath = [manifest_filepath] self.manifest_filepath = manifest_filepath + self.lengths = [] # Needed for BucketSampling data = [] total_duration = 0 @@ -249,6 +249,8 @@ def __init__( file_info["text_tokens"] = self.text_tokenizer(file_info["normalized_text"]) data.append(file_info) + # Calculating length of spectrogram from input audio for batch sampling + self.lengths.append(os.path.getsize(item["audio_filepath"]) // (n_fft // 2)) if file_info["duration"] is None: logging.info( @@ -570,7 +572,7 @@ def __getitem__(self, index): if "text_tokens" in sample: text = torch.tensor(sample["text_tokens"]).long() - text_length = torch.tensor(len(sample["text_tokens"])).long() + text_length = torch.tensor(len(text)).long() else: tokenized = self.text_tokenizer(sample["normalized_text"]) text = torch.tensor(tokenized).long() @@ -1432,3 +1434,112 @@ def __getitem__(self, index): def __len__(self): return len(self.data) + + +class DistributedBucketSampler(torch.utils.data.distributed.DistributedSampler): + """ + Maintain similar input lengths in a batch. + Length groups are specified by boundaries. + Ex) boundaries = [b1, b2, b3] -> any batch is included either {x | b1 < length(x) <=b2} or {x | b2 < length(x) <= b3}. + + It removes samples which are not included in the boundaries. + Ex) boundaries = [b1, b2, b3] -> any x s.t. length(x) <= b1 or length(x) > b3 are discarded. + """ + + def __init__(self, dataset, batch_size, boundaries, num_replicas=None, rank=None, shuffle=True): + super().__init__(dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle) + self.lengths = dataset.lengths + self.batch_size = batch_size + self.boundaries = boundaries + + self.buckets, self.num_samples_per_bucket = self._create_buckets() + self.total_size = sum(self.num_samples_per_bucket) + self.num_samples = self.total_size // self.num_replicas + + def _create_buckets(self): + buckets = [[] for _ in range(len(self.boundaries) - 1)] + for i in range(len(self.lengths)): + length = self.lengths[i] + idx_bucket = self._bisect(length) + if idx_bucket != -1: + buckets[idx_bucket].append(i) + + for i in range(len(buckets) - 1, 0, -1): + if len(buckets[i]) == 0: + buckets.pop(i) + self.boundaries.pop(i + 1) + + num_samples_per_bucket = [] + total_batch_size = self.num_replicas * self.batch_size + for i in range(len(buckets)): + len_bucket = len(buckets[i]) + rem = (total_batch_size - (len_bucket % total_batch_size)) % total_batch_size + num_samples_per_bucket.append(len_bucket + rem) + return buckets, num_samples_per_bucket + + def __iter__(self): + # deterministically shuffle based on epoch + g = torch.Generator() + g.manual_seed(self.epoch) + indices = [] + if self.shuffle: + for bucket in self.buckets: + indices.append(torch.randperm(len(bucket), generator=g).tolist()) + else: + for bucket in self.buckets: + indices.append(list(range(len(bucket)))) + + batches = [] + for i in range(len(self.buckets)): + bucket = self.buckets[i] + len_bucket = len(bucket) + ids_bucket = indices[i] + num_samples_bucket = self.num_samples_per_bucket[i] + + # add extra samples to make it evenly divisible + rem = num_samples_bucket - len_bucket + ids_bucket = ids_bucket + ids_bucket * (rem // len_bucket) + ids_bucket[: (rem % len_bucket)] + + # subsample + ids_bucket = ids_bucket[self.rank :: self.num_replicas] + + # batching + for j in range(len(ids_bucket) // self.batch_size): + batch = [bucket[idx] for idx in ids_bucket[j * self.batch_size : (j + 1) * self.batch_size]] + batches.append(batch) + + if self.shuffle: + batch_ids = torch.randperm(len(batches), generator=g).tolist() + batches = [batches[i] for i in batch_ids] + self.batches = batches + + assert len(self.batches) * self.batch_size == self.num_samples + return iter(self.batches) + + def _bisect(self, x, lo=0, hi=None): + if hi is None: + hi = len(self.boundaries) - 1 + + if hi > lo: + mid = (hi + lo) // 2 + if self.boundaries[mid] < x and x <= self.boundaries[mid + 1]: + return mid + elif x <= self.boundaries[mid]: + return self._bisect(x, lo, mid) + else: + return self._bisect(x, mid + 1, hi) + else: + return -1 + + def __len__(self): + return self.num_samples // self.batch_size + + def set_epoch(self, epoch: int) -> None: + """ + Sets the epoch for this sampler. When :attr:`shuffle=True`, this ensures all replicas + use a different random ordering for each epoch. Otherwise, the next iteration of this + sampler will yield the same ordering. + Args: + epoch (int): Epoch number. + """ + self.epoch = epoch diff --git a/tutorials/nlp/Text2Sparql.ipynb b/tutorials/nlp/Text2Sparql.ipynb index b734e72c1fc6..69ccdaccadc9 100644 --- a/tutorials/nlp/Text2Sparql.ipynb +++ b/tutorials/nlp/Text2Sparql.ipynb @@ -2260,4 +2260,4 @@ }, "nbformat": 4, "nbformat_minor": 1 -} \ No newline at end of file +} diff --git a/tutorials/text_processing/Text_(Inverse)_Normalization.ipynb b/tutorials/text_processing/Text_(Inverse)_Normalization.ipynb old mode 100644 new mode 100755 index 596523b41c0a..f8123146f55f --- a/tutorials/text_processing/Text_(Inverse)_Normalization.ipynb +++ b/tutorials/text_processing/Text_(Inverse)_Normalization.ipynb @@ -465,4 +465,4 @@ }, "nbformat": 4, "nbformat_minor": 1 -} +} \ No newline at end of file