diff --git a/TTS/encoder/models/resnet.py b/TTS/encoder/models/resnet.py index 84e9967f84..e75ab6c463 100644 --- a/TTS/encoder/models/resnet.py +++ b/TTS/encoder/models/resnet.py @@ -1,7 +1,7 @@ import torch from torch import nn -# from TTS.utils.audio import TorchSTFT +# from TTS.utils.audio.torch_transforms import TorchSTFT from TTS.encoder.models.base_encoder import BaseEncoder diff --git a/TTS/encoder/utils/generic_utils.py b/TTS/encoder/utils/generic_utils.py index 91a896f60d..e18aa0eeff 100644 --- a/TTS/encoder/utils/generic_utils.py +++ b/TTS/encoder/utils/generic_utils.py @@ -9,7 +9,7 @@ from TTS.encoder.models.lstm import LSTMSpeakerEncoder from TTS.encoder.models.resnet import ResNetSpeakerEncoder -from TTS.utils.io import save_fsspec +from trainer.io import save_fsspec class AugmentWAV(object): diff --git a/TTS/tts/configs/fast_pitch_e2e_config.py b/TTS/tts/configs/fast_pitch_e2e_config.py new file mode 100644 index 0000000000..21bfc4c3d1 --- /dev/null +++ b/TTS/tts/configs/fast_pitch_e2e_config.py @@ -0,0 +1,178 @@ +from dataclasses import dataclass, field +from typing import List + +from TTS.tts.configs.shared_configs import BaseTTSConfig +from TTS.tts.models.forward_tts_e2e import ForwardTTSE2eArgs + + +@dataclass +class FastPitchE2eConfig(BaseTTSConfig): + """Configure `ForwardTTSE2e` as FastPitchE2e model. + + Example: + + >>> from TTS.tts.configs.fast_pitch_e2e_config import FastPitchE2EConfig + >>> config = FastPitchE2EConfig() + + Args: + model (str): + Model name used for selecting the right model at initialization. Defaults to `fast_pitch`. + + base_model (str): + Name of the base model being configured as this model so that 🐸 TTS knows it needs to initiate + the base model rather than searching for the `model` implementation. Defaults to `forward_tts`. + + model_args (Coqpit): + Model class arguments. Check `FastPitchArgs` for more details. Defaults to `FastPitchArgs()`. + + data_dep_init_steps (int): + Number of steps used for computing normalization parameters at the beginning of the training. GlowTTS uses + Activation Normalization that pre-computes normalization stats at the beginning and use the same values + for the rest. Defaults to 10. + + speakers_file (str): + Path to the file containing the list of speakers. Needed at inference for loading matching speaker ids to + speaker names. Defaults to `None`. + + use_speaker_embedding (bool): + enable / disable using speaker embeddings for multi-speaker models. If set True, the model is + in the multi-speaker mode. Defaults to False. + + use_d_vector_file (bool): + enable /disable using external speaker embeddings in place of the learned embeddings. Defaults to False. + + d_vector_file (str): + Path to the file including pre-computed speaker embeddings. Defaults to None. + + d_vector_dim (int): + Dimension of the external speaker embeddings. Defaults to 0. + + optimizer (str): + Name of the model optimizer. Defaults to `Adam`. + + optimizer_params (dict): + Arguments of the model optimizer. Defaults to `{"betas": [0.9, 0.998], "weight_decay": 1e-6}`. + + lr_scheduler (str): + Name of the learning rate scheduler. Defaults to `Noam`. + + lr_scheduler_params (dict): + Arguments of the learning rate scheduler. Defaults to `{"warmup_steps": 4000}`. + + lr (float): + Initial learning rate. Defaults to `1e-3`. + + grad_clip (float): + Gradient norm clipping value. Defaults to `5.0`. + + spec_loss_type (str): + Type of the spectrogram loss. Check `ForwardTTSLoss` for possible values. Defaults to `mse`. + + duration_loss_type (str): + Type of the duration loss. Check `ForwardTTSLoss` for possible values. Defaults to `mse`. + + use_ssim_loss (bool): + Enable/disable the use of SSIM (Structural Similarity) loss. Defaults to True. + + wd (float): + Weight decay coefficient. Defaults to `1e-7`. + + ssim_loss_alpha (float): + Weight for the SSIM loss. If set 0, disables the SSIM loss. Defaults to 1.0. + + dur_loss_alpha (float): + Weight for the duration predictor's loss. If set 0, disables the huber loss. Defaults to 1.0. + + spec_loss_alpha (float): + Weight for the L1 spectrogram loss. If set 0, disables the L1 loss. Defaults to 1.0. + + pitch_loss_alpha (float): + Weight for the pitch predictor's loss. If set 0, disables the pitch predictor. Defaults to 1.0. + + binary_align_loss_alpha (float): + Weight for the binary loss. If set 0, disables the binary loss. Defaults to 1.0. + + binary_loss_warmup_epochs (float): + Number of epochs to gradually increase the binary loss impact. Defaults to 150. + + min_seq_len (int): + Minimum input sequence length to be used at training. + + max_seq_len (int): + Maximum input sequence length to be used at training. Larger values result in more VRAM usage. + """ + + model: str = "fast_pitch_e2e" + base_model: str = "forward_tts_e2e" + + # model specific params + # model_args: ForwardTTSE2eArgs = ForwardTTSE2eArgs(vocoder_config=HifiganConfig()) + model_args: ForwardTTSE2eArgs = ForwardTTSE2eArgs() + + # multi-speaker settings + # num_speakers: int = 0 + # speakers_file: str = None + # use_speaker_embedding: bool = False + # use_d_vector_file: bool = False + # d_vector_file: str = False + # d_vector_dim: int = 0 + spec_segment_size: int = 30 + + # optimizer + grad_clip: List[float] = field(default_factory=lambda: [1000, 1000]) + lr_gen: float = 0.0002 + lr_disc: float = 0.0002 + lr_scheduler_gen: str = "ExponentialLR" + lr_scheduler_gen_params: dict = field(default_factory=lambda: {"gamma": 0.999875, "last_epoch": -1}) + lr_scheduler_disc: str = "ExponentialLR" + lr_scheduler_disc_params: dict = field(default_factory=lambda: {"gamma": 0.999875, "last_epoch": -1}) + scheduler_after_epoch: bool = True + optimizer: str = "AdamW" + optimizer_params: dict = field(default_factory=lambda: {"betas": [0.8, 0.99], "eps": 1e-9, "weight_decay": 0.01}) + + # encoder loss params + spec_loss_type: str = "mse" + duration_loss_type: str = "mse" + use_ssim_loss: bool = False + ssim_loss_alpha: float = 0.0 + spec_loss_alpha: float = 1.0 + aligner_loss_alpha: float = 1.0 + pitch_loss_alpha: float = 1.0 + dur_loss_alpha: float = 1.0 + binary_align_loss_alpha: float = 0.1 + binary_loss_warmup_epochs: int = 150 + + # dvocoder loss params + disc_loss_alpha: float = 1.0 + gen_loss_alpha: float = 1.0 + feat_loss_alpha: float = 1.0 + mel_loss_alpha: float = 10.0 + multi_scale_stft_loss_alpha: float = 2.5 + multi_scale_stft_loss_params: dict = field( + default_factory=lambda: { + "n_ffts": [1024, 2048, 512], + "hop_lengths": [120, 240, 50], + "win_lengths": [600, 1200, 240], + } + ) + + # data loader params + return_wav: bool = True + + # overrides + r: int = 1 + + # dataset configs + compute_f0: bool = True + f0_cache_path: str = None + + # testing + test_sentences: List[str] = field( + default_factory=lambda: [ + "It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.", + "Be a voice, not an echo.", + "I'm sorry Dave. I'm afraid I can't do that.", + "This cake is great. It's so delicious and moist.", + "Prior to November 22, 1963.", + ] + ) diff --git a/TTS/tts/configs/vits_config.py b/TTS/tts/configs/vits_config.py index a8c7f91dcd..6bfe82e670 100644 --- a/TTS/tts/configs/vits_config.py +++ b/TTS/tts/configs/vits_config.py @@ -116,12 +116,8 @@ class VitsConfig(BaseTTSConfig): dur_loss_alpha: float = 1.0 speaker_encoder_loss_alpha: float = 1.0 - # data loader params - return_wav: bool = True - compute_linear_spec: bool = True - # overrides - r: int = 1 # DO NOT CHANGE + # r: int = 1 # DO NOT CHANGE add_blank: bool = True # testing @@ -137,6 +133,7 @@ class VitsConfig(BaseTTSConfig): # multi-speaker settings # use speaker embedding layer + # TODO: keep this only in VitsArgs num_speakers: int = 0 use_speaker_embedding: bool = False speakers_file: str = None diff --git a/TTS/tts/datasets/__init__.py b/TTS/tts/datasets/__init__.py index 6c7c9eddea..7fe3d65b3e 100644 --- a/TTS/tts/datasets/__init__.py +++ b/TTS/tts/datasets/__init__.py @@ -13,7 +13,6 @@ def split_dataset(items, eval_split_max_size=None, eval_split_size=0.01): """Split a dataset into train and eval. Consider speaker distribution in multi-speaker training. Args: - <<<<<<< HEAD items (List[List]): A list of samples. Each sample is a list of `[audio_path, text, speaker_id]`. @@ -23,9 +22,6 @@ def split_dataset(items, eval_split_max_size=None, eval_split_size=0.01): eval_split_size (float): If between 0.0 and 1.0 represents the proportion of the dataset to include in the evaluation set. If > 1, represents the absolute number of evaluation samples. Defaults to 0.01 (1%). - ======= - items (List[List]): A list of samples. Each sample is a list of `[text, audio_path, speaker_id]`. - >>>>>>> Fix docstring """ speakers = [item["speaker_name"] for item in items] is_multi_speaker = len(set(speakers)) > 1 diff --git a/TTS/tts/datasets/dataset.py b/TTS/tts/datasets/dataset.py index d8f16e4efe..072da27ec3 100644 --- a/TTS/tts/datasets/dataset.py +++ b/TTS/tts/datasets/dataset.py @@ -9,7 +9,7 @@ from torch.utils.data import Dataset from TTS.tts.utils.data import prepare_data, prepare_stop_target, prepare_tensor -from TTS.utils.audio import AudioProcessor +from TTS.utils.audio.numpy_transforms import compute_f0, load_wav, wav_to_mel, wav_to_spec # to prevent too many open files error as suggested here # https://github.com/pytorch/pytorch/issues/11201#issuecomment-421146936 @@ -37,9 +37,9 @@ def noise_augment_audio(wav): class TTSDataset(Dataset): def __init__( self, + audio_config: "Coqpit" = None, outputs_per_step: int = 1, compute_linear_spec: bool = False, - ap: AudioProcessor = None, samples: List[Dict] = None, tokenizer: "TTSTokenizer" = None, compute_f0: bool = False, @@ -64,12 +64,12 @@ def __init__( If you need something different, you can subclass and override. Args: + audio_config (Coqpit): Audio configuration. + outputs_per_step (int): Number of time frames predicted per step. compute_linear_spec (bool): compute linear spectrogram if True. - ap (TTS.tts.utils.AudioProcessor): Audio processor object. - samples (list): List of dataset samples. tokenizer (TTSTokenizer): tokenizer to convert text to sequence IDs. If None init internally else @@ -115,6 +115,7 @@ def __init__( verbose (bool): Print diagnostic information. Defaults to false. """ super().__init__() + self.audio_config = audio_config self.batch_group_size = batch_group_size self._samples = samples self.outputs_per_step = outputs_per_step @@ -126,7 +127,6 @@ def __init__( self.max_audio_len = max_audio_len self.min_text_len = min_text_len self.max_text_len = max_text_len - self.ap = ap self.phoneme_cache_path = phoneme_cache_path self.speaker_id_mapping = speaker_id_mapping self.d_vector_mapping = d_vector_mapping @@ -146,7 +146,7 @@ def __init__( if compute_f0: self.f0_dataset = F0Dataset( - self.samples, self.ap, cache_path=f0_cache_path, precompute_num_workers=precompute_num_workers + self.samples, self.audio_config, cache_path=f0_cache_path, precompute_num_workers=precompute_num_workers ) if self.verbose: @@ -188,7 +188,7 @@ def print_logs(self, level: int = 0) -> None: print(f"{indent}| > Number of instances : {len(self.samples)}") def load_wav(self, filename): - waveform = self.ap.load_wav(filename) + waveform = load_wav(filename) assert waveform.size > 0 return waveform @@ -408,7 +408,7 @@ def collate_fn(self, batch): else: speaker_ids = None # compute features - mel = [self.ap.melspectrogram(w).astype("float32") for w in batch["wav"]] + mel = [wav_to_mel(w).astype("float32") for w in batch["wav"]] mel_lengths = [m.shape[1] for m in mel] @@ -455,7 +455,7 @@ def collate_fn(self, batch): # compute linear spectrogram linear = None if self.compute_linear_spec: - linear = [self.ap.spectrogram(w).astype("float32") for w in batch["wav"]] + linear = [wav_to_spec(w).astype("float32") for w in batch["wav"]] linear = prepare_tensor(linear, self.outputs_per_step) linear = linear.transpose(0, 2, 1) assert mel.shape[1] == linear.shape[1] @@ -465,13 +465,13 @@ def collate_fn(self, batch): wav_padded = None if self.return_wav: wav_lengths = [w.shape[0] for w in batch["wav"]] - max_wav_len = max(mel_lengths_adjusted) * self.ap.hop_length + max_wav_len = max(mel_lengths_adjusted) * self.audio_config.hop_length wav_lengths = torch.LongTensor(wav_lengths) wav_padded = torch.zeros(len(batch["wav"]), 1, max_wav_len) for i, w in enumerate(batch["wav"]): mel_length = mel_lengths_adjusted[i] - w = np.pad(w, (0, self.ap.hop_length * self.outputs_per_step), mode="edge") - w = w[: mel_length * self.ap.hop_length] + w = np.pad(w, (0, self.audio_config.hop_length * self.outputs_per_step), mode="edge") + w = w[: mel_length * self.audio_config.hop_length] wav_padded[i, :, : w.shape[0]] = torch.from_numpy(w) wav_padded.transpose_(1, 2) @@ -647,14 +647,14 @@ class F0Dataset: def __init__( self, samples: Union[List[List], List[Dict]], - ap: "AudioProcessor", + audio_config: "AudioConfig", verbose=False, cache_path: str = None, precompute_num_workers=0, normalize_f0=True, ): self.samples = samples - self.ap = ap + self.audio_config = audio_config self.verbose = verbose self.cache_path = cache_path self.normalize_f0 = normalize_f0 @@ -711,9 +711,9 @@ def create_pitch_file_path(wav_file, cache_path): return pitch_file @staticmethod - def _compute_and_save_pitch(ap, wav_file, pitch_file=None): - wav = ap.load_wav(wav_file) - pitch = ap.compute_f0(wav) + def _compute_and_save_pitch(audio_config, wav_file, pitch_file=None): + wav = load_wav(wav_file) + pitch = compute_f0(x=wav, pitch_fmax=audio_config.pitch_fmax, hop_length=audio_config.hop_length, sample_rate=audio_config.sample_rate) if pitch_file: np.save(pitch_file, pitch) return pitch @@ -750,7 +750,7 @@ def compute_or_load(self, wav_file): """ pitch_file = self.create_pitch_file_path(wav_file, self.cache_path) if not os.path.exists(pitch_file): - pitch = self._compute_and_save_pitch(self.ap, wav_file, pitch_file) + pitch = self._compute_and_save_pitch(self.audio_config, wav_file, pitch_file) else: pitch = np.load(pitch_file) return pitch.astype(np.float32) diff --git a/TTS/tts/layers/feed_forward/decoder.py b/TTS/tts/layers/feed_forward/decoder.py index 34c586aab2..70598f91ae 100644 --- a/TTS/tts/layers/feed_forward/decoder.py +++ b/TTS/tts/layers/feed_forward/decoder.py @@ -117,7 +117,7 @@ def __init__(self, in_channels, out_channels, params): self.postnet = nn.Conv1d(in_channels, out_channels, 1) def forward(self, x, x_mask=None, g=None): # pylint: disable=unused-argument - # TODO: handle multi-speaker + # TODO: maybe pass g to every block x_mask = 1 if x_mask is None else x_mask o = self.transformer_block(x) * x_mask o = self.postnet(o) * x_mask @@ -191,6 +191,9 @@ def __init__( ): super().__init__() + if c_in_channels and c_in_channels != 0: + self.cond = nn.Conv1d(c_in_channels, in_hidden_channels, 1) + if decoder_type.lower() == "relative_position_transformer": self.decoder = RelativePositionTransformerDecoder( in_channels=in_hidden_channels, @@ -225,6 +228,9 @@ def forward(self, x, x_mask, g=None): # pylint: disable=unused-argument x_mask: [B, 1, T] g: [B, C_g, 1] """ - # TODO: implement multi-speaker - o = self.decoder(x, x_mask, g) + # multi-speaker conditioning + if hasattr(self, "cond") and self.cond is not None: + g = self.cond(g) + x = x + g + o = self.decoder(x=x, x_mask=x_mask, g=g) return o diff --git a/TTS/tts/layers/generic/transformer.py b/TTS/tts/layers/generic/transformer.py index 9b7ecee2ba..29f3c88819 100644 --- a/TTS/tts/layers/generic/transformer.py +++ b/TTS/tts/layers/generic/transformer.py @@ -36,7 +36,7 @@ def forward(self, src, src_mask=None, src_key_padding_mask=None): class FFTransformerBlock(nn.Module): - def __init__(self, in_out_channels, num_heads, hidden_channels_ffn, num_layers, dropout_p): + def __init__(self, in_out_channels, num_heads, hidden_channels_ffn, num_layers, dropout_p, kernel_size_fft): super().__init__() self.fft_layers = nn.ModuleList( [ @@ -45,6 +45,7 @@ def __init__(self, in_out_channels, num_heads, hidden_channels_ffn, num_layers, num_heads=num_heads, hidden_channels_ffn=hidden_channels_ffn, dropout_p=dropout_p, + kernel_size_fft=kernel_size_fft, ) for _ in range(num_layers) ] @@ -71,9 +72,16 @@ def forward(self, x, mask=None, g=None): # pylint: disable=unused-argument class FFTDurationPredictor: def __init__( - self, in_channels, hidden_channels, num_heads, num_layers, dropout_p=0.1, cond_channels=None + self, in_channels, hidden_channels, num_heads, num_layers, dropout_p=0.1, cond_channels=None, kernel_size_fft=3 ): # pylint: disable=unused-argument - self.fft = FFTransformerBlock(in_channels, num_heads, hidden_channels, num_layers, dropout_p) + self.fft = FFTransformerBlock( + in_out_channels=in_channels, + num_heads=num_heads, + hidden_channels=hidden_channels, + num_layers=num_layers, + dropout_p=dropout_p, + kernel_size_fft=kernel_size_fft, + ) self.proj = nn.Linear(in_channels, 1) def forward(self, x, mask=None, g=None): # pylint: disable=unused-argument diff --git a/TTS/tts/layers/losses.py b/TTS/tts/layers/losses.py index e03cf0840c..b30f566be6 100644 --- a/TTS/tts/layers/losses.py +++ b/TTS/tts/layers/losses.py @@ -8,7 +8,8 @@ from TTS.tts.utils.helpers import sequence_mask from TTS.tts.utils.ssim import ssim -from TTS.utils.audio import TorchSTFT +from TTS.utils.audio.torch_transforms import TorchSTFT +from TTS.vocoder.layers.losses import MultiScaleSTFTLoss # pylint: disable=abstract-method @@ -739,9 +740,9 @@ def forward( pitch_output, pitch_target, input_lens, - alignment_logprob=None, - alignment_hard=None, - alignment_soft=None, + aligner_logprob=None, + aligner_hard=None, + aligner_soft=None, binary_loss_weight=None, ): loss = 0 @@ -768,12 +769,12 @@ def forward( return_dict["loss_pitch"] = self.pitch_loss_alpha * pitch_loss if hasattr(self, "aligner_loss") and self.aligner_loss_alpha > 0: - aligner_loss = self.aligner_loss(alignment_logprob, input_lens, decoder_output_lens) + aligner_loss = self.aligner_loss(aligner_logprob, input_lens, decoder_output_lens) loss = loss + self.aligner_loss_alpha * aligner_loss return_dict["loss_aligner"] = self.aligner_loss_alpha * aligner_loss - if self.binary_alignment_loss_alpha > 0 and alignment_hard is not None: - binary_alignment_loss = self._binary_alignment_loss(alignment_hard, alignment_soft) + if self.binary_alignment_loss_alpha > 0 and aligner_hard is not None: + binary_alignment_loss = self._binary_alignment_loss(aligner_hard, aligner_soft) loss = loss + self.binary_alignment_loss_alpha * binary_alignment_loss if binary_loss_weight: return_dict["loss_binary_alignment"] = ( @@ -784,3 +785,93 @@ def forward( return_dict["loss"] = loss return return_dict + + +class ForwardTTSE2eLoss(nn.Module): + def __init__(self, config): + super().__init__() + self.encoder_loss = ForwardTTSLoss(config) + self.multi_scale_stft_loss = MultiScaleSTFTLoss(**config.multi_scale_stft_loss_params) + # for generator losses + self.mel_loss_alpha = ( + config.mel_loss_alpha + ) # mel_loss over the encoder model output as opposed to the vocoder output + self.feat_loss_alpha = config.feat_loss_alpha + self.gen_loss_alpha = config.gen_loss_alpha + self.multi_scale_stft_loss_alpha = config.multi_scale_stft_loss_alpha + + @staticmethod + def feature_loss(feats_real, feats_generated): + loss = 0 + for dr, dg in zip(feats_real, feats_generated): + for rl, gl in zip(dr, dg): + rl = rl.float().detach() + gl = gl.float() + loss += torch.mean(torch.abs(rl - gl)) + return loss * 2 + + @staticmethod + def generator_loss(scores_fake): + loss = 0 + gen_losses = [] + for dg in scores_fake: + dg = dg.float() + l = torch.mean((1 - dg) ** 2) + gen_losses.append(l) + loss += l + + return loss, gen_losses + + def forward( + self, + decoder_output, + decoder_target, + decoder_output_lens, + dur_output, + dur_target, + pitch_output, + pitch_target, + input_lens, + waveform, + waveform_hat, + aligner_logprob=None, + aligner_hard=None, + aligner_soft=None, + binary_loss_weight=None, + feats_fake=None, + feats_real=None, + scores_fake=None, + spec_slice=None, + spec_slice_hat=None, + ): + loss_dict = self.encoder_loss( + decoder_output=decoder_output, + decoder_target=decoder_target, + decoder_output_lens=decoder_output_lens, + dur_output=dur_output, + dur_target=dur_target, + pitch_output=pitch_output, + pitch_target=pitch_target, + input_lens=input_lens, + aligner_logprob=aligner_logprob, + aligner_hard=aligner_hard, + aligner_soft=aligner_soft, + binary_loss_weight=binary_loss_weight, + ) + + # vocoder generator losses + loss_feat = self.feature_loss(feats_real=feats_real, feats_generated=feats_fake) * self.feat_loss_alpha + loss_gen = self.generator_loss(scores_fake=scores_fake)[0] * self.gen_loss_alpha + loss_mel = torch.nn.functional.l1_loss(spec_slice, spec_slice_hat) * self.mel_loss_alpha + loss_stft_mg, loss_stft_sc = self.multi_scale_stft_loss(y_hat=waveform_hat, y=waveform) + loss_stft_mg = loss_stft_mg * self.multi_scale_stft_loss_alpha + loss_stft_sc = loss_stft_sc * self.multi_scale_stft_loss_alpha + + loss_dict["vocoder_loss_mel"] = loss_mel + loss_dict["vocoder_loss_feat"] = loss_feat + loss_dict["vocoder_loss_gen"] = loss_gen + loss_dict["vocoder_loss_stft_mg"] = loss_stft_mg + loss_dict["vocoder_loss_stft_sc"] = loss_stft_sc + + loss_dict["loss"] = loss_dict["loss"] + loss_mel + loss_feat + loss_gen + loss_stft_sc + loss_stft_mg + return loss_dict diff --git a/TTS/tts/layers/vits/discriminator.py b/TTS/tts/layers/vits/discriminator.py index 148f283c90..4aba36df6b 100644 --- a/TTS/tts/layers/vits/discriminator.py +++ b/TTS/tts/layers/vits/discriminator.py @@ -2,7 +2,7 @@ from torch import nn from torch.nn.modules.conv import Conv1d -from TTS.vocoder.models.hifigan_discriminator import DiscriminatorP, MultiPeriodDiscriminator +from TTS.vocoder.models.hifigan_discriminator import DiscriminatorP class DiscriminatorS(torch.nn.Module): @@ -12,19 +12,19 @@ class DiscriminatorS(torch.nn.Module): use_spectral_norm (bool): if `True` swith to spectral norm instead of weight norm. """ - def __init__(self, use_spectral_norm=False): + def __init__(self, use_spectral_norm=False, upsampling_rates=[4, 4, 4, 4]): super().__init__() norm_f = nn.utils.spectral_norm if use_spectral_norm else nn.utils.weight_norm - self.convs = nn.ModuleList( - [ - norm_f(Conv1d(1, 16, 15, 1, padding=7)), - norm_f(Conv1d(16, 64, 41, 4, groups=4, padding=20)), - norm_f(Conv1d(64, 256, 41, 4, groups=16, padding=20)), - norm_f(Conv1d(256, 1024, 41, 4, groups=64, padding=20)), - norm_f(Conv1d(1024, 1024, 41, 4, groups=256, padding=20)), - norm_f(Conv1d(1024, 1024, 5, 1, padding=2)), - ] - ) + self.convs = nn.ModuleList([norm_f(Conv1d(1, 16, 15, 1, padding=7))]) + groups = 4 + in_channels = 16 + out_channels = 64 + for rate in upsampling_rates: + self.convs.append(norm_f(Conv1d(in_channels, out_channels, 41, rate, groups=groups, padding=20))) + groups = min(groups * rate, 256) + in_channels = min(in_channels * rate, 1024) + out_channels = min(out_channels * rate, 1024) + self.convs += [norm_f(Conv1d(1024, 1024, 5, 1, padding=2))] self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1)) def forward(self, x): @@ -58,10 +58,10 @@ class VitsDiscriminator(nn.Module): use_spectral_norm (bool): if `True` swith to spectral norm instead of weight norm. """ - def __init__(self, periods=(2, 3, 5, 7, 11), use_spectral_norm=False): + def __init__(self, use_spectral_norm=False, periods=[2, 3, 5, 7, 11], upsampling_rates=[4,4,4,4]): super().__init__() self.nets = nn.ModuleList() - self.nets.append(DiscriminatorS(use_spectral_norm=use_spectral_norm)) + self.nets.append(DiscriminatorS(use_spectral_norm=use_spectral_norm, upsampling_rates=upsampling_rates)) self.nets.extend([DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods]) def forward(self, x, x_hat=None): diff --git a/TTS/tts/models/base_tts.py b/TTS/tts/models/base_tts.py index c71872d3ad..cad6765459 100644 --- a/TTS/tts/models/base_tts.py +++ b/TTS/tts/models/base_tts.py @@ -10,6 +10,7 @@ from torch.utils.data.sampler import WeightedRandomSampler from trainer.torch import DistributedSampler, DistributedSamplerWrapper +from TTS.config import get_from_config_or_model_args, get_from_config_or_model_args_with_default from TTS.model import BaseTrainerModel from TTS.tts.datasets.dataset import TTSDataset from TTS.tts.utils.data import get_length_balancer_weights @@ -30,8 +31,8 @@ class BaseTTS(BaseTrainerModel): def __init__( self, config: Coqpit, - ap: "AudioProcessor", tokenizer: "TTSTokenizer", + ap: "AudioProcessor" = None, speaker_manager: SpeakerManager = None, language_manager: LanguageManager = None, ): @@ -108,7 +109,9 @@ def init_multispeaker(self, config: Coqpit, data: List = None): self.speaker_embedding.weight.data.normal_(0, 0.3) def get_aux_input(self, **kwargs) -> Dict: - """Prepare and return `aux_input` used by `forward()`""" + """Prepare and return `aux_input` used by `forward()` + + If not overridden, this function returns a dictionary with None values""" return {"speaker_id": None, "style_wav": None, "d_vector": None, "language_id": None} def get_aux_input_from_test_setences(self, sentence_info): @@ -311,7 +314,7 @@ def get_data_loader( compute_f0=config.get("compute_f0", False), f0_cache_path=config.get("f0_cache_path", None), samples=samples, - ap=self.ap, + audio_config=self.config.audio, return_wav=config.return_wav if "return_wav" in config else False, batch_group_size=0 if is_eval else config.batch_group_size * config.batch_size, min_text_len=config.min_text_len, @@ -323,7 +326,9 @@ def get_data_loader( use_noise_augment=False if is_eval else config.use_noise_augment, verbose=verbose, speaker_id_mapping=speaker_id_mapping, - d_vector_mapping=d_vector_mapping if config.use_d_vector_file else None, + d_vector_mapping=d_vector_mapping + if get_from_config_or_model_args(config, "use_d_vector_file") + else None, tokenizer=self.tokenizer, start_by_longest=config.start_by_longest, language_id_mapping=language_id_mapping, @@ -428,3 +433,21 @@ def on_init_start(self, trainer): trainer.config.save_json(os.path.join(trainer.output_path, "config.json")) print(f" > `language_ids.json` is saved to {output_path}.") print(" > `language_ids_file` is updated in the config.json.") + + +class BaseTTSE2E(BaseTTS): + def _set_model_args(self, config: Coqpit): + self.config = config + if "Config" in config.__class__.__name__: + num_chars = ( + self.config.model_args.num_chars if self.tokenizer is None else self.tokenizer.characters.num_chars + ) + self.config.model_args.num_chars = num_chars + self.config.num_chars = num_chars + self.args = config.model_args + self.args.num_chars = num_chars + elif "Args" in config.__class__.__name__: + self.args = config + self.args.num_chars = self.args.num_chars + else: + raise ValueError("config must be either a *Config or *Args") diff --git a/TTS/tts/models/forward_tts.py b/TTS/tts/models/forward_tts.py index a1273f7f7c..01e3dff57f 100644 --- a/TTS/tts/models/forward_tts.py +++ b/TTS/tts/models/forward_tts.py @@ -95,6 +95,9 @@ class ForwardTTSArgs(Coqpit): num_speakers (int): Number of speakers for the speaker embedding layer. Defaults to 0. + use_speaker_embedding (bool): + Whether to use a speaker embedding layer. Defaults to False. + speakers_file (str): Path to the speaker mapping file for the Speaker Manager. Defaults to None. @@ -107,8 +110,10 @@ class ForwardTTSArgs(Coqpit): d_vector_dim (int): Number of d-vector channels. Defaults to 0. - """ + d_vector_file (str): + Path to the d-vector file. Defaults to None. + """ num_chars: int = None out_channels: int = 80 hidden_channels: int = 384 @@ -126,16 +131,29 @@ class ForwardTTSArgs(Coqpit): length_scale: int = 1 encoder_type: str = "fftransformer" encoder_params: dict = field( - default_factory=lambda: {"hidden_channels_ffn": 1024, "num_heads": 1, "num_layers": 6, "dropout_p": 0.1} + default_factory=lambda: { + "hidden_channels_ffn": 1024, + "num_heads": 1, + "num_layers": 6, + "dropout_p": 0.1, + "kernel_size_fft": 9, + } ) decoder_type: str = "fftransformer" decoder_params: dict = field( - default_factory=lambda: {"hidden_channels_ffn": 1024, "num_heads": 1, "num_layers": 6, "dropout_p": 0.1} + default_factory=lambda: { + "hidden_channels_ffn": 1024, + "num_heads": 1, + "num_layers": 6, + "dropout_p": 0.1, + "kernel_size_fft": 9, + } ) detach_duration_predictor: bool = False max_duration: int = 75 num_speakers: int = 1 use_speaker_embedding: bool = False + speaker_embedding_channels: int = 256 speakers_file: str = None use_d_vector_file: bool = False d_vector_dim: int = None @@ -152,6 +170,11 @@ class ForwardTTS(BaseTTS): If the pitch predictor is used, the model trains a pitch predictor that predicts average pitch value for each input character as in the FastPitch model. + :: + + |-----> (optional) PitchPredictor(o_en, spk_emb) --> pitch_emb --> o_en = o_en + pitch_emb-----| -> CondConv(spk_emb) -> spk_proj + spk, text -> Encoder(text, spk)--> o_en, spk_emb -----> DurationPredictor(o_en, spk_emb)--> dur -------------------------> Expand(o_en, dur) -> PositionEncoding(o_en_expand) -> Decoder(o_en_expand_pos, spk_proj) -> mel_out + `ForwardTTS` can be configured to one of these architectures, - FastPitch @@ -165,9 +188,18 @@ class ForwardTTS(BaseTTS): Defaults to None. Examples: - >>> from TTS.tts.models.fast_pitch import ForwardTTS, ForwardTTSArgs - >>> config = ForwardTTSArgs() - >>> model = ForwardTTS(config) + Instantiate the model directly. + + >>> from TTS.tts.models.forward_tts_e2e import ForwardTTSE2e, ForwardTTSE2eArgs + >>> args = ForwardTTSE2eArgs() + >>> model = ForwardTTSE2e(args) + + Instantiate the model from config. + + >>> from TTS.tts.models.forward_tts_e2e import ForwardTTSE2e + >>> from TTS.tts.configs.fast_pitch_e2e_config import FastPitchE2eConfig + >>> config = FastPitchE2eConfig(num_chars=10) + >>> model = ForwardTTSE2e.init_from_config(config) """ # pylint: disable=dangerous-default-value @@ -213,18 +245,20 @@ def __init__( ) self.duration_predictor = DurationPredictor( - self.args.hidden_channels + self.embedded_speaker_dim, + self.args.hidden_channels, self.args.duration_predictor_hidden_channels, self.args.duration_predictor_kernel_size, self.args.duration_predictor_dropout_p, + cond_channels=self.embedded_speaker_dim, ) if self.args.use_pitch: self.pitch_predictor = DurationPredictor( - self.args.hidden_channels + self.embedded_speaker_dim, + self.args.hidden_channels, self.args.pitch_predictor_hidden_channels, self.args.pitch_predictor_kernel_size, self.args.pitch_predictor_dropout_p, + cond_channels=self.embedded_speaker_dim, ) self.pitch_emb = nn.Conv1d( 1, @@ -245,28 +279,67 @@ def init_multispeaker(self, config: Coqpit): config (Coqpit): Model configuration. """ self.embedded_speaker_dim = 0 - # init speaker manager - if self.speaker_manager is None and (config.use_d_vector_file or config.use_speaker_embedding): - raise ValueError( - " > SpeakerManager is not provided. You must provide the SpeakerManager before initializing a multi-speaker model." - ) - # set number of speakers - if self.speaker_manager is not None: + self.num_speakers = self.args.num_speakers + self.audio_transform = None + + if self.speaker_manager: self.num_speakers = self.speaker_manager.num_speakers - # init d-vector embedding - if config.use_d_vector_file: - self.embedded_speaker_dim = config.d_vector_dim - if self.args.d_vector_dim != self.args.hidden_channels: - self.proj_g = nn.Conv1d(self.args.d_vector_dim, self.args.hidden_channels, 1) - # init speaker embedding layer - if config.use_speaker_embedding and not config.use_d_vector_file: - print(" > Init speaker_embedding layer.") - self.emb_g = nn.Embedding(self.num_speakers, self.args.hidden_channels) - nn.init.uniform_(self.emb_g.weight, -0.1, 0.1) + + if self.args.use_speaker_embedding: + self._init_speaker_embedding() + + if self.args.use_d_vector_file: + self._init_d_vector() + + def _init_speaker_embedding(self): + """Init class arguments for training with a speaker embedding layer.""" + # pylint: disable=attribute-defined-outside-init + if self.num_speakers > 0: + print(" > initialization of speaker-embedding layers.") + self.embedded_speaker_dim = self.args.speaker_embedding_channels + self.emb_g = nn.Embedding(self.num_speakers, self.embedded_speaker_dim) + + def _init_d_vector(self): + """Init class arguments for training with external speaker embeddings.""" + # pylint: disable=attribute-defined-outside-init + if hasattr(self, "emb_g"): + raise ValueError("[!] Speaker embedding layer already initialized before d_vector settings.") + self.embedded_speaker_dim = self.args.d_vector_dim + + @staticmethod + def _set_cond_input(aux_input: Dict): + """Set auxilliary model inputs based on the model configuration.""" + sid, g, lid = None, None, None + if "speaker_ids" in aux_input and aux_input["speaker_ids"] is not None: + sid = aux_input["speaker_ids"] + if sid.ndim == 0: + sid = sid.unsqueeze_(0) + if "d_vectors" in aux_input and aux_input["d_vectors"] is not None: + g = torch.nn.functional.normalize(aux_input["d_vectors"]).unsqueeze(-1) + if g.ndim == 2: + g = g.unsqueeze_(0) + + if "language_ids" in aux_input and aux_input["language_ids"] is not None: + lid = aux_input["language_ids"] + if lid.ndim == 0: + lid = lid.unsqueeze_(0) + + return sid, g, lid + + def get_aux_input(self, aux_input: Dict): + """Get auxilliary model inputs based on the model configuration.""" + sid, g, lid = self._set_cond_input(aux_input) + return {"speaker_ids": sid, "style_wav": None, "d_vectors": g, "language_ids": lid} @staticmethod def generate_attn(dr, x_mask, y_mask=None): - """Generate an attention mask from the durations. + """Generate an attention mask from the linear scale durations. + + Args: + dr (Tensor): Linear scale durations. + x_mask (Tensor): Mask for the input (character) sequence. + y_mask (Tensor): Mask for the output (spectrogram) sequence. Compute it from the predicted durations + if None. Defaults to None. Shapes - dr: :math:`(B, T_{en})` @@ -283,8 +356,14 @@ def generate_attn(dr, x_mask, y_mask=None): return attn def expand_encoder_outputs(self, en, dr, x_mask, y_mask): - """Generate attention alignment map from durations and - expand encoder outputs + """Generate attention alignment map from linear scale durations and + expand encoder outputs. + + Args: + en (Tensor): Encoder outputs. + dr (Tensor): Linear scale durations. + x_mask (Tensor): Mask for the input (character) sequence. + y_mask (Tensor): Mask for the output (spectrogram) sequence. Shapes: - en: :math:`(B, D_{en}, T_{en})` @@ -316,8 +395,8 @@ def format_durations(self, o_dr_log, x_mask): 5. Round the duration values. Args: - o_dr_log: Log scale durations. - x_mask: Input text mask. + o_dr_log (Tensor): Log scale durations. + x_mask (Tensor): Input text mask. Shapes: - o_dr_log: :math:`(B, T_{de})` @@ -362,46 +441,57 @@ def _forward_encoder( # encoder pass o_en = self.encoder(torch.transpose(x_emb, 1, -1), x_mask) # speaker conditioning - # TODO: try different ways of conditioning - if g is not None: - o_en = o_en + g - return o_en, x_mask, g, x_emb + return x_emb, x_mask, g, o_en - def _forward_decoder( - self, - o_en: torch.FloatTensor, - dr: torch.IntTensor, - x_mask: torch.FloatTensor, - y_lengths: torch.IntTensor, - g: torch.FloatTensor, - ) -> Tuple[torch.FloatTensor, torch.FloatTensor]: - """Decoding forward pass. + def _expand_encoder( + self, o_en: torch.FloatTensor, y_lengths: torch.IntTensor, dr: torch.IntTensor, x_mask: torch.FloatTensor + ): + """Expand encoder outputs to match the decoder. 1. Compute the decoder output mask 2. Expand encoder output with the durations. 3. Apply position encoding. - 4. Add speaker embeddings if multi-speaker mode. - 5. Run the decoder. Args: o_en (torch.FloatTensor): Encoder output. + y_lengths (torch.IntTensor): Output sequence lengths. dr (torch.IntTensor): Ground truth durations or alignment network durations. x_mask (torch.IntTensor): Input sequence mask. - y_lengths (torch.IntTensor): Output sequence lengths. - g (torch.FloatTensor): Conditioning vectors. In general speaker embeddings. Returns: - Tuple[torch.FloatTensor, torch.FloatTensor]: Decoder output, attention map from durations. + Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]: Decoder mask, expanded encoder outputs, + attention map """ y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).to(o_en.dtype) # expand o_en with durations - o_en_ex, attn = self.expand_encoder_outputs(o_en, dr, x_mask, y_mask) + o_en_ex, attn = self.expand_encoder_outputs(en=o_en, dr=dr, x_mask=x_mask, y_mask=y_mask) # positional encoding if hasattr(self, "pos_encoder"): o_en_ex = self.pos_encoder(o_en_ex, y_mask) + return y_mask, o_en_ex, attn.transpose(1, 2) + + def _forward_decoder( + self, + o_en_ex: torch.FloatTensor, + y_mask: torch.FloatTensor, + g: torch.FloatTensor, + ) -> Tuple[torch.FloatTensor, torch.FloatTensor]: + """Decoding forward pass. + + 1. Run the decoder network. + + Args: + o_en_ex (torch.FloatTensor): Expanded encoder output. + y_lengths (torch.IntTensor): Output sequence lengths. + g (torch.FloatTensor): Conditioning vectors. In general speaker embeddings. + + Returns: + Tuple[torch.FloatTensor, torch.FloatTensor]: Decoder output, attention map from durations. + """ + # decoder pass o_de = self.decoder(o_en_ex, y_mask, g=g) - return o_de.transpose(1, 2), attn.transpose(1, 2) + return o_de.transpose(1, 2) def _forward_pitch_predictor( self, @@ -409,6 +499,7 @@ def _forward_pitch_predictor( x_mask: torch.IntTensor, pitch: torch.FloatTensor = None, dr: torch.IntTensor = None, + g: torch.FloatTensor = None, ) -> Tuple[torch.FloatTensor, torch.FloatTensor]: """Pitch predictor forward pass. @@ -421,6 +512,7 @@ def _forward_pitch_predictor( x_mask (torch.IntTensor): Input sequence mask. pitch (torch.FloatTensor, optional): Ground truth pitch values. Defaults to None. dr (torch.IntTensor, optional): Ground truth durations. Defaults to None. + g (torch.FloatTensor, optional): Conditioning vectors. In general speaker embeddings. Defaults to None. Returns: Tuple[torch.FloatTensor, torch.FloatTensor]: Pitch embedding, pitch prediction. @@ -431,7 +523,7 @@ def _forward_pitch_predictor( - pitch: :math:`(B, 1, T_{de})` - dr: :math:`(B, T_{en})` """ - o_pitch = self.pitch_predictor(o_en, x_mask) + o_pitch = self.pitch_predictor(o_en, x_mask, g=g) if pitch is not None: avg_pitch = average_over_durations(pitch, dr) o_pitch_emb = self.pitch_emb(avg_pitch) @@ -466,19 +558,19 @@ def _forward_aligner( - x_mask: :math:`[B, 1, T_en]` - y_mask: :math:`[B, 1, T_de]` - - o_alignment_dur: :math:`[B, T_en]` - - alignment_soft: :math:`[B, T_en, T_de]` - - alignment_logprob: :math:`[B, 1, T_de, T_en]` - - alignment_mas: :math:`[B, T_en, T_de]` + - aligner_durations: :math:`[B, T_en]` + - aligner_soft: :math:`[B, T_en, T_de]` + - aligner_logprob: :math:`[B, 1, T_de, T_en]` + - aligner_mas: :math:`[B, T_en, T_de]` """ attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2) - alignment_soft, alignment_logprob = self.aligner(y.transpose(1, 2), x.transpose(1, 2), x_mask, None) - alignment_mas = maximum_path( - alignment_soft.squeeze(1).transpose(1, 2).contiguous(), attn_mask.squeeze(1).contiguous() + aligner_soft, aligner_logprob = self.aligner(y.transpose(1, 2), x.transpose(1, 2), x_mask, None) + aligner_mas = maximum_path( + aligner_soft.squeeze(1).transpose(1, 2).contiguous(), attn_mask.squeeze(1).contiguous() ) - o_alignment_dur = torch.sum(alignment_mas, -1).int() - alignment_soft = alignment_soft.squeeze(1).transpose(1, 2) - return o_alignment_dur, alignment_soft, alignment_logprob, alignment_mas + aligner_durations = torch.sum(aligner_mas, -1).int() + aligner_soft = aligner_soft.squeeze(1).transpose(1, 2) + return aligner_durations, aligner_soft, aligner_logprob, aligner_mas def _set_speaker_input(self, aux_input: Dict): d_vectors = aux_input.get("d_vectors", None) @@ -523,77 +615,89 @@ def forward( - g: :math:`[B, C]` - pitch: :math:`[B, 1, T]` """ - g = self._set_speaker_input(aux_input) + spk = self._set_speaker_input(aux_input) # compute sequence masks - y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).float() - x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.shape[1]), 1).float() + y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).float() # [B, 1, T_max2] + x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.shape[1]), 1).float() # [B, 1, T_max] # encoder pass - o_en, x_mask, g, x_emb = self._forward_encoder(x, x_mask, g) + x_emb, x_mask, spk_emb, o_en = self._forward_encoder( + x, x_mask, spk + ) # [B, T_max, C_en], [B, 1, T_max], [B, C], [B, C_en, T_max] # duration predictor pass if self.args.detach_duration_predictor: - o_dr_log = self.duration_predictor(o_en.detach(), x_mask) + o_dr_log = self.duration_predictor(x=o_en.detach(), x_mask=x_mask, g=spk_emb) # [B, 1, T_max] else: - o_dr_log = self.duration_predictor(o_en, x_mask) + o_dr_log = self.duration_predictor(x=o_en, x_mask=x_mask, g=spk_emb) # [B, 1, T_max] o_dr = torch.clamp(torch.exp(o_dr_log) - 1, 0, self.max_duration) # generate attn mask from predicted durations - o_attn = self.generate_attn(o_dr.squeeze(1), x_mask) + dur_predictor_attn = self.generate_attn(o_dr.squeeze(1), x_mask) # [B, T_max, T_max2'] # aligner - o_alignment_dur = None - alignment_soft = None - alignment_logprob = None - alignment_mas = None + aligner_durations = None + aligner_soft = None + aligner_logprob = None + aligner_mas = None if self.use_aligner: - o_alignment_dur, alignment_soft, alignment_logprob, alignment_mas = self._forward_aligner( - x_emb, y, x_mask, y_mask + # TODO: try passing o_en instead of x_emb + aligner_durations, aligner_soft, aligner_logprob, aligner_mas = self._forward_aligner( + x=x_emb, y=y, x_mask=x_mask, y_mask=y_mask ) - alignment_soft = alignment_soft.transpose(1, 2) - alignment_mas = alignment_mas.transpose(1, 2) - dr = o_alignment_dur + aligner_soft = aligner_soft.transpose(1, 2) # [B, T_max, T_max2] -> [B, T_max2, T_max] + aligner_mas = aligner_mas.transpose(1, 2) # [B, T_max, T_max2] -> [B, T_max2, T_max] + dr = aligner_durations # [B, T_max] # pitch predictor pass o_pitch = None avg_pitch = None if self.args.use_pitch: - o_pitch_emb, o_pitch, avg_pitch = self._forward_pitch_predictor(o_en, x_mask, pitch, dr) + o_pitch_emb, o_pitch, avg_pitch = self._forward_pitch_predictor( + o_en=o_en, x_mask=x_mask, pitch=pitch, dr=dr, g=spk_emb + ) o_en = o_en + o_pitch_emb + # expand encoder outputs + y_mask, o_en_ex, attn = self._expand_encoder( + o_en=o_en, y_lengths=y_lengths, dr=dr, x_mask=x_mask + ) # [B, 1, T_max2], [B, C_en, T_max2], [B, T_max2, T_max] # decoder pass - o_de, attn = self._forward_decoder( - o_en, dr, x_mask, y_lengths, g=None - ) # TODO: maybe pass speaker embedding (g) too + o_de = self._forward_decoder(o_en_ex=o_en_ex, y_mask=y_mask, g=spk_emb) # [B, T_max2, C_de] outputs = { "model_outputs": o_de, # [B, T, C] + "spk_emb": spk_emb, # [B, C] "durations_log": o_dr_log.squeeze(1), # [B, T] "durations": o_dr.squeeze(1), # [B, T] - "attn_durations": o_attn, # for visualization [B, T_en, T_de'] + "attn_durations": dur_predictor_attn, # for visualization [B, T_en, T_de'] "pitch_avg": o_pitch, "pitch_avg_gt": avg_pitch, "alignments": attn, # [B, T_de, T_en] - "alignment_soft": alignment_soft, - "alignment_mas": alignment_mas, - "o_alignment_dur": o_alignment_dur, - "alignment_logprob": alignment_logprob, + "aligner_soft": aligner_soft, + "aligner_mas": aligner_mas, + "aligner_durations": aligner_durations, + "aligner_logprob": aligner_logprob, "x_mask": x_mask, "y_mask": y_mask, + "o_en_ex": o_en_ex, } return outputs @torch.no_grad() - def inference(self, x, aux_input={"d_vectors": None, "speaker_ids": None}): # pylint: disable=unused-argument + def inference( + self, x, aux_input={"d_vectors": None, "speaker_ids": None}, skip_decoder=False + ): # pylint: disable=unused-argument """Model's inference pass. Args: x (torch.LongTensor): Input character sequence. aux_input (Dict): Auxiliary model inputs. Defaults to `{"d_vectors": None, "speaker_ids": None}`. + skip_decoder (bool): Whether to skip the decoder. Defaults to False. Shapes: - x: [B, T_max] - x_lengths: [B] - g: [B, C] """ - g = self._set_speaker_input(aux_input) + spk = self._set_speaker_input(aux_input) x_lengths = torch.tensor(x.shape[1:2]).to(x.device) x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.shape[1]), 1).to(x.dtype).float() # encoder pass - o_en, x_mask, g, _ = self._forward_encoder(x, x_mask, g) + _, x_mask, spk_emb, o_en = self._forward_encoder(x, x_mask, spk) # duration predictor pass o_dr_log = self.duration_predictor(o_en, x_mask) o_dr = self.format_durations(o_dr_log, x_mask).squeeze(1) @@ -601,16 +705,21 @@ def inference(self, x, aux_input={"d_vectors": None, "speaker_ids": None}): # p # pitch predictor pass o_pitch = None if self.args.use_pitch: - o_pitch_emb, o_pitch = self._forward_pitch_predictor(o_en, x_mask) + o_pitch_emb, o_pitch = self._forward_pitch_predictor(o_en=o_en, x_mask=x_mask, g=spk_emb) o_en = o_en + o_pitch_emb - # decoder pass - o_de, attn = self._forward_decoder(o_en, o_dr, x_mask, y_lengths, g=None) + # expand encoder outputs + y_mask, o_en_ex, attn = self._expand_encoder(o_en=o_en, y_lengths=y_lengths, dr=o_dr, x_mask=x_mask) outputs = { - "model_outputs": o_de, "alignments": attn, "pitch": o_pitch, - "durations_log": o_dr_log, + "durations": o_dr, + "spk_emb": spk_emb, } + if skip_decoder: + outputs["o_en_ex"] = o_en_ex + else: + # decoder pass + outputs["model_outputs"] = self._forward_decoder(o_en_ex=o_en_ex, y_mask=y_mask, g=spk_emb) return outputs def train_step(self, batch: dict, criterion: nn.Module): @@ -630,7 +739,7 @@ def train_step(self, batch: dict, criterion: nn.Module): ) # use aligner's output as the duration target if self.use_aligner: - durations = outputs["o_alignment_dur"] + durations = outputs["aligner_durations"] # use float32 in AMP with autocast(enabled=False): # compute loss @@ -643,9 +752,9 @@ def train_step(self, batch: dict, criterion: nn.Module): pitch_output=outputs["pitch_avg"] if self.use_pitch else None, pitch_target=outputs["pitch_avg_gt"] if self.use_pitch else None, input_lens=text_lengths, - alignment_logprob=outputs["alignment_logprob"] if self.use_aligner else None, - alignment_soft=outputs["alignment_soft"], - alignment_hard=outputs["alignment_mas"], + aligner_logprob=outputs["aligner_logprob"] if self.use_aligner else None, + aligner_soft=outputs["aligner_soft"], + aligner_hard=outputs["aligner_mas"], binary_loss_weight=self.binary_loss_weight, ) # compute duration error @@ -655,7 +764,7 @@ def train_step(self, batch: dict, criterion: nn.Module): return outputs, loss_dict - def _create_logs(self, batch, outputs, ap): + def create_logs(self, batch, outputs, ap): """Create common logger outputs.""" model_outputs = outputs["model_outputs"] alignments = outputs["alignments"] @@ -694,7 +803,7 @@ def _create_logs(self, batch, outputs, ap): def train_log( self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int ) -> None: # pylint: disable=no-self-use - figures, audios = self._create_logs(batch, outputs, self.ap) + figures, audios = self.create_logs(batch, outputs, self.ap) logger.train_figures(steps, figures) logger.train_audios(steps, audios, self.ap.sample_rate) @@ -702,7 +811,7 @@ def eval_step(self, batch: dict, criterion: nn.Module): return self.train_step(batch, criterion) def eval_log(self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int) -> None: - figures, audios = self._create_logs(batch, outputs, self.ap) + figures, audios = self.create_logs(batch, outputs, self.ap) logger.eval_figures(steps, figures) logger.eval_audios(steps, audios, self.ap.sample_rate) diff --git a/TTS/tts/models/forward_tts_e2e.py b/TTS/tts/models/forward_tts_e2e.py new file mode 100644 index 0000000000..5bda50dedb --- /dev/null +++ b/TTS/tts/models/forward_tts_e2e.py @@ -0,0 +1,1039 @@ +import os +from dataclasses import dataclass, field +from itertools import chain +from typing import Dict, List, Tuple, Union + +import numpy as np +import pyworld as pw +import torch +import torch.distributed as dist +from coqpit import Coqpit +from torch import nn +from torch.cuda.amp.autocast_mode import autocast +from torch.utils.data import DataLoader +from trainer.trainer_utils import get_optimizer, get_scheduler + +from TTS.tts.datasets.dataset import F0Dataset, TTSDataset, _parse_sample +from TTS.tts.layers.losses import ForwardTTSE2eLoss, VitsDiscriminatorLoss +from TTS.tts.layers.vits.discriminator import VitsDiscriminator +from TTS.tts.models.base_tts import BaseTTSE2E +from TTS.tts.models.forward_tts import ForwardTTS, ForwardTTSArgs +from TTS.tts.models.vits import load_audio, wav_to_mel +from TTS.tts.utils.helpers import rand_segments, segment, sequence_mask +from TTS.tts.utils.speakers import SpeakerManager +from TTS.tts.utils.synthesis import synthesis +from TTS.tts.utils.text.tokenizer import TTSTokenizer +from TTS.tts.utils.visual import plot_alignment, plot_avg_pitch, plot_spectrogram +from TTS.utils.audio.numpy_transforms import build_mel_basis, compute_f0 +from TTS.utils.audio.numpy_transforms import db_to_amp as db_to_amp_numpy +from TTS.utils.audio.numpy_transforms import mel_to_wav as mel_to_wav_numpy +from TTS.vocoder.models.hifigan_generator import HifiganGenerator +from TTS.vocoder.utils.generic_utils import plot_results + + +def id_to_torch(aux_id, cuda=False): + if aux_id is not None: + aux_id = np.asarray(aux_id) + aux_id = torch.from_numpy(aux_id) + if cuda: + return aux_id.cuda() + return aux_id + + +def embedding_to_torch(d_vector, cuda=False): + if d_vector is not None: + d_vector = np.asarray(d_vector) + d_vector = torch.from_numpy(d_vector).type(torch.FloatTensor) + d_vector = d_vector.squeeze().unsqueeze(0) + if cuda: + return d_vector.cuda() + return d_vector + + +def numpy_to_torch(np_array, dtype, cuda=False): + if np_array is None: + return None + tensor = torch.as_tensor(np_array, dtype=dtype) + if cuda: + return tensor.cuda() + return tensor + + +############################## +# DATASET +############################## + + +class ForwardTTSE2eF0Dataset(F0Dataset): + """Override F0Dataset to avoid the AudioProcessor.""" + + def __init__( + self, + audio_config: "AudioConfig", + samples: Union[List[List], List[Dict]], + verbose=False, + cache_path: str = None, + precompute_num_workers=0, + normalize_f0=True, + ): + super().__init__( + samples=samples, + audio_config=audio_config, + verbose=verbose, + cache_path=cache_path, + precompute_num_workers=precompute_num_workers, + normalize_f0=normalize_f0, + ) + + @staticmethod + def _compute_and_save_pitch(audio_config, wav_file, pitch_file=None): + wav, _ = load_audio(wav_file) + f0 = compute_f0( + x=wav.numpy()[0], + sample_rate=audio_config.sample_rate, + hop_length=audio_config.hop_length, + pitch_fmax=audio_config.pitch_fmax, + ) + # skip the last F0 value to align with the spectrogram + if wav.shape[1] % audio_config.hop_length != 0: + f0 = f0[:-1] + if pitch_file: + np.save(pitch_file, f0) + return f0 + + def compute_or_load(self, wav_file): + """ + compute pitch and return a numpy array of pitch values + """ + pitch_file = self.create_pitch_file_path(wav_file, self.cache_path) + if not os.path.exists(pitch_file): + pitch = self._compute_and_save_pitch( + audio_config=self.audio_config, wav_file=wav_file, pitch_file=pitch_file + ) + else: + pitch = np.load(pitch_file) + return pitch.astype(np.float32) + + +class ForwardTTSE2eDataset(TTSDataset): + def __init__(self, *args, **kwargs): + # don't init the default F0Dataset in TTSDataset + compute_f0 = kwargs.pop("compute_f0", False) + kwargs["compute_f0"] = False + + super().__init__(*args, **kwargs) + + self.compute_f0 = compute_f0 + self.pad_id = self.tokenizer.characters.pad_id + self.audio_config = kwargs["audio_config"] + + if self.compute_f0: + self.f0_dataset = ForwardTTSE2eF0Dataset( + audio_config=self.audio_config, + samples=self.samples, + cache_path=kwargs["f0_cache_path"], + precompute_num_workers=kwargs["precompute_num_workers"], + ) + + def __getitem__(self, idx): + item = self.samples[idx] + raw_text = item["text"] + + wav, _ = load_audio(item["audio_file"]) + wav_filename = os.path.basename(item["audio_file"]) + + token_ids = self.get_token_ids(idx, item["text"]) + + f0 = None + if self.compute_f0: + f0 = self.get_f0(idx)["f0"] + + # after phonemization the text length may change + # this is a shameful 🤭 hack to prevent longer phonemes + # TODO: find a better fix + if len(token_ids) > self.max_text_len or wav.shape[1] < self.min_audio_len: + self.rescue_item_idx += 1 + return self.__getitem__(self.rescue_item_idx) + + return { + "raw_text": raw_text, + "token_ids": token_ids, + "token_len": len(token_ids), + "wav": wav, + "pitch": f0, + "wav_file": wav_filename, + "speaker_name": item["speaker_name"], + "language_name": item["language"], + } + + @property + def lengths(self): + lens = [] + for item in self.samples: + _, wav_file, *_ = _parse_sample(item) + audio_len = os.path.getsize(wav_file) / 16 * 8 # assuming 16bit audio + lens.append(audio_len) + return lens + + def collate_fn(self, batch): + """ + Return Shapes: + - tokens: :math:`[B, T]` + - token_lens :math:`[B]` + - token_rel_lens :math:`[B]` + - pitch :math:`[B, T]` + - waveform: :math:`[B, 1, T]` + - waveform_lens: :math:`[B]` + - waveform_rel_lens: :math:`[B]` + - speaker_names: :math:`[B]` + - language_names: :math:`[B]` + - audiofile_paths: :math:`[B]` + - raw_texts: :math:`[B]` + """ + # convert list of dicts to dict of lists + B = len(batch) + batch = {k: [dic[k] for dic in batch] for k in batch[0]} + + _, ids_sorted_decreasing = torch.sort( + torch.LongTensor([x.size(1) for x in batch["wav"]]), dim=0, descending=True + ) + + max_text_len = max([len(x) for x in batch["token_ids"]]) + token_lens = torch.LongTensor(batch["token_len"]) + token_rel_lens = token_lens / token_lens.max() + + wav_lens = [w.shape[1] for w in batch["wav"]] + wav_lens = torch.LongTensor(wav_lens) + wav_lens_max = torch.max(wav_lens) + wav_rel_lens = wav_lens / wav_lens_max + + pitch_lens = [p.shape[0] for p in batch["pitch"]] + pitch_lens = torch.LongTensor(pitch_lens) + pitch_lens_max = torch.max(pitch_lens) + + token_padded = torch.LongTensor(B, max_text_len) + wav_padded = torch.FloatTensor(B, 1, wav_lens_max) + pitch_padded = torch.FloatTensor(B, 1, pitch_lens_max) + + token_padded = token_padded.zero_() + self.pad_id + wav_padded = wav_padded.zero_() + self.pad_id + pitch_padded = pitch_padded.zero_() + self.pad_id + + for i in range(len(ids_sorted_decreasing)): + token_ids = batch["token_ids"][i] + token_padded[i, : batch["token_len"][i]] = torch.LongTensor(token_ids) + + wav = batch["wav"][i] + wav_padded[i, :, : wav.size(1)] = torch.FloatTensor(wav) + + pitch = batch["pitch"][i] + pitch_padded[i, 0, : len(pitch)] = torch.FloatTensor(pitch) + + return { + "text_input": token_padded, + "text_lengths": token_lens, + "text_rel_lens": token_rel_lens, + "pitch": pitch_padded, + "waveform": wav_padded, # (B x T) + "waveform_lens": wav_lens, # (B) + "waveform_rel_lens": wav_rel_lens, + "speaker_names": batch["speaker_name"], + "language_names": batch["language_name"], + "audio_files": batch["wav_file"], + "raw_text": batch["raw_text"], + } + + +############################## +# CONFIG DEFINITIONS +############################## + + +@dataclass +class ForwardTTSE2eAudio(Coqpit): + sample_rate: int = 22050 + hop_length: int = 256 + win_length: int = 1024 + fft_size: int = 1024 + mel_fmin: float = 0.0 + mel_fmax: float = 8000 + num_mels: int = 80 + pitch_fmax: float = 640.0 + + +@dataclass +class ForwardTTSE2eArgs(ForwardTTSArgs): + # vocoder_config: BaseGANVocoderConfig = None + num_chars: int = 100 + encoder_out_channels: int = 80 + spec_segment_size: int = 80 + # duration predictor + detach_duration_predictor: bool = True + duration_predictor_dropout_p: float = 0.1 + # pitch predictor + pitch_predictor_dropout_p: float = 0.1 + # discriminator + init_discriminator: bool = True + use_spectral_norm_discriminator: bool = False + # model parameters + detach_vocoder_input: bool = False + hidden_channels: int = 256 + encoder_type: str = "fftransformer" + encoder_params: dict = field( + default_factory=lambda: { + "hidden_channels_ffn": 1024, + "num_heads": 2, + "num_layers": 4, + "dropout_p": 0.1, + "kernel_size_fft": 9, + } + ) + decoder_type: str = "fftransformer" + decoder_params: dict = field( + default_factory=lambda: { + "hidden_channels_ffn": 1024, + "num_heads": 2, + "num_layers": 4, + "dropout_p": 0.1, + "kernel_size_fft": 9, + } + ) + # generator + resblock_type_decoder: str = "1" + resblock_kernel_sizes_decoder: List[int] = field(default_factory=lambda: [3, 7, 11]) + resblock_dilation_sizes_decoder: List[List[int]] = field(default_factory=lambda: [[1, 3, 5], [1, 3, 5], [1, 3, 5]]) + upsample_rates_decoder: List[int] = field(default_factory=lambda: [8, 8, 2, 2]) + upsample_initial_channel_decoder: int = 512 + upsample_kernel_sizes_decoder: List[int] = field(default_factory=lambda: [16, 16, 4, 4]) + # discriminator + upsampling_rates_discriminator: List[int] = field(default_factory=lambda: [4, 4, 4, 4]) + periods_discriminator: List[int] = field(default_factory=lambda: [2, 3, 5, 7, 11]) + # multi-speaker params + use_speaker_embedding: bool = False + num_speakers: int = 0 + speakers_file: str = None + d_vector_file: str = None + speaker_embedding_channels: int = 384 + use_d_vector_file: bool = False + d_vector_dim: int = 0 + + +############################## +# MODEL DEFINITION +############################## + + +class ForwardTTSE2e(BaseTTSE2E): + """ + Model training:: + text --> ForwardTTS() --> spec_hat --> rand_seg_select()--> GANVocoder() --> waveform_seg + spec --------^ + + Examples: + >>> from TTS.tts.models.forward_tts_e2e import ForwardTTSE2e, ForwardTTSE2eConfig + >>> config = ForwardTTSE2eConfig() + >>> model = ForwardTTSE2e(config) + """ + + # pylint: disable=dangerous-default-value + def __init__( + self, + config: Coqpit, + tokenizer: "TTSTokenizer" = None, + speaker_manager: SpeakerManager = None, + ): + super().__init__(config=config, tokenizer=tokenizer, speaker_manager=speaker_manager) + self._set_model_args(config) + + self.init_multispeaker(config) + + self.encoder_model = ForwardTTS(config=self.args, ap=None, tokenizer=tokenizer, speaker_manager=speaker_manager) + # self.vocoder_model = GAN(config=self.args.vocoder_config, ap=ap) + self.waveform_decoder = HifiganGenerator( + self.args.hidden_channels, + 1, + self.args.resblock_type_decoder, + self.args.resblock_dilation_sizes_decoder, + self.args.resblock_kernel_sizes_decoder, + self.args.upsample_kernel_sizes_decoder, + self.args.upsample_initial_channel_decoder, + self.args.upsample_rates_decoder, + inference_padding=0, + cond_channels=self.embedded_speaker_dim, + conv_pre_weight_norm=False, + conv_post_weight_norm=False, + conv_post_bias=False, + ) + + # use Vits Discriminator for limiting VRAM use + if self.args.init_discriminator: + self.disc = VitsDiscriminator( + use_spectral_norm=self.args.use_spectral_norm_discriminator, + periods=self.args.periods_discriminator, + upsampling_rates=self.args.upsampling_rates_discriminator, + ) + + # def check_model_args(self): + # upsample_rate = torch.prod(torch.as_tensor(self.args.upsample_rates_decoder)).item() + # if s + # assert ( + # upsample_rate == self.config.audio.hop_length + # ), f" [!] Product of upsample rates must be equal to the hop length - {upsample_rate} vs {config.audio.hop_length}" + + def init_multispeaker(self, config: Coqpit): + """Init for multi-speaker training. + + Args: + config (Coqpit): Model configuration. + """ + self.embedded_speaker_dim = 0 + self.num_speakers = self.args.num_speakers + self.audio_transform = None + + if self.speaker_manager: + self.num_speakers = self.speaker_manager.num_speakers + + if self.args.use_speaker_embedding: + self._init_speaker_embedding() + + if self.args.use_d_vector_file: + self._init_d_vector() + + def _init_speaker_embedding(self): + # pylint: disable=attribute-defined-outside-init + if self.num_speakers > 0: + print(" > initialization of speaker-embedding layers.") + self.embedded_speaker_dim = self.args.speaker_embedding_channels + self.emb_g = nn.Embedding(self.num_speakers, self.embedded_speaker_dim) + + def _init_d_vector(self): + # pylint: disable=attribute-defined-outside-init + if hasattr(self, "emb_g"): + raise ValueError("[!] Speaker embedding layer already initialized before d_vector settings.") + self.embedded_speaker_dim = self.args.d_vector_dim + + def get_aux_input(self, *args, **kwargs) -> Dict: + return self.encoder_model.get_aux_input(*args, **kwargs) + + def forward( + self, + x: torch.LongTensor, + x_lengths: torch.LongTensor, + spec_lengths: torch.LongTensor, + spec: torch.FloatTensor, + waveform: torch.FloatTensor, + dr: torch.IntTensor = None, + pitch: torch.FloatTensor = None, + aux_input: Dict = {"d_vectors": None, "speaker_ids": None}, # pylint: disable=unused-argument + ) -> Dict: + """Model's forward pass. + + Args: + x (torch.LongTensor): Input character sequences. + x_lengths (torch.LongTensor): Input sequence lengths. + spec_lengths (torch.LongTensor): Spectrogram sequnce lengths. Defaults to None. + spec (torch.FloatTensor): Spectrogram frames. Only used when the alignment network is on. Defaults to None. + waveform (torch.FloatTensor): Waveform. Defaults to None. + dr (torch.IntTensor): Character durations over the spectrogram frames. Only used when the alignment network is off. Defaults to None. + pitch (torch.FloatTensor): Pitch values for each spectrogram frame. Only used when the pitch predictor is on. Defaults to None. + aux_input (Dict): Auxiliary model inputs for multi-speaker training. Defaults to `{"d_vectors": 0, "speaker_ids": None}`. + + Shapes: + - x: :math:`[B, T_max]` + - x_lengths: :math:`[B]` + - spec_lengths: :math:`[B]` + - spec: :math:`[B, T_max2]` + - waveform: :math:`[B, C, T_max2]` + - dr: :math:`[B, T_max]` + - g: :math:`[B, C]` + - pitch: :math:`[B, 1, T]` + """ + encoder_outputs = self.encoder_model( + x=x, x_lengths=x_lengths, y_lengths=spec_lengths, y=spec, dr=dr, pitch=pitch, aux_input=aux_input + ) + o_en_ex = encoder_outputs["o_en_ex"].transpose(1, 2) # [B, C_en, T_max2] -> [B, T_max2, C_en] + o_en_ex_slices, slice_ids = rand_segments( + x=o_en_ex.transpose(1, 2), + x_lengths=spec_lengths, + segment_size=self.args.spec_segment_size, + let_short_samples=True, + pad_short=True, + ) + + vocoder_output = self.waveform_decoder( + x=o_en_ex_slices.detach() if self.args.detach_vocoder_input else o_en_ex_slices, + g=encoder_outputs["spk_emb"], + ) + wav_seg = segment( + waveform, + slice_ids * self.config.audio.hop_length, + self.args.spec_segment_size * self.config.audio.hop_length, + pad_short=True, + ) + model_outputs = {**encoder_outputs} + model_outputs["encoder_outputs"] = encoder_outputs["model_outputs"] + model_outputs["model_outputs"] = vocoder_output + model_outputs["waveform_seg"] = wav_seg + model_outputs["slice_ids"] = slice_ids + return model_outputs + + @torch.no_grad() + def inference(self, x, aux_input={"d_vectors": None, "speaker_ids": None}): + encoder_outputs = self.encoder_model.inference(x=x, aux_input=aux_input, skip_decoder=True) + o_en_ex = encoder_outputs["o_en_ex"] + vocoder_output = self.waveform_decoder(x=o_en_ex, g=encoder_outputs["spk_emb"]) + model_outputs = {**encoder_outputs} + model_outputs["model_outputs"] = vocoder_output + return model_outputs + + @torch.no_grad() + def inference_spec_decoder(self, x, aux_input={"d_vectors": None, "speaker_ids": None}): + encoder_outputs = self.encoder_model.inference(x=x, aux_input=aux_input, skip_decoder=False) + model_outputs = {**encoder_outputs} + return model_outputs + + def train_step(self, batch: dict, criterion: nn.Module, optimizer_idx: int): + if optimizer_idx == 0: + tokens = batch["text_input"] + token_lenghts = batch["text_lengths"] + spec = batch["mel_input"] + spec_lens = batch["mel_lengths"] + waveform = batch["waveform"] # [B, T, C] -> [B, C, T] + pitch = batch["pitch"] + d_vectors = batch["d_vectors"] + speaker_ids = batch["speaker_ids"] + language_ids = batch["language_ids"] + + # generator pass + outputs = self.forward( + x=tokens, + x_lengths=token_lenghts, + spec_lengths=spec_lens, + spec=spec, + waveform=waveform, + pitch=pitch, + aux_input={"d_vectors": d_vectors, "speaker_ids": speaker_ids, "language_ids": language_ids}, + ) + + # cache tensors for the generator pass + self.model_outputs_cache = outputs # pylint: disable=attribute-defined-outside-init + + # compute scores and features + scores_d_fake, _, scores_d_real, _ = self.disc(outputs["model_outputs"].detach(), outputs["waveform_seg"]) + + # compute loss + with autocast(enabled=False): # use float32 for the criterion + loss_dict = criterion[optimizer_idx]( + scores_disc_fake=scores_d_fake, + scores_disc_real=scores_d_real, + ) + return outputs, loss_dict + + if optimizer_idx == 1: + mel = batch["mel_input"].transpose(1, 2) + + # compute melspec segment + with autocast(enabled=False): + mel_slice = segment( + mel.float(), self.model_outputs_cache["slice_ids"], self.args.spec_segment_size, pad_short=True + ) + + mel_slice_hat = wav_to_mel( + y=self.model_outputs_cache["model_outputs"].float(), + n_fft=self.config.audio.fft_size, + sample_rate=self.config.audio.sample_rate, + num_mels=self.config.audio.num_mels, + hop_length=self.config.audio.hop_length, + win_length=self.config.audio.win_length, + fmin=self.config.audio.mel_fmin, + fmax=self.config.audio.mel_fmax, + center=False, + ) + + # compute discriminator scores and features + scores_d_fake, feats_d_fake, _, feats_d_real = self.disc( + self.model_outputs_cache["model_outputs"], self.model_outputs_cache["waveform_seg"] + ) + + # compute losses + with autocast(enabled=False): # use float32 for the criterion + loss_dict = criterion[optimizer_idx]( + decoder_output=self.model_outputs_cache["encoder_outputs"], + decoder_target=batch["mel_input"], + decoder_output_lens=batch["mel_lengths"], + dur_output=self.model_outputs_cache["durations_log"], + dur_target=self.model_outputs_cache["aligner_durations"], + pitch_output=self.model_outputs_cache["pitch_avg"] if self.args.use_pitch else None, + pitch_target=self.model_outputs_cache["pitch_avg_gt"] if self.args.use_pitch else None, + input_lens=batch["text_lengths"], + waveform=self.model_outputs_cache["waveform_seg"], + waveform_hat=self.model_outputs_cache["model_outputs"], + aligner_logprob=self.model_outputs_cache["aligner_logprob"], + aligner_hard=self.model_outputs_cache["aligner_mas"], + aligner_soft=self.model_outputs_cache["aligner_soft"], + binary_loss_weight=self.encoder_model.binary_loss_weight, + feats_fake=feats_d_fake, + feats_real=feats_d_real, + scores_fake=scores_d_fake, + spec_slice=mel_slice, + spec_slice_hat=mel_slice_hat, + ) + + # compute duration error for logging + durations_pred = self.model_outputs_cache["durations"] + durations_target = self.model_outputs_cache["aligner_durations"] + duration_error = torch.abs(durations_target - durations_pred).sum() / batch["text_lengths"].sum() + loss_dict["duration_error"] = duration_error + + return self.model_outputs_cache, loss_dict + + raise ValueError(" [!] Unexpected `optimizer_idx`.") + + def eval_step(self, batch: dict, criterion: nn.Module, optimizer_idx: int): + return self.train_step(batch, criterion, optimizer_idx) + + def _log(self, batch, outputs, name_prefix="train"): + figures, audios = {}, {} + + # encoder outputs + model_outputs = outputs[1]["encoder_outputs"] + alignments = outputs[1]["alignments"] + mel_input = batch["mel_input"] + + pred_spec = model_outputs[0].data.cpu().numpy() + gt_spec = mel_input[0].data.cpu().numpy() + align_img = alignments[0].data.cpu().numpy() + + figures = { + "prediction": plot_spectrogram(pred_spec, None, output_fig=False), + "ground_truth": plot_spectrogram(gt_spec, None, output_fig=False), + "alignment": plot_alignment(align_img, output_fig=False), + } + + # plot pitch figures + if self.args.use_pitch: + pitch_avg = abs(outputs[1]["pitch_avg_gt"][0, 0].data.cpu().numpy()) + pitch_avg_hat = abs(outputs[1]["pitch_avg"][0, 0].data.cpu().numpy()) + chars = self.tokenizer.decode(batch["text_input"][0].data.cpu().numpy()) + pitch_figures = { + "pitch_ground_truth": plot_avg_pitch(pitch_avg, chars, output_fig=False), + "pitch_avg_predicted": plot_avg_pitch(pitch_avg_hat, chars, output_fig=False), + } + figures.update(pitch_figures) + + # plot the attention mask computed from the predicted durations + if "attn_durations" in outputs[1]: + alignments_hat = outputs[1]["attn_durations"][0].data.cpu().numpy() + figures["alignment_hat"] = plot_alignment(alignments_hat.T, output_fig=False) + + # Sample audio + encoder_audio = mel_to_wav_numpy( + mel=db_to_amp_numpy(x=pred_spec.T, gain=1, base=None), mel_basis=self.__mel_basis, **self.config.audio + ) + audios[f"{name_prefix}/encoder_audio"] = encoder_audio + + # vocoder outputs + y_hat = outputs[1]["model_outputs"] + y = outputs[1]["waveform_seg"] + + vocoder_figures = plot_results(y_hat=y_hat, y=y, audio_config=self.config.audio, name_prefix=name_prefix) + figures.update(vocoder_figures) + + sample_voice = y_hat[0].squeeze(0).detach().cpu().numpy() + audios[f"{name_prefix}/real_audio"] = sample_voice + return figures, audios + + def train_log( + self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int + ): # pylint: disable=no-self-use, unused-argument + """Create visualizations and waveform examples. + + For example, here you can plot spectrograms and generate sample sample waveforms from these spectrograms to + be projected onto Tensorboard. + + Args: + batch (Dict): Model inputs used at the previous training step. + outputs (Dict): Model outputs generated at the previous training step. + + Returns: + Tuple[Dict, np.ndarray]: training plots and output waveform. + """ + figures, audios = self._log(batch=batch, outputs=outputs, name_prefix="vocoder/") + logger.train_figures(steps, figures) + logger.train_audios(steps, audios, self.config.audio.sample_rate) + + def eval_log(self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int) -> None: + figures, audios = self._log(batch=batch, outputs=outputs, name_prefix="vocoder/") + logger.eval_figures(steps, figures) + logger.eval_audios(steps, audios, self.config.audio.sample_rate) + + def get_aux_input_from_test_sentences(self, sentence_info): + if hasattr(self.config, "model_args"): + config = self.config.model_args + else: + config = self.config + + # extract speaker and language info + text, speaker_name, style_wav, language_name = None, None, None, None # pylint: disable=unused-variable + + if isinstance(sentence_info, list): + if len(sentence_info) == 1: + text = sentence_info[0] + elif len(sentence_info) == 2: + text, speaker_name = sentence_info + elif len(sentence_info) == 3: + text, speaker_name, style_wav = sentence_info + elif len(sentence_info) == 4: + text, speaker_name, style_wav, language_name = sentence_info + else: + text = sentence_info + + # get speaker id/d_vector + speaker_id, d_vector, language_id = None, None, None # pylint: disable=unused-variable + if hasattr(self, "speaker_manager"): + if config.use_d_vector_file: + if speaker_name is None: + d_vector = self.speaker_manager.get_random_d_vector() + else: + d_vector = self.speaker_manager.get_mean_d_vector(speaker_name, num_samples=None, randomize=False) + elif config.use_speaker_embedding: + if speaker_name is None: + speaker_id = self.speaker_manager.get_random_speaker_id() + else: + speaker_id = self.speaker_manager.speaker_ids[speaker_name] + + # get language id + # if hasattr(self, "language_manager") and config.use_language_embedding and language_name is not None: + # language_id = self.language_manager.language_id_mapping[language_name] + + return { + "text": text, + "speaker_id": speaker_id, + "style_wav": style_wav, + "d_vector": d_vector, + "language_id": None, + "language_name": None, + } + + def synthesize(self, text: str, speaker_id, language_id, d_vector): + # TODO: add language_id + is_cuda = next(self.parameters()).is_cuda + + # convert text to sequence of token IDs + text_inputs = np.asarray( + self.tokenizer.text_to_ids(text, language=language_id), + dtype=np.int32, + ) + # pass tensors to backend + if speaker_id is not None: + speaker_id = id_to_torch(speaker_id, cuda=is_cuda) + + if d_vector is not None: + d_vector = embedding_to_torch(d_vector, cuda=is_cuda) + + # if language_id is not None: + # language_id = id_to_torch(language_id, cuda=is_cuda) + + text_inputs = numpy_to_torch(text_inputs, torch.long, cuda=is_cuda) + text_inputs = text_inputs.unsqueeze(0) + + # synthesize voice + outputs = self.inference(text_inputs, aux_input={"d_vectors": d_vector, "speaker_ids": speaker_id}) + + # collect outputs + wav = outputs["model_outputs"][0].data.cpu().numpy() + alignments = outputs["alignments"] + return_dict = { + "wav": wav, + "alignments": alignments, + "text_inputs": text_inputs, + "outputs": outputs, + } + return return_dict + + def synthesize_with_gl(self, text: str, speaker_id, language_id, d_vector): + # TODO: add language_id + is_cuda = next(self.parameters()).is_cuda + + # convert text to sequence of token IDs + text_inputs = np.asarray( + self.tokenizer.text_to_ids(text, language=language_id), + dtype=np.int32, + ) + # pass tensors to backend + if speaker_id is not None: + speaker_id = id_to_torch(speaker_id, cuda=is_cuda) + + if d_vector is not None: + d_vector = embedding_to_torch(d_vector, cuda=is_cuda) + + # if language_id is not None: + # language_id = id_to_torch(language_id, cuda=is_cuda) + + text_inputs = numpy_to_torch(text_inputs, torch.long, cuda=is_cuda) + text_inputs = text_inputs.unsqueeze(0) + + # synthesize voice + outputs = self.inference_spec_decoder(text_inputs, aux_input={"d_vectors": d_vector, "speaker_ids": speaker_id}) + + # collect outputs + wav = mel_to_wav_numpy( + mel=outputs["model_outputs"].cpu().numpy()[0].T, mel_basis=self.__mel_basis, **self.config.audio + ) + alignments = outputs["alignments"] + return_dict = { + "wav": wav[None, :], + "alignments": alignments, + "text_inputs": text_inputs, + "outputs": outputs, + } + return return_dict + + @torch.no_grad() + def test_run(self, assets) -> Tuple[Dict, Dict]: + """Generic test run for `tts` models used by `Trainer`. + + You can override this for a different behaviour. + + Returns: + Tuple[Dict, Dict]: Test figures and audios to be projected to Tensorboard. + """ + print(" | > Synthesizing test sentences.") + test_audios = {} + test_figures = {} + test_sentences = self.config.test_sentences + for idx, s_info in enumerate(test_sentences): + aux_inputs = self.get_aux_input_from_test_sentences(s_info) + outputs = self.synthesize( + aux_inputs["text"], + speaker_id=aux_inputs["speaker_id"], + d_vector=aux_inputs["d_vector"], + language_id=aux_inputs["language_id"], + ) + outputs_gl = self.synthesize_with_gl( + aux_inputs["text"], + speaker_id=aux_inputs["speaker_id"], + d_vector=aux_inputs["d_vector"], + language_id=aux_inputs["language_id"], + ) + test_audios["{}-audio".format(idx)] = outputs["wav"].T + test_audios["{}-audio_encoder".format(idx)] = outputs_gl["wav"].T + test_figures["{}-alignment".format(idx)] = plot_alignment(outputs["alignments"], output_fig=False) + return {"figures": test_figures, "audios": test_audios} + + def test_log( + self, outputs: dict, logger: "Logger", assets: dict, steps: int # pylint: disable=unused-argument + ) -> None: + logger.test_audios(steps, outputs["audios"], self.config.audio.sample_rate) + logger.test_figures(steps, outputs["figures"]) + + def format_batch(self, batch: Dict) -> Dict: + """Compute speaker, langugage IDs and d_vector for the batch if necessary.""" + speaker_ids = None + language_ids = None + d_vectors = None + + # get numerical speaker ids from speaker names + if self.speaker_manager is not None and self.speaker_manager.speaker_ids and self.args.use_speaker_embedding: + speaker_ids = [self.speaker_manager.speaker_ids[sn] for sn in batch["speaker_names"]] + + if speaker_ids is not None: + speaker_ids = torch.LongTensor(speaker_ids) + batch["speaker_ids"] = speaker_ids + + # get d_vectors from audio file names + if self.speaker_manager is not None and self.speaker_manager.d_vectors and self.args.use_d_vector_file: + d_vector_mapping = self.speaker_manager.d_vectors + d_vectors = [d_vector_mapping[w]["embedding"] for w in batch["audio_files"]] + d_vectors = torch.FloatTensor(d_vectors) + + # get language ids from language names + if ( + self.language_manager is not None + and self.language_manager.language_id_mapping + and self.args.use_language_embedding + ): + language_ids = [self.language_manager.language_id_mapping[ln] for ln in batch["language_names"]] + + if language_ids is not None: + language_ids = torch.LongTensor(language_ids) + + batch["language_ids"] = language_ids + batch["d_vectors"] = d_vectors + batch["speaker_ids"] = speaker_ids + return batch + + def format_batch_on_device(self, batch): + """Compute spectrograms on the device.""" + ac = self.config.audio + + # compute spectrograms + batch["mel_input"] = wav_to_mel( + batch["waveform"], + hop_length=ac.hop_length, + win_length=ac.win_length, + n_fft=ac.fft_size, + num_mels=ac.num_mels, + sample_rate=ac.sample_rate, + fmin=ac.mel_fmin, + fmax=ac.mel_fmax, + center=False, + ) + + # TODO: Align pitch properly + # assert ( + # batch["pitch"].shape[2] == batch["mel_input"].shape[2] + # ), f"{batch['pitch'].shape[2]}, {batch['mel_input'].shape[2]}" + batch["pitch"] = batch["pitch"][:, :, : batch["mel_input"].shape[2]] + batch["mel_lengths"] = (batch["mel_input"].shape[2] * batch["waveform_rel_lens"]).int() + + # zero the padding frames + batch["mel_input"] = batch["mel_input"] * sequence_mask(batch["mel_lengths"]).unsqueeze(1) + batch["mel_input"] = batch["mel_input"].transpose(1, 2) + return batch + + def get_data_loader( + self, + config: Coqpit, + assets: Dict, + is_eval: bool, + samples: Union[List[Dict], List[List]], + verbose: bool, + num_gpus: int, + rank: int = None, + ) -> "DataLoader": + if is_eval and not config.run_eval: + loader = None + else: + # init dataloader + dataset = ForwardTTSE2eDataset( + samples=samples, + audio_config=self.config.audio, + batch_group_size=0 if is_eval else config.batch_group_size * config.batch_size, + min_text_len=config.min_text_len, + max_text_len=config.max_text_len, + min_audio_len=config.min_audio_len, + max_audio_len=config.max_audio_len, + phoneme_cache_path=config.phoneme_cache_path, + precompute_num_workers=config.precompute_num_workers, + compute_f0=config.compute_f0, + f0_cache_path=config.f0_cache_path, + verbose=verbose, + tokenizer=self.tokenizer, + start_by_longest=config.start_by_longest, + ) + + # wait all the DDP process to be ready + if num_gpus > 1: + dist.barrier() + + # sort input sequences from short to long + dataset.preprocess_samples() + + # get samplers + sampler = self.get_sampler(config, dataset, num_gpus) + + loader = DataLoader( + dataset, + batch_size=config.eval_batch_size if is_eval else config.batch_size, + shuffle=False, # shuffle is done in the dataset. + drop_last=False, # setting this False might cause issues in AMP training. + sampler=sampler, + collate_fn=dataset.collate_fn, + num_workers=config.num_eval_loader_workers if is_eval else config.num_loader_workers, + pin_memory=False, + ) + return loader + + def get_criterion(self): + return [VitsDiscriminatorLoss(self.config), ForwardTTSE2eLoss(self.config)] + + def get_optimizer(self) -> List: + """Initiate and return the GAN optimizers based on the config parameters. + It returnes 2 optimizers in a list. First one is for the generator and the second one is for the discriminator. + Returns: + List: optimizers. + """ + optimizer0 = get_optimizer(self.config.optimizer, self.config.optimizer_params, self.config.lr_disc, self.disc) + gen_parameters = chain(params for k, params in self.named_parameters() if not k.startswith("disc.")) + optimizer1 = get_optimizer( + self.config.optimizer, self.config.optimizer_params, self.config.lr_gen, parameters=gen_parameters + ) + return [optimizer0, optimizer1] + + def get_lr(self) -> List: + """Set the initial learning rates for each optimizer. + + Returns: + List: learning rates for each optimizer. + """ + return [self.config.lr_disc, self.config.lr_gen] + + def get_scheduler(self, optimizer) -> List: + """Set the schedulers for each optimizer. + + Args: + optimizer (List[`torch.optim.Optimizer`]): List of optimizers. + + Returns: + List: Schedulers, one for each optimizer. + """ + scheduler_D = get_scheduler(self.config.lr_scheduler_gen, self.config.lr_scheduler_gen_params, optimizer[0]) + scheduler_G = get_scheduler(self.config.lr_scheduler_disc, self.config.lr_scheduler_disc_params, optimizer[1]) + return [scheduler_D, scheduler_G] + + def on_train_step_start(self, trainer): + """Schedule binary loss weight.""" + self.encoder_model.config.binary_loss_warmup_epochs = self.config.binary_loss_warmup_epochs + self.encoder_model.on_train_step_start(trainer) + + def on_init_start(self, trainer: "Trainer"): + self.__mel_basis = build_mel_basis( + sample_rate=self.config.audio.sample_rate, + fft_size=self.config.audio.fft_size, + num_mels=self.config.audio.num_mels, + mel_fmax=self.config.audio.mel_fmax, + mel_fmin=self.config.audio.mel_fmin, + ) + + @staticmethod + def init_from_config(config: "ForwardTTSConfig", samples: Union[List[List], List[Dict]] = None, verbose=False): + """Initiate model from config + + Args: + config (ForwardTTSE2eConfig): Model config. + samples (Union[List[List], List[Dict]]): Training samples to parse speaker ids for training. + Defaults to None. + """ + from TTS.utils.audio.processor import AudioProcessor + + tokenizer, new_config = TTSTokenizer.init_from_config(config) + speaker_manager = SpeakerManager.init_from_config(config, samples) + # language_manager = LanguageManager.init_from_config(config) + return ForwardTTSE2e(config=new_config, tokenizer=tokenizer, speaker_manager=speaker_manager) + + def load_checkpoint(self, config, checkpoint_path, eval=False): + """Load model from a checkpoint created by the 👟""" + # pylint: disable=unused-argument, redefined-builtin + state = torch.load(checkpoint_path, map_location=torch.device("cpu")) + self.load_state_dict(state["model"]) + if eval: + self.eval() + assert not self.training + + def get_state_dict(self): + """Custom state dict of the model with all the necessary components for inference.""" + save_state = {"config": self.config.to_dict(), "args": self.args.to_dict(), "model": self.state_dict} + + if hasattr(self, "emb_g"): + save_state["speaker_ids"] = self.speaker_manager.speaker_ids + + if self.args.use_d_vector_file: + # TODO: implement saving of d_vectors + ... + return save_state + + def save(self, config, checkpoint_path): + """Save model to a file.""" + save_state = self.get_state_dict(config, checkpoint_path) + torch.save(save_state, checkpoint_path) diff --git a/TTS/tts/models/vits.py b/TTS/tts/models/vits.py index 2c1c2bc67b..a476e8709a 100644 --- a/TTS/tts/models/vits.py +++ b/TTS/tts/models/vits.py @@ -29,6 +29,7 @@ from TTS.tts.utils.text.characters import BaseCharacters, _characters, _pad, _phonemes, _punctuations from TTS.tts.utils.text.tokenizer import TTSTokenizer from TTS.tts.utils.visual import plot_alignment +from TTS.utils.generic_utils import count_parameters from TTS.vocoder.models.hifigan_generator import HifiganGenerator from TTS.vocoder.utils.generic_utils import plot_results @@ -125,6 +126,7 @@ def wav_to_spec(y, n_fft, hop_length, win_length, center=False): pad_mode="reflect", normalized=False, onesided=True, + return_complex=True, ) spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6) @@ -153,10 +155,10 @@ def spec_to_mel(spec, n_fft, num_mels, sample_rate, fmin, fmax): def wav_to_mel(y, n_fft, num_mels, sample_rate, hop_length, win_length, fmin, fmax, center=False): """ Args Shapes: - - y : :math:`[B, 1, T]` + - y : :math:`[B, 1, T_y]` Return Shapes: - - spec : :math:`[B,C,T]` + - spec : :math:`[B,C,T_spec]` """ y = y.squeeze(1) @@ -521,7 +523,7 @@ class VitsArgs(Coqpit): inference_noise_scale_dp: float = 1.0 max_inference_len: int = None init_discriminator: bool = True - use_spectral_norm_disriminator: bool = False + use_spectral_norm_discriminator: bool = False use_speaker_embedding: bool = False num_speakers: int = 0 speakers_file: str = None @@ -857,21 +859,21 @@ def forward_mas(self, outputs, z_p, m_p, logs_p, x, x_mask, y_mask, g, lang_emb) logp3 = torch.einsum("klm, kln -> kmn", [m_p * o_scale, z_p]) logp4 = torch.sum(-0.5 * (m_p**2) * o_scale, [1]).unsqueeze(-1) # [b, t, 1] logp = logp2 + logp3 + logp1 + logp4 - attn = maximum_path(logp, attn_mask.squeeze(1)).unsqueeze(1).detach() # [b, 1, t, t'] + mas_attn = maximum_path(logp, attn_mask.squeeze(1)).unsqueeze(1).detach() # [b, 1, t, t'] # duration predictor - attn_durations = attn.sum(3) + mas_attn_durations = mas_attn.sum(3) if self.args.use_sdp: loss_duration = self.duration_predictor( x.detach() if self.args.detach_dp_input else x, x_mask, - attn_durations, + mas_attn_durations, g=g.detach() if self.args.detach_dp_input and g is not None else g, lang_emb=lang_emb.detach() if self.args.detach_dp_input and lang_emb is not None else lang_emb, ) loss_duration = loss_duration / torch.sum(x_mask) else: - attn_log_durations = torch.log(attn_durations + 1e-6) * x_mask + attn_log_durations = torch.log(mas_attn_durations + 1e-6) * x_mask log_durations = self.duration_predictor( x.detach() if self.args.detach_dp_input else x, x_mask, @@ -880,7 +882,7 @@ def forward_mas(self, outputs, z_p, m_p, logs_p, x, x_mask, y_mask, g, lang_emb) ) loss_duration = torch.sum((log_durations - attn_log_durations) ** 2, [1, 2]) / torch.sum(x_mask) outputs["loss_duration"] = loss_duration - return outputs, attn + return outputs, mas_attn def upsampling_z(self, z, slice_ids=None, y_lengths=None, y_mask=None): spec_segment_size = self.spec_segment_size @@ -965,11 +967,11 @@ def forward( # pylint: disable=dangerous-default-value z_p = self.flow(z, y_mask, g=g) # duration predictor - outputs, attn = self.forward_mas(outputs, z_p, m_p, logs_p, x, x_mask, y_mask, g=g, lang_emb=lang_emb) + outputs, mas_attn = self.forward_mas(outputs, z_p, m_p, logs_p, x, x_mask, y_mask, g=g, lang_emb=lang_emb) # expand prior - m_p = torch.einsum("klmn, kjm -> kjn", [attn, m_p]) - logs_p = torch.einsum("klmn, kjm -> kjn", [attn, logs_p]) + m_p = torch.einsum("klmn, kjm -> kjn", [mas_attn, m_p]) + logs_p = torch.einsum("klmn, kjm -> kjn", [mas_attn, logs_p]) # select a random feature segment for the waveform decoder z_slice, slice_ids = rand_segments(z, y_lengths, self.spec_segment_size, let_short_samples=True, pad_short=True) @@ -1005,7 +1007,7 @@ def forward( # pylint: disable=dangerous-default-value outputs.update( { "model_outputs": o, - "alignments": attn.squeeze(1), + "alignments": mas_attn.squeeze(1), "m_p": m_p, "logs_p": logs_p, "z": z, @@ -1269,7 +1271,8 @@ def train_step(self, batch: dict, criterion: nn.Module, optimizer_idx: int) -> T raise ValueError(" [!] Unexpected `optimizer_idx`.") - def _log(self, ap, batch, outputs, name_prefix="train"): # pylint: disable=unused-argument,no-self-use + @staticmethod + def _log(ap, outputs, name_prefix="train"): y_hat = outputs[1]["model_outputs"] y = outputs[1]["waveform_seg"] figures = plot_results(y_hat, y, ap, name_prefix) @@ -1302,7 +1305,7 @@ def train_log( Returns: Tuple[Dict, np.ndarray]: training plots and output waveform. """ - figures, audios = self._log(self.ap, batch, outputs, "train") + figures, audios = self._log(self.ap, outputs, "train") logger.train_figures(steps, figures) logger.train_audios(steps, audios, self.ap.sample_rate) @@ -1311,7 +1314,7 @@ def eval_step(self, batch: dict, criterion: nn.Module, optimizer_idx: int): return self.train_step(batch, criterion, optimizer_idx) def eval_log(self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int) -> None: - figures, audios = self._log(self.ap, batch, outputs, "eval") + figures, audios = self._log(self.ap, outputs, "eval") logger.eval_figures(steps, figures) logger.eval_audios(steps, audios, self.ap.sample_rate) @@ -1542,9 +1545,7 @@ def get_optimizer(self) -> List: Returns: List: optimizers. """ - # select generator parameters optimizer0 = get_optimizer(self.config.optimizer, self.config.optimizer_params, self.config.lr_disc, self.disc) - gen_parameters = chain(params for k, params in self.named_parameters() if not k.startswith("disc.")) optimizer1 = get_optimizer( self.config.optimizer, self.config.optimizer_params, self.config.lr_gen, parameters=gen_parameters diff --git a/TTS/utils/audio/__init__.py b/TTS/utils/audio/__init__.py new file mode 100644 index 0000000000..f18f221999 --- /dev/null +++ b/TTS/utils/audio/__init__.py @@ -0,0 +1 @@ +from TTS.utils.audio.processor import AudioProcessor diff --git a/TTS/utils/audio/numpy_transforms.py b/TTS/utils/audio/numpy_transforms.py new file mode 100644 index 0000000000..2633b83c3f --- /dev/null +++ b/TTS/utils/audio/numpy_transforms.py @@ -0,0 +1,446 @@ +from typing import Callable, Tuple + +import librosa +import numpy as np +import pyworld as pw +import scipy +import soundfile as sf + +# from TTS.tts.utils.helpers import StandardScaler + + +def build_mel_basis( + *, + sample_rate: int = None, + fft_size: int = None, + num_mels: int = None, + mel_fmax: int = None, + mel_fmin: int = None, + **kwargs, +) -> np.ndarray: + """Build melspectrogram basis. + + Returns: + np.ndarray: melspectrogram basis. + """ + if mel_fmax is not None: + assert mel_fmax <= sample_rate // 2 + assert mel_fmax - mel_fmin > 0 + return librosa.filters.mel(sr=sample_rate, n_fft=fft_size, n_mels=num_mels, fmin=mel_fmin, fmax=mel_fmax) + + +def millisec_to_length( + *, frame_length_ms: int = None, frame_shift_ms: int = None, sample_rate: int = None, **kwargs +) -> Tuple[int, int]: + """Compute hop and window length from milliseconds. + + Returns: + Tuple[int, int]: hop length and window length for STFT. + """ + factor = frame_length_ms / frame_shift_ms + assert (factor).is_integer(), " [!] frame_shift_ms should divide frame_length_ms" + hop_length = int(frame_shift_ms / 1000.0 * sample_rate) + win_length = int(hop_length * factor) + return hop_length, win_length + + +def _log(x, base): + if base == 10: + return np.log10(x) + return np.log(x) + + +def _exp(x, base): + if base == 10: + return np.power(10, x) + return np.exp(x) + + +def amp_to_db(*, x: np.ndarray = None, gain: float = 1, base: int = 10, **kwargs) -> np.ndarray: + """Convert amplitude values to decibels. + + Args: + x (np.ndarray): Amplitude spectrogram. + gain (float): Gain factor. Defaults to 1. + base (int): Logarithm base. Defaults to 10. + + Returns: + np.ndarray: Decibels spectrogram. + """ + return gain * _log(np.maximum(1e-5, x), base) + + +# pylint: disable=no-self-use +def db_to_amp(*, x: np.ndarray = None, gain: float = 1, base: int = 10, **kwargs) -> np.ndarray: + """Convert decibels spectrogram to amplitude spectrogram. + + Args: + x (np.ndarray): Decibels spectrogram. + gain (float): Gain factor. Defaults to 1. + base (int): Logarithm base. Defaults to 10. + + Returns: + np.ndarray: Amplitude spectrogram. + """ + return _exp(x / gain, base) + + +def preemphasis(*, x: np.ndarray, coef: float = 0.97, **kwargs) -> np.ndarray: + """Apply pre-emphasis to the audio signal. Useful to reduce the correlation between neighbouring signal values. + + Args: + x (np.ndarray): Audio signal. + + Raises: + RuntimeError: Preemphasis coeff is set to 0. + + Returns: + np.ndarray: Decorrelated audio signal. + """ + if coef == 0: + raise RuntimeError(" [!] Preemphasis is set 0.0.") + return scipy.signal.lfilter([1, -coef], [1], x) + + +def deemphasis(*, x: np.ndarray = None, coef: float = 0.97, **kwargs) -> np.ndarray: + """Reverse pre-emphasis.""" + if coef == 0: + raise RuntimeError(" [!] Preemphasis is set 0.0.") + return scipy.signal.lfilter([1], [1, -coef], x) + + +def spec_to_mel(*, spectrogram: np.ndarray = None, mel_basis: np.ndarray = None, **kwargs) -> np.ndarray: + """Project a full scale spectrogram to a melspectrogram. + + Args: + spectrogram (np.ndarray): Full scale spectrogram. + + Returns: + np.ndarray: Melspectrogram + """ + return np.dot(mel_basis, spectrogram) + + +def mel_to_spec(*, mel: np.ndarray = None, mel_basis: np.ndarray = None, **kwargs) -> np.ndarray: + """Convert a melspectrogram to full scale spectrogram.""" + inv_mel_basis = np.linalg.pinv(mel_basis) + return np.maximum(1e-10, np.dot(inv_mel_basis, mel)) + + +def wav_to_spec(*, y: np.ndarray = None, **kwargs) -> np.ndarray: + """Compute a spectrogram from a waveform. + + Args: + y (np.ndarray): Waveform. + + Returns: + np.ndarray: Spectrogram. + """ + D = stft(y, **kwargs) + S = np.abs(D) + return S.astype(np.float32) + + +def wav_to_mel(*, y: np.ndarray = None, **kwargs) -> np.ndarray: + """Compute a melspectrogram from a waveform.""" + D = stft(y=y, **kwargs) + S = spec_to_mel(spec=np.abs(D), **kwargs) + return S.astype(np.float32) + + +def spec_to_wav(*, spec: np.ndarray, power: float = 1.5, **kwargs) -> np.ndarray: + """Convert a spectrogram to a waveform using Griffi-Lim vocoder.""" + S = spec.copy() + return griffin_lim(spec=S**power, **kwargs) + + +def mel_to_wav(*, mel: np.ndarray = None, power: float = 1.5, **kwargs) -> np.ndarray: + """Convert a melspectrogram to a waveform using Griffi-Lim vocoder.""" + S = mel.copy() + S = mel_to_spec(mel=S, mel_basis=kwargs["mel_basis"]) # Convert back to linear + return griffin_lim(spec=S**power, **kwargs) + + +def spec_to_mel(*, spec: np.ndarray, mel_basis: np.ndarray = None, **kwargs) -> np.ndarray: + """Convert a full scale linear spectrogram output of a network to a melspectrogram. + + Args: + linear_spec (np.ndarray): Normalized full scale linear spectrogram. + + Returns: + np.ndarray: Normalized melspectrogram. + """ + return np.dot(mel_basis, spec) + + +### STFT and ISTFT ### +def stft( + *, + y: np.ndarray = None, + fft_size: int = None, + hop_length: int = None, + win_length: int = None, + pad_mode: str = "reflect", + window: str = "hann", + center: bool = True, + **kwargs, +) -> np.ndarray: + """Librosa STFT wrapper. + + Check http://librosa.org/doc/main/generated/librosa.stft.html argument details. + + Returns: + np.ndarray: Complex number array. + """ + return librosa.stft( + y=y, + n_fft=fft_size, + hop_length=hop_length, + win_length=win_length, + pad_mode=pad_mode, + window="hann", + center=True, + ) + + +def istft( + *, + y: np.ndarray = None, + fft_size: int = None, + hop_length: int = None, + win_length: int = None, + window: str = "hann", + center: bool = True, + **kwargs, +) -> np.ndarray: + """Librosa iSTFT wrapper. + + Check http://librosa.org/doc/main/generated/librosa.istft.html argument details. + + Returns: + np.ndarray: Complex number array. + """ + return librosa.istft(y, hop_length=hop_length, win_length=win_length, center=center, window=window) + + +def griffin_lim(*, spec: np.ndarray = None, num_iter=60, **kwargs) -> np.ndarray: + angles = np.exp(2j * np.pi * np.random.rand(*spec.shape)) + S_complex = np.abs(spec).astype(np.complex) + y = istft(y=S_complex * angles, **kwargs) + if not np.isfinite(y).all(): + print(" [!] Waveform is not finite everywhere. Skipping the GL.") + return np.array([0.0]) + for _ in range(num_iter): + angles = np.exp(1j * np.angle(stft(y=y, **kwargs))) + y = istft(y=S_complex * angles, **kwargs) + return y + + +def compute_stft_paddings( + *, x: np.ndarray = None, hop_length: int = None, pad_two_sides: bool = False, **kwargs +) -> Tuple[int, int]: + """Compute paddings used by Librosa's STFT. Compute right padding (final frame) or both sides padding + (first and final frames)""" + pad = (x.shape[0] // hop_length + 1) * hop_length - x.shape[0] + if not pad_two_sides: + return 0, pad + return pad // 2, pad // 2 + pad % 2 + + +def compute_f0( + *, x: np.ndarray = None, pitch_fmax: float = None, hop_length: int = None, sample_rate: int = None, **kwargs +) -> np.ndarray: + """Compute pitch (f0) of a waveform using the same parameters used for computing melspectrogram. + + Args: + x (np.ndarray): Waveform. + + Returns: + np.ndarray: Pitch. + + Examples: + >>> WAV_FILE = filename = librosa.util.example_audio_file() + >>> from TTS.config import BaseAudioConfig + >>> from TTS.utils.audio.processor import AudioProcessor >>> conf = BaseAudioConfig(pitch_fmax=8000) + >>> ap = AudioProcessor(**conf) + >>> wav = ap.load_wav(WAV_FILE, sr=22050)[:5 * 22050] + >>> pitch = ap.compute_f0(wav) + """ + assert pitch_fmax is not None, " [!] Set `pitch_fmax` before caling `compute_f0`." + + f0, t = pw.dio( + x.astype(np.double), + fs=sample_rate, + f0_ceil=pitch_fmax, + frame_period=1000 * hop_length / sample_rate, + ) + f0 = pw.stonemask(x.astype(np.double), f0, t, sample_rate) + return f0 + + +### Audio Processing ### +def find_endpoint( + *, + wav: np.ndarray = None, + trim_db: float = None, + sample_rate: int = None, + min_silence_sec=0.8, + gain: float = None, + base: int = 10, + **kwargs, +) -> int: + """Find the last point without silence at the end of a audio signal. + + Args: + wav (np.ndarray): Audio signal. + threshold_db (int, optional): Silence threshold in decibels. Defaults to -40. + min_silence_sec (float, optional): Ignore silences that are shorter then this in secs. Defaults to 0.8. + + Returns: + int: Last point without silence. + """ + window_length = int(sample_rate * min_silence_sec) + hop_length = int(window_length / 4) + threshold = db_to_amp(x=-trim_db, gain=gain, base=base) + for x in range(hop_length, len(wav) - window_length, hop_length): + if np.max(wav[x : x + window_length]) < threshold: + return x + hop_length + return len(wav) + + +def trim_silence( + *, + wav: np.ndarray = None, + sample_rate: int = None, + trim_db: float = None, + win_length: int = None, + hop_length: int = None, + **kwargs, +) -> np.ndarray: + """Trim silent parts with a threshold and 0.01 sec margin""" + margin = int(sample_rate * 0.01) + wav = wav[margin:-margin] + return librosa.effects.trim(wav, top_db=trim_db, frame_length=win_length, hop_length=hop_length)[0] + + +def sound_norm(*, x: np.ndarray = None, coef: float = 0.95, **kwargs) -> np.ndarray: + """Normalize the volume of an audio signal. + + Args: + x (np.ndarray): Raw waveform. + coef (float): Coefficient to rescale the maximum value. Defaults to 0.95. + + Returns: + np.ndarray: Volume normalized waveform. + """ + return x / abs(x).max() * coef + + +def rms_norm(*, wav: np.ndarray = None, db_level: float = -27.0, **kwargs) -> np.ndarray: + r = 10 ** (db_level / 20) + a = np.sqrt((len(wav) * (r**2)) / np.sum(wav**2)) + return wav * a + + +def rms_volume_norm(*, x: np.ndarray, db_level: float = -27.0, **kwargs) -> np.ndarray: + """Normalize the volume based on RMS of the signal. + + Args: + x (np.ndarray): Raw waveform. + db_level (float): Target dB level in RMS. Defaults to -27.0. + + Returns: + np.ndarray: RMS normalized waveform. + """ + if db_level is None: + db_level = db_level + assert -99 <= db_level <= 0, " [!] db_level should be between -99 and 0" + wav = rms_norm(wav=x, db_level=db_level) + return wav + + +def load_wav(*, filename: str, sample_rate: int = None, resample: bool = False, **kwargs) -> np.ndarray: + """Read a wav file using Librosa and optionally resample, silence trim, volume normalize. + + Resampling slows down loading the file significantly. Therefore it is recommended to resample the file before. + + Args: + filename (str): Path to the wav file. + sr (int, optional): Sampling rate for resampling. Defaults to None. + resample (bool, optional): Resample the audio file when loading. Slows down the I/O time. Defaults to False. + + Returns: + np.ndarray: Loaded waveform. + """ + if resample: + # loading with resampling. It is significantly slower. + x, sr = librosa.load(filename, sr=sample_rate) + elif sr is None: + # SF is faster than librosa for loading files + x, sr = sf.read(filename) + assert sample_rate == sr, "%s vs %s" % (sample_rate, sr) + else: + x, sr = librosa.load(filename, sr=sr) + return x + + +def save_wav(*, wav: np.ndarray, path: str, sample_rate: int = None, **kwargs) -> None: + """Save float waveform to a file using Scipy. + + Args: + wav (np.ndarray): Waveform with float values in range [-1, 1] to save. + path (str): Path to a output file. + sr (int, optional): Sampling rate used for saving to the file. Defaults to None. + """ + wav_norm = wav * (32767 / max(0.01, np.max(np.abs(wav)))) + scipy.io.wavfile.write(path, sample_rate, wav_norm.astype(np.int16)) + + +def mulaw_encode(*, wav: np.ndarray, mulaw_qc: int, **kwargs) -> np.ndarray: + mu = 2**mulaw_qc - 1 + signal = np.sign(wav) * np.log(1 + mu * np.abs(wav)) / np.log(1.0 + mu) + signal = (signal + 1) / 2 * mu + 0.5 + return np.floor( + signal, + ) + + +def mulaw_decode(*, wav, mulaw_qc: int, **kwargs) -> np.ndarray: + """Recovers waveform from quantized values.""" + mu = 2**mulaw_qc - 1 + x = np.sign(wav) / mu * ((1 + mu) ** np.abs(wav) - 1) + return x + + +def encode_16bits(*, x: np.ndarray, **kwargs) -> np.ndarray: + return np.clip(x * 2**15, -(2**15), 2**15 - 1).astype(np.int16) + + +def quantize(*, x: np.ndarray, quantize_bits: int, **kwargs) -> np.ndarray: + """Quantize a waveform to a given number of bits. + + Args: + x (np.ndarray): Waveform to quantize. Must be normalized into the range `[-1, 1]`. + quantize_bits (int): Number of quantization bits. + + Returns: + np.ndarray: Quantized waveform. + """ + return (x + 1.0) * (2**quantize_bits - 1) / 2 + + +def dequantize(*, x, quantize_bits, **kwargs) -> np.ndarray: + """Dequantize a waveform from the given number of bits.""" + return 2 * x / (2**quantize_bits - 1) - 1 + + +def _log(x, base): + if base == 10: + return np.log10(x) + return np.log(x) + + +def _exp(x, base): + if base == 10: + return np.power(10, x) + return np.exp(x) diff --git a/TTS/utils/audio.py b/TTS/utils/audio/processor.py similarity index 64% rename from TTS/utils/audio.py rename to TTS/utils/audio/processor.py index fc9d194201..7186a7cbfd 100644 --- a/TTS/utils/audio.py +++ b/TTS/utils/audio/processor.py @@ -1,177 +1,9 @@ -from typing import Dict, Tuple - import librosa import numpy as np import pyworld as pw import scipy.io.wavfile import scipy.signal import soundfile as sf -import torch -from torch import nn - -from TTS.tts.utils.helpers import StandardScaler - - -class TorchSTFT(nn.Module): # pylint: disable=abstract-method - """Some of the audio processing funtions using Torch for faster batch processing. - - TODO: Merge this with audio.py - - Args: - - n_fft (int): - FFT window size for STFT. - - hop_length (int): - number of frames between STFT columns. - - win_length (int, optional): - STFT window length. - - pad_wav (bool, optional): - If True pad the audio with (n_fft - hop_length) / 2). Defaults to False. - - window (str, optional): - The name of a function to create a window tensor that is applied/multiplied to each frame/window. Defaults to "hann_window" - - sample_rate (int, optional): - target audio sampling rate. Defaults to None. - - mel_fmin (int, optional): - minimum filter frequency for computing melspectrograms. Defaults to None. - - mel_fmax (int, optional): - maximum filter frequency for computing melspectrograms. Defaults to None. - - n_mels (int, optional): - number of melspectrogram dimensions. Defaults to None. - - use_mel (bool, optional): - If True compute the melspectrograms otherwise. Defaults to False. - - do_amp_to_db_linear (bool, optional): - enable/disable amplitude to dB conversion of linear spectrograms. Defaults to False. - - spec_gain (float, optional): - gain applied when converting amplitude to DB. Defaults to 1.0. - - power (float, optional): - Exponent for the magnitude spectrogram, e.g., 1 for energy, 2 for power, etc. Defaults to None. - - use_htk (bool, optional): - Use HTK formula in mel filter instead of Slaney. - - mel_norm (None, 'slaney', or number, optional): - If 'slaney', divide the triangular mel weights by the width of the mel band - (area normalization). - - If numeric, use `librosa.util.normalize` to normalize each filter by to unit l_p norm. - See `librosa.util.normalize` for a full description of supported norm values - (including `+-np.inf`). - - Otherwise, leave all the triangles aiming for a peak value of 1.0. Defaults to "slaney". - """ - - def __init__( - self, - n_fft, - hop_length, - win_length, - pad_wav=False, - window="hann_window", - sample_rate=None, - mel_fmin=0, - mel_fmax=None, - n_mels=80, - use_mel=False, - do_amp_to_db=False, - spec_gain=1.0, - power=None, - use_htk=False, - mel_norm="slaney", - ): - super().__init__() - self.n_fft = n_fft - self.hop_length = hop_length - self.win_length = win_length - self.pad_wav = pad_wav - self.sample_rate = sample_rate - self.mel_fmin = mel_fmin - self.mel_fmax = mel_fmax - self.n_mels = n_mels - self.use_mel = use_mel - self.do_amp_to_db = do_amp_to_db - self.spec_gain = spec_gain - self.power = power - self.use_htk = use_htk - self.mel_norm = mel_norm - self.window = nn.Parameter(getattr(torch, window)(win_length), requires_grad=False) - self.mel_basis = None - if use_mel: - self._build_mel_basis() - - def __call__(self, x): - """Compute spectrogram frames by torch based stft. - - Args: - x (Tensor): input waveform - - Returns: - Tensor: spectrogram frames. - - Shapes: - x: [B x T] or [:math:`[B, 1, T]`] - """ - if x.ndim == 2: - x = x.unsqueeze(1) - if self.pad_wav: - padding = int((self.n_fft - self.hop_length) / 2) - x = torch.nn.functional.pad(x, (padding, padding), mode="reflect") - # B x D x T x 2 - o = torch.stft( - x.squeeze(1), - self.n_fft, - self.hop_length, - self.win_length, - self.window, - center=True, - pad_mode="reflect", # compatible with audio.py - normalized=False, - onesided=True, - return_complex=False, - ) - M = o[:, :, :, 0] - P = o[:, :, :, 1] - S = torch.sqrt(torch.clamp(M**2 + P**2, min=1e-8)) - - if self.power is not None: - S = S**self.power - - if self.use_mel: - S = torch.matmul(self.mel_basis.to(x), S) - if self.do_amp_to_db: - S = self._amp_to_db(S, spec_gain=self.spec_gain) - return S - - def _build_mel_basis(self): - mel_basis = librosa.filters.mel( - self.sample_rate, - self.n_fft, - n_mels=self.n_mels, - fmin=self.mel_fmin, - fmax=self.mel_fmax, - htk=self.use_htk, - norm=self.mel_norm, - ) - self.mel_basis = torch.from_numpy(mel_basis).float() - - @staticmethod - def _amp_to_db(x, spec_gain=1.0): - return torch.log(torch.clamp(x, min=1e-5) * spec_gain) - - @staticmethod - def _db_to_amp(x, spec_gain=1.0): - return torch.exp(x) / spec_gain # pylint: disable=too-many-public-methods @@ -398,158 +230,6 @@ def init_from_config(config: "Coqpit", verbose=True): return AudioProcessor(verbose=verbose, **config) ### setting up the parameters ### - def _build_mel_basis( - self, - ) -> np.ndarray: - """Build melspectrogram basis. - - Returns: - np.ndarray: melspectrogram basis. - """ - if self.mel_fmax is not None: - assert self.mel_fmax <= self.sample_rate // 2 - return librosa.filters.mel( - self.sample_rate, self.fft_size, n_mels=self.num_mels, fmin=self.mel_fmin, fmax=self.mel_fmax - ) - - def _stft_parameters( - self, - ) -> Tuple[int, int]: - """Compute the real STFT parameters from the time values. - - Returns: - Tuple[int, int]: hop length and window length for STFT. - """ - factor = self.frame_length_ms / self.frame_shift_ms - assert (factor).is_integer(), " [!] frame_shift_ms should divide frame_length_ms" - hop_length = int(self.frame_shift_ms / 1000.0 * self.sample_rate) - win_length = int(hop_length * factor) - return hop_length, win_length - - ### normalization ### - def normalize(self, S: np.ndarray) -> np.ndarray: - """Normalize values into `[0, self.max_norm]` or `[-self.max_norm, self.max_norm]` - - Args: - S (np.ndarray): Spectrogram to normalize. - - Raises: - RuntimeError: Mean and variance is computed from incompatible parameters. - - Returns: - np.ndarray: Normalized spectrogram. - """ - # pylint: disable=no-else-return - S = S.copy() - if self.signal_norm: - # mean-var scaling - if hasattr(self, "mel_scaler"): - if S.shape[0] == self.num_mels: - return self.mel_scaler.transform(S.T).T - elif S.shape[0] == self.fft_size / 2: - return self.linear_scaler.transform(S.T).T - else: - raise RuntimeError(" [!] Mean-Var stats does not match the given feature dimensions.") - # range normalization - S -= self.ref_level_db # discard certain range of DB assuming it is air noise - S_norm = (S - self.min_level_db) / (-self.min_level_db) - if self.symmetric_norm: - S_norm = ((2 * self.max_norm) * S_norm) - self.max_norm - if self.clip_norm: - S_norm = np.clip( - S_norm, -self.max_norm, self.max_norm # pylint: disable=invalid-unary-operand-type - ) - return S_norm - else: - S_norm = self.max_norm * S_norm - if self.clip_norm: - S_norm = np.clip(S_norm, 0, self.max_norm) - return S_norm - else: - return S - - def denormalize(self, S: np.ndarray) -> np.ndarray: - """Denormalize spectrogram values. - - Args: - S (np.ndarray): Spectrogram to denormalize. - - Raises: - RuntimeError: Mean and variance are incompatible. - - Returns: - np.ndarray: Denormalized spectrogram. - """ - # pylint: disable=no-else-return - S_denorm = S.copy() - if self.signal_norm: - # mean-var scaling - if hasattr(self, "mel_scaler"): - if S_denorm.shape[0] == self.num_mels: - return self.mel_scaler.inverse_transform(S_denorm.T).T - elif S_denorm.shape[0] == self.fft_size / 2: - return self.linear_scaler.inverse_transform(S_denorm.T).T - else: - raise RuntimeError(" [!] Mean-Var stats does not match the given feature dimensions.") - if self.symmetric_norm: - if self.clip_norm: - S_denorm = np.clip( - S_denorm, -self.max_norm, self.max_norm # pylint: disable=invalid-unary-operand-type - ) - S_denorm = ((S_denorm + self.max_norm) * -self.min_level_db / (2 * self.max_norm)) + self.min_level_db - return S_denorm + self.ref_level_db - else: - if self.clip_norm: - S_denorm = np.clip(S_denorm, 0, self.max_norm) - S_denorm = (S_denorm * -self.min_level_db / self.max_norm) + self.min_level_db - return S_denorm + self.ref_level_db - else: - return S_denorm - - ### Mean-STD scaling ### - def load_stats(self, stats_path: str) -> Tuple[np.array, np.array, np.array, np.array, Dict]: - """Loading mean and variance statistics from a `npy` file. - - Args: - stats_path (str): Path to the `npy` file containing - - Returns: - Tuple[np.array, np.array, np.array, np.array, Dict]: loaded statistics and the config used to - compute them. - """ - stats = np.load(stats_path, allow_pickle=True).item() # pylint: disable=unexpected-keyword-arg - mel_mean = stats["mel_mean"] - mel_std = stats["mel_std"] - linear_mean = stats["linear_mean"] - linear_std = stats["linear_std"] - stats_config = stats["audio_config"] - # check all audio parameters used for computing stats - skip_parameters = ["griffin_lim_iters", "stats_path", "do_trim_silence", "ref_level_db", "power"] - for key in stats_config.keys(): - if key in skip_parameters: - continue - if key not in ["sample_rate", "trim_db"]: - assert ( - stats_config[key] == self.__dict__[key] - ), f" [!] Audio param {key} does not match the value used for computing mean-var stats. {stats_config[key]} vs {self.__dict__[key]}" - return mel_mean, mel_std, linear_mean, linear_std, stats_config - - # pylint: disable=attribute-defined-outside-init - def setup_scaler( - self, mel_mean: np.ndarray, mel_std: np.ndarray, linear_mean: np.ndarray, linear_std: np.ndarray - ) -> None: - """Initialize scaler objects used in mean-std normalization. - - Args: - mel_mean (np.ndarray): Mean for melspectrograms. - mel_std (np.ndarray): STD for melspectrograms. - linear_mean (np.ndarray): Mean for full scale spectrograms. - linear_std (np.ndarray): STD for full scale spectrograms. - """ - self.mel_scaler = StandardScaler() - self.mel_scaler.set_stats(mel_mean, mel_std) - self.linear_scaler = StandardScaler() - self.linear_scaler.set_stats(linear_mean, linear_std) ### DB and AMP conversion ### # pylint: disable=no-self-use @@ -737,8 +417,7 @@ def compute_f0(self, x: np.ndarray) -> np.ndarray: Examples: >>> WAV_FILE = filename = librosa.util.example_audio_file() >>> from TTS.config import BaseAudioConfig - >>> from TTS.utils.audio import AudioProcessor - >>> conf = BaseAudioConfig(pitch_fmax=8000) + >>> from TTS.utils.audio.processor import AudioProcessor >>> conf = BaseAudioConfig(pitch_fmax=8000) >>> ap = AudioProcessor(**conf) >>> wav = ap.load_wav(WAV_FILE, sr=22050)[:5 * 22050] >>> pitch = ap.compute_f0(wav) @@ -913,15 +592,3 @@ def quantize(x: np.ndarray, bits: int) -> np.ndarray: def dequantize(x, bits): """Dequantize a waveform from the given number of bits.""" return 2 * x / (2**bits - 1) - 1 - - -def _log(x, base): - if base == 10: - return np.log10(x) - return np.log(x) - - -def _exp(x, base): - if base == 10: - return np.power(10, x) - return np.exp(x) diff --git a/TTS/utils/audio/torch_transforms.py b/TTS/utils/audio/torch_transforms.py new file mode 100644 index 0000000000..21e7b2343a --- /dev/null +++ b/TTS/utils/audio/torch_transforms.py @@ -0,0 +1,165 @@ +import librosa +import torch +from torch import nn + + +class TorchSTFT(nn.Module): # pylint: disable=abstract-method + """Some of the audio processing funtions using Torch for faster batch processing. + + TODO: Merge this with audio.py + + Args: + + n_fft (int): + FFT window size for STFT. + + hop_length (int): + number of frames between STFT columns. + + win_length (int, optional): + STFT window length. + + pad_wav (bool, optional): + If True pad the audio with (n_fft - hop_length) / 2). Defaults to False. + + window (str, optional): + The name of a function to create a window tensor that is applied/multiplied to each frame/window. Defaults to "hann_window" + + sample_rate (int, optional): + target audio sampling rate. Defaults to None. + + mel_fmin (int, optional): + minimum filter frequency for computing melspectrograms. Defaults to None. + + mel_fmax (int, optional): + maximum filter frequency for computing melspectrograms. Defaults to None. + + n_mels (int, optional): + number of melspectrogram dimensions. Defaults to None. + + use_mel (bool, optional): + If True compute the melspectrograms otherwise. Defaults to False. + + do_amp_to_db_linear (bool, optional): + enable/disable amplitude to dB conversion of linear spectrograms. Defaults to False. + + spec_gain (float, optional): + gain applied when converting amplitude to DB. Defaults to 1.0. + + power (float, optional): + Exponent for the magnitude spectrogram, e.g., 1 for energy, 2 for power, etc. Defaults to None. + + use_htk (bool, optional): + Use HTK formula in mel filter instead of Slaney. + + mel_norm (None, 'slaney', or number, optional): + If 'slaney', divide the triangular mel weights by the width of the mel band + (area normalization). + + If numeric, use `librosa.util.normalize` to normalize each filter by to unit l_p norm. + See `librosa.util.normalize` for a full description of supported norm values + (including `+-np.inf`). + + Otherwise, leave all the triangles aiming for a peak value of 1.0. Defaults to "slaney". + """ + + def __init__( + self, + n_fft, + hop_length, + win_length, + pad_wav=False, + window="hann_window", + sample_rate=None, + mel_fmin=0, + mel_fmax=None, + n_mels=80, + use_mel=False, + do_amp_to_db=False, + spec_gain=1.0, + power=None, + use_htk=False, + mel_norm="slaney", + ): + super().__init__() + self.n_fft = n_fft + self.hop_length = hop_length + self.win_length = win_length + self.pad_wav = pad_wav + self.sample_rate = sample_rate + self.mel_fmin = mel_fmin + self.mel_fmax = mel_fmax + self.n_mels = n_mels + self.use_mel = use_mel + self.do_amp_to_db = do_amp_to_db + self.spec_gain = spec_gain + self.power = power + self.use_htk = use_htk + self.mel_norm = mel_norm + self.window = nn.Parameter(getattr(torch, window)(win_length), requires_grad=False) + self.mel_basis = None + if use_mel: + self._build_mel_basis() + + def __call__(self, x): + """Compute spectrogram frames by torch based stft. + + Args: + x (Tensor): input waveform + + Returns: + Tensor: spectrogram frames. + + Shapes: + x: [B x T] or [:math:`[B, 1, T]`] + """ + if x.ndim == 2: + x = x.unsqueeze(1) + if self.pad_wav: + padding = int((self.n_fft - self.hop_length) / 2) + x = torch.nn.functional.pad(x, (padding, padding), mode="reflect") + # B x D x T x 2 + o = torch.stft( + x.squeeze(1), + self.n_fft, + self.hop_length, + self.win_length, + self.window, + center=True, + pad_mode="reflect", # compatible with audio.py + normalized=False, + onesided=True, + return_complex=False, + ) + M = o[:, :, :, 0] + P = o[:, :, :, 1] + S = torch.sqrt(torch.clamp(M**2 + P**2, min=1e-8)) + + if self.power is not None: + S = S**self.power + + if self.use_mel: + S = torch.matmul(self.mel_basis.to(x), S) + if self.do_amp_to_db: + S = self._amp_to_db(S, spec_gain=self.spec_gain) + return S + + def _build_mel_basis(self): + mel_basis = librosa.filters.mel( + self.sample_rate, + self.n_fft, + n_mels=self.n_mels, + fmin=self.mel_fmin, + fmax=self.mel_fmax, + htk=self.use_htk, + norm=self.mel_norm, + ) + self.mel_basis = torch.from_numpy(mel_basis).float() + + @staticmethod + def _amp_to_db(x, spec_gain=1.0): + return torch.log(torch.clamp(x, min=1e-5) * spec_gain) + + @staticmethod + def _db_to_amp(x, spec_gain=1.0): + return torch.exp(x) / spec_gain diff --git a/TTS/utils/io.py b/TTS/utils/io.py index 304df5ed21..7ed0d56191 100644 --- a/TTS/utils/io.py +++ b/TTS/utils/io.py @@ -84,116 +84,4 @@ def load_checkpoint(model, checkpoint_path, use_cuda=False, eval=False): # pyli model.cuda() if eval: model.eval() - return model, state - - -def save_fsspec(state: Any, path: str, **kwargs): - """Like torch.save but can save to other locations (e.g. s3:// , gs://). - - Args: - state: State object to save - path: Any path or url supported by fsspec. - **kwargs: Keyword arguments forwarded to torch.save. - """ - with fsspec.open(path, "wb") as f: - torch.save(state, f, **kwargs) - - -def save_model(config, model, optimizer, scaler, current_step, epoch, output_path, **kwargs): - if hasattr(model, "module"): - model_state = model.module.state_dict() - else: - model_state = model.state_dict() - if isinstance(optimizer, list): - optimizer_state = [optim.state_dict() for optim in optimizer] - else: - optimizer_state = optimizer.state_dict() if optimizer is not None else None - - if isinstance(scaler, list): - scaler_state = [s.state_dict() for s in scaler] - else: - scaler_state = scaler.state_dict() if scaler is not None else None - - if isinstance(config, Coqpit): - config = config.to_dict() - - state = { - "config": config, - "model": model_state, - "optimizer": optimizer_state, - "scaler": scaler_state, - "step": current_step, - "epoch": epoch, - "date": datetime.date.today().strftime("%B %d, %Y"), - } - state.update(kwargs) - save_fsspec(state, output_path) - - -def save_checkpoint( - config, - model, - optimizer, - scaler, - current_step, - epoch, - output_folder, - **kwargs, -): - file_name = "checkpoint_{}.pth".format(current_step) - checkpoint_path = os.path.join(output_folder, file_name) - print("\n > CHECKPOINT : {}".format(checkpoint_path)) - save_model( - config, - model, - optimizer, - scaler, - current_step, - epoch, - checkpoint_path, - **kwargs, - ) - - -def save_best_model( - current_loss, - best_loss, - config, - model, - optimizer, - scaler, - current_step, - epoch, - out_path, - keep_all_best=False, - keep_after=10000, - **kwargs, -): - if current_loss < best_loss: - best_model_name = f"best_model_{current_step}.pth" - checkpoint_path = os.path.join(out_path, best_model_name) - print(" > BEST MODEL : {}".format(checkpoint_path)) - save_model( - config, - model, - optimizer, - scaler, - current_step, - epoch, - checkpoint_path, - model_loss=current_loss, - **kwargs, - ) - fs = fsspec.get_mapper(out_path).fs - # only delete previous if current is saved successfully - if not keep_all_best or (current_step < keep_after): - model_names = fs.glob(os.path.join(out_path, "best_model*.pth")) - for model_name in model_names: - if os.path.basename(model_name) != best_model_name: - fs.rm(model_name) - # create a shortcut which always points to the currently best model - shortcut_name = "best_model.pth" - shortcut_path = os.path.join(out_path, shortcut_name) - fs.copy(checkpoint_path, shortcut_path) - best_loss = current_loss - return best_loss + return model, state \ No newline at end of file diff --git a/TTS/utils/synthesizer.py b/TTS/utils/synthesizer.py index 1f33b53e77..bce7528ce6 100644 --- a/TTS/utils/synthesizer.py +++ b/TTS/utils/synthesizer.py @@ -11,7 +11,7 @@ # pylint: disable=unused-wildcard-import # pylint: disable=wildcard-import from TTS.tts.utils.synthesis import synthesis, transfer_voice, trim_silence -from TTS.utils.audio import AudioProcessor +from TTS.utils.audio.processor import AudioProcessor from TTS.vocoder.models import setup_model as setup_vocoder_model from TTS.vocoder.utils.generic_utils import interpolate_vocoder_input diff --git a/TTS/vocoder/layers/losses.py b/TTS/vocoder/layers/losses.py index 848e292b83..befc43cca6 100644 --- a/TTS/vocoder/layers/losses.py +++ b/TTS/vocoder/layers/losses.py @@ -4,7 +4,7 @@ from torch import nn from torch.nn import functional as F -from TTS.utils.audio import TorchSTFT +from TTS.utils.audio.torch_transforms import TorchSTFT from TTS.vocoder.utils.distribution import discretized_mix_logistic_loss, gaussian_loss ################################# diff --git a/TTS/vocoder/models/gan.py b/TTS/vocoder/models/gan.py index ed5b26dd93..51cdefc242 100644 --- a/TTS/vocoder/models/gan.py +++ b/TTS/vocoder/models/gan.py @@ -209,9 +209,9 @@ def train_log( self, batch: Dict, outputs: Dict, logger: "Logger", assets: Dict, steps: int # pylint: disable=unused-argument ) -> Tuple[Dict, np.ndarray]: """Call `_log()` for training.""" - figures, audios = self._log("eval", self.ap, batch, outputs) - logger.eval_figures(steps, figures) - logger.eval_audios(steps, audios, self.ap.sample_rate) + figures, audios = self._log("train", self.ap, batch, outputs) + logger.train_figures(steps, figures) + logger.train_audios(steps, audios, self.ap.sample_rate) @torch.no_grad() def eval_step(self, batch: Dict, criterion: nn.Module, optimizer_idx: int) -> Tuple[Dict, Dict]: diff --git a/TTS/vocoder/models/univnet_discriminator.py b/TTS/vocoder/models/univnet_discriminator.py index d6b0e5d52c..34e2d1c276 100644 --- a/TTS/vocoder/models/univnet_discriminator.py +++ b/TTS/vocoder/models/univnet_discriminator.py @@ -3,7 +3,7 @@ from torch import nn from torch.nn.utils import spectral_norm, weight_norm -from TTS.utils.audio import TorchSTFT +from TTS.utils.audio.torch_transforms import TorchSTFT from TTS.vocoder.models.hifigan_discriminator import MultiPeriodDiscriminator LRELU_SLOPE = 0.1 diff --git a/TTS/vocoder/utils/generic_utils.py b/TTS/vocoder/utils/generic_utils.py index 63a0af4445..ce2b56fbdd 100644 --- a/TTS/vocoder/utils/generic_utils.py +++ b/TTS/vocoder/utils/generic_utils.py @@ -5,7 +5,8 @@ from matplotlib import pyplot as plt from TTS.tts.utils.visual import plot_spectrogram -from TTS.utils.audio import AudioProcessor +from TTS.utils.audio.numpy_transforms import amp_to_db, build_mel_basis, wav_to_mel +from TTS.utils.audio.processor import AudioProcessor def interpolate_vocoder_input(scale_factor, spec): @@ -29,13 +30,20 @@ def interpolate_vocoder_input(scale_factor, spec): return spec -def plot_results(y_hat: torch.tensor, y: torch.tensor, ap: AudioProcessor, name_prefix: str = None) -> Dict: +def plot_results( + y_hat: torch.tensor, + y: torch.tensor, + ap: AudioProcessor = None, + audio_config: "Coqpit" = None, + name_prefix: str = None, +) -> Dict: """Plot the predicted and the real waveform and their spectrograms. Args: y_hat (torch.tensor): Predicted waveform. y (torch.tensor): Real waveform. - ap (AudioProcessor): Audio processor used to process the waveform. + ap (AudioProcessor): Audio processor used to process the waveform. Defaults to None. + audio_config (Coqpit): Audio configuration. Only used when ```ap``` is None. Defaults to None. name_prefix (str, optional): Name prefix used to name the figures. Defaults to None. Returns: @@ -48,8 +56,17 @@ def plot_results(y_hat: torch.tensor, y: torch.tensor, ap: AudioProcessor, name_ y_hat = y_hat[0].squeeze().detach().cpu().numpy() y = y[0].squeeze().detach().cpu().numpy() - spec_fake = ap.melspectrogram(y_hat).T - spec_real = ap.melspectrogram(y).T + if ap is not None: + spec_fake = ap.melspectrogram(y_hat).T + spec_real = ap.melspectrogram(y).T + elif audio_config is not None: + mel_basis = build_mel_basis(**audio_config) + spec_fake = wav_to_mel(y=y_hat, mel_basis=mel_basis, **audio_config).T + spec_real = wav_to_mel(y=y, mel_basis=mel_basis, **audio_config).T + spec_fake = amp_to_db(x=spec_fake, gain=1.0, base=10.0) + spec_real = amp_to_db(x=spec_real, gain=1.0, base=10.0) + else: + raise ValueError(" [!] Either `ap` or `audio_config` must be provided.") spec_diff = np.abs(spec_fake - spec_real) # plot figure and save it diff --git a/recipes/ljspeech/align_tts/train_aligntts.py b/recipes/ljspeech/align_tts/train_aligntts.py index 591b15091f..da673e5443 100644 --- a/recipes/ljspeech/align_tts/train_aligntts.py +++ b/recipes/ljspeech/align_tts/train_aligntts.py @@ -7,7 +7,7 @@ from TTS.tts.datasets import load_tts_samples from TTS.tts.models.align_tts import AlignTTS from TTS.tts.utils.text.tokenizer import TTSTokenizer -from TTS.utils.audio import AudioProcessor +from TTS.utils.audio.processor import AudioProcessor output_path = os.path.dirname(os.path.abspath(__file__)) diff --git a/recipes/ljspeech/fast_pitch/train_fast_pitch.py b/recipes/ljspeech/fast_pitch/train_fast_pitch.py index a84658f35f..694552fb4a 100644 --- a/recipes/ljspeech/fast_pitch/train_fast_pitch.py +++ b/recipes/ljspeech/fast_pitch/train_fast_pitch.py @@ -7,7 +7,7 @@ from TTS.tts.datasets import load_tts_samples from TTS.tts.models.forward_tts import ForwardTTS from TTS.tts.utils.text.tokenizer import TTSTokenizer -from TTS.utils.audio import AudioProcessor +from TTS.utils.audio.processor import AudioProcessor from TTS.utils.manage import ModelManager output_path = os.path.dirname(os.path.abspath(__file__)) diff --git a/recipes/ljspeech/fast_pitch_e2e/train_fast_pitch_e2e.py b/recipes/ljspeech/fast_pitch_e2e/train_fast_pitch_e2e.py new file mode 100644 index 0000000000..754c1a0813 --- /dev/null +++ b/recipes/ljspeech/fast_pitch_e2e/train_fast_pitch_e2e.py @@ -0,0 +1,88 @@ +import os + +from trainer import Trainer, TrainerArgs + +from TTS.config.shared_configs import BaseDatasetConfig +from TTS.tts.configs.fast_pitch_e2e_config import FastPitchE2eConfig +from TTS.tts.datasets import load_tts_samples +from TTS.tts.models.forward_tts_e2e import ForwardTTSE2e, ForwardTTSE2eArgs, ForwardTTSE2eAudio +from TTS.tts.utils.text.tokenizer import TTSTokenizer + +output_path = os.path.dirname(os.path.abspath(__file__)) + +# init configs +dataset_config = BaseDatasetConfig( + name="ljspeech", + meta_file_train="metadata.csv", + path=os.path.join(output_path, "../LJSpeech-1.1/"), +) + +audio_config = ForwardTTSE2eAudio( + sample_rate=22050, + hop_length=256, + win_length=1024, + fft_size=1024, + mel_fmin=0.0, + mel_fmax=8000, + pitch_fmax=640.0, + num_mels=80, +) + +# vocoder_config = HifiganConfig() +model_args = ForwardTTSE2eArgs() + +config = FastPitchE2eConfig( + run_name="fast_pitch_e2e_ljspeech", + run_description="Train like in FS2 paper.", + model_args=model_args, + audio=audio_config, + batch_size=32, + eval_batch_size=16, + num_loader_workers=8, + num_eval_loader_workers=4, + compute_input_seq_cache=True, + compute_f0=True, + f0_cache_path=os.path.join(output_path, "f0_cache"), + run_eval=True, + test_delay_epochs=-1, + epochs=1000, + text_cleaner="english_cleaners", + use_phonemes=True, + phoneme_language="en-us", + phoneme_cache_path=os.path.join(output_path, "phoneme_cache"), + precompute_num_workers=4, + print_step=50, + print_eval=False, + mixed_precision=False, + sort_by_audio_len=True, + output_path=output_path, + datasets=[dataset_config], + start_by_longest=False, + binary_align_loss_alpha=0.0, +) + +# INITIALIZE THE TOKENIZER +# Tokenizer is used to convert text to sequences of token IDs. +# If characters are not defined in the config, default characters are passed to the config +tokenizer, config = TTSTokenizer.init_from_config(config) + +# LOAD DATA SAMPLES +# Each sample is a list of ```[text, audio_file_path, speaker_name]``` +# You can define your custom sample loader returning the list of samples. +# Or define your custom formatter and pass it to the `load_tts_samples`. +# Check `TTS.tts.datasets.load_tts_samples` for more details. +train_samples, eval_samples = load_tts_samples( + dataset_config, + eval_split=True, + eval_split_max_size=config.eval_split_max_size, + eval_split_size=config.eval_split_size, +) + +# init the model +model = ForwardTTSE2e(config=config, tokenizer=tokenizer, speaker_manager=None) + +# init the trainer and 🚀 +trainer = Trainer( + TrainerArgs(), config, output_path, model=model, train_samples=train_samples, eval_samples=eval_samples +) +trainer.fit() diff --git a/recipes/ljspeech/fast_speech/train_fast_speech.py b/recipes/ljspeech/fast_speech/train_fast_speech.py index 0245dd938b..2a525c58d4 100644 --- a/recipes/ljspeech/fast_speech/train_fast_speech.py +++ b/recipes/ljspeech/fast_speech/train_fast_speech.py @@ -7,7 +7,7 @@ from TTS.tts.datasets import load_tts_samples from TTS.tts.models.forward_tts import ForwardTTS from TTS.tts.utils.text.tokenizer import TTSTokenizer -from TTS.utils.audio import AudioProcessor +from TTS.utils.audio.processor import AudioProcessor from TTS.utils.manage import ModelManager output_path = os.path.dirname(os.path.abspath(__file__)) diff --git a/recipes/ljspeech/glow_tts/train_glowtts.py b/recipes/ljspeech/glow_tts/train_glowtts.py index a0b4ac48b4..85ca450cf2 100644 --- a/recipes/ljspeech/glow_tts/train_glowtts.py +++ b/recipes/ljspeech/glow_tts/train_glowtts.py @@ -12,7 +12,7 @@ from TTS.tts.datasets import load_tts_samples from TTS.tts.models.glow_tts import GlowTTS from TTS.tts.utils.text.tokenizer import TTSTokenizer -from TTS.utils.audio import AudioProcessor +from TTS.utils.audio.processor import AudioProcessor # we use the same path as this script as our training folder. output_path = os.path.dirname(os.path.abspath(__file__)) diff --git a/recipes/ljspeech/hifigan/train_hifigan.py b/recipes/ljspeech/hifigan/train_hifigan.py index b4cbae63ed..c96f721bb6 100644 --- a/recipes/ljspeech/hifigan/train_hifigan.py +++ b/recipes/ljspeech/hifigan/train_hifigan.py @@ -2,7 +2,7 @@ from trainer import Trainer, TrainerArgs -from TTS.utils.audio import AudioProcessor +from TTS.utils.audio.processor import AudioProcessor from TTS.vocoder.configs import HifiganConfig from TTS.vocoder.datasets.preprocess import load_wav_data from TTS.vocoder.models.gan import GAN diff --git a/recipes/ljspeech/multiband_melgan/train_multiband_melgan.py b/recipes/ljspeech/multiband_melgan/train_multiband_melgan.py index 225f5a302f..e8f3d06636 100644 --- a/recipes/ljspeech/multiband_melgan/train_multiband_melgan.py +++ b/recipes/ljspeech/multiband_melgan/train_multiband_melgan.py @@ -2,7 +2,7 @@ from trainer import Trainer, TrainerArgs -from TTS.utils.audio import AudioProcessor +from TTS.utils.audio.processor import AudioProcessor from TTS.vocoder.configs import MultibandMelganConfig from TTS.vocoder.datasets.preprocess import load_wav_data from TTS.vocoder.models.gan import GAN diff --git a/recipes/ljspeech/speedy_speech/train_speedy_speech.py b/recipes/ljspeech/speedy_speech/train_speedy_speech.py index 1ab3db1c2e..9d51c36a6f 100644 --- a/recipes/ljspeech/speedy_speech/train_speedy_speech.py +++ b/recipes/ljspeech/speedy_speech/train_speedy_speech.py @@ -7,7 +7,7 @@ from TTS.tts.datasets import load_tts_samples from TTS.tts.models.forward_tts import ForwardTTS from TTS.tts.utils.text.tokenizer import TTSTokenizer -from TTS.utils.audio import AudioProcessor +from TTS.utils.audio.processor import AudioProcessor output_path = os.path.dirname(os.path.abspath(__file__)) dataset_config = BaseDatasetConfig( diff --git a/recipes/ljspeech/tacotron2-DCA/train_tacotron_dca.py b/recipes/ljspeech/tacotron2-DCA/train_tacotron_dca.py index a9f253ea86..ddd58bd68a 100644 --- a/recipes/ljspeech/tacotron2-DCA/train_tacotron_dca.py +++ b/recipes/ljspeech/tacotron2-DCA/train_tacotron_dca.py @@ -8,7 +8,7 @@ from TTS.tts.datasets import load_tts_samples from TTS.tts.models.tacotron2 import Tacotron2 from TTS.tts.utils.text.tokenizer import TTSTokenizer -from TTS.utils.audio import AudioProcessor +from TTS.utils.audio.processor import AudioProcessor # from TTS.tts.datasets.tokenizer import Tokenizer diff --git a/recipes/ljspeech/tacotron2-DDC/train_tacotron_ddc.py b/recipes/ljspeech/tacotron2-DDC/train_tacotron_ddc.py index 99089db83e..fd30943a2b 100644 --- a/recipes/ljspeech/tacotron2-DDC/train_tacotron_ddc.py +++ b/recipes/ljspeech/tacotron2-DDC/train_tacotron_ddc.py @@ -8,7 +8,7 @@ from TTS.tts.datasets import load_tts_samples from TTS.tts.models.tacotron2 import Tacotron2 from TTS.tts.utils.text.tokenizer import TTSTokenizer -from TTS.utils.audio import AudioProcessor +from TTS.utils.audio.processor import AudioProcessor # from TTS.tts.datasets.tokenizer import Tokenizer diff --git a/recipes/ljspeech/univnet/train.py b/recipes/ljspeech/univnet/train.py index 81d2b889b9..471333b35f 100644 --- a/recipes/ljspeech/univnet/train.py +++ b/recipes/ljspeech/univnet/train.py @@ -2,7 +2,7 @@ from trainer import Trainer, TrainerArgs -from TTS.utils.audio import AudioProcessor +from TTS.utils.audio.processor import AudioProcessor from TTS.vocoder.configs import UnivnetConfig from TTS.vocoder.datasets.preprocess import load_wav_data from TTS.vocoder.models.gan import GAN diff --git a/recipes/ljspeech/vits_tts/train_vits.py b/recipes/ljspeech/vits_tts/train_vits.py index c070b3f1ce..203be22c2f 100644 --- a/recipes/ljspeech/vits_tts/train_vits.py +++ b/recipes/ljspeech/vits_tts/train_vits.py @@ -8,7 +8,7 @@ from TTS.tts.datasets import load_tts_samples from TTS.tts.models.vits import Vits from TTS.tts.utils.text.tokenizer import TTSTokenizer -from TTS.utils.audio import AudioProcessor +from TTS.utils.audio.processor import AudioProcessor output_path = os.path.dirname(os.path.abspath(__file__)) dataset_config = BaseDatasetConfig( @@ -37,7 +37,7 @@ batch_size=32, eval_batch_size=16, batch_group_size=5, - num_loader_workers=0, + num_loader_workers=8, num_eval_loader_workers=4, run_eval=True, test_delay_epochs=-1, diff --git a/recipes/ljspeech/wavegrad/train_wavegrad.py b/recipes/ljspeech/wavegrad/train_wavegrad.py index 1abdf45d87..be9e0a09c8 100644 --- a/recipes/ljspeech/wavegrad/train_wavegrad.py +++ b/recipes/ljspeech/wavegrad/train_wavegrad.py @@ -2,7 +2,7 @@ from trainer import Trainer, TrainerArgs -from TTS.utils.audio import AudioProcessor +from TTS.utils.audio.processor import AudioProcessor from TTS.vocoder.configs import WavegradConfig from TTS.vocoder.datasets.preprocess import load_wav_data from TTS.vocoder.models.wavegrad import Wavegrad diff --git a/recipes/ljspeech/wavernn/train_wavernn.py b/recipes/ljspeech/wavernn/train_wavernn.py index 640f509218..75be20e21c 100644 --- a/recipes/ljspeech/wavernn/train_wavernn.py +++ b/recipes/ljspeech/wavernn/train_wavernn.py @@ -2,7 +2,7 @@ from trainer import Trainer, TrainerArgs -from TTS.utils.audio import AudioProcessor +from TTS.utils.audio.processor import AudioProcessor from TTS.vocoder.configs import WavernnConfig from TTS.vocoder.datasets.preprocess import load_wav_data from TTS.vocoder.models.wavernn import Wavernn diff --git a/recipes/multilingual/vits_tts/train_vits_tts.py b/recipes/multilingual/vits_tts/train_vits_tts.py index 0e650ade8e..c4b6216632 100644 --- a/recipes/multilingual/vits_tts/train_vits_tts.py +++ b/recipes/multilingual/vits_tts/train_vits_tts.py @@ -11,7 +11,7 @@ from TTS.tts.utils.languages import LanguageManager from TTS.tts.utils.speakers import SpeakerManager from TTS.tts.utils.text.tokenizer import TTSTokenizer -from TTS.utils.audio import AudioProcessor +from TTS.utils.audio.processor import AudioProcessor output_path = os.path.dirname(os.path.abspath(__file__)) diff --git a/recipes/vctk/fast_pitch/train_fast_pitch.py b/recipes/vctk/fast_pitch/train_fast_pitch.py index c39932daaa..a1d838f56e 100644 --- a/recipes/vctk/fast_pitch/train_fast_pitch.py +++ b/recipes/vctk/fast_pitch/train_fast_pitch.py @@ -8,7 +8,7 @@ from TTS.tts.models.forward_tts import ForwardTTS from TTS.tts.utils.speakers import SpeakerManager from TTS.tts.utils.text.tokenizer import TTSTokenizer -from TTS.utils.audio import AudioProcessor +from TTS.utils.audio.processor import AudioProcessor output_path = os.path.dirname(os.path.abspath(__file__)) dataset_config = BaseDatasetConfig(name="vctk", meta_file_train="", path=os.path.join(output_path, "../VCTK/")) diff --git a/recipes/vctk/fast_speech/train_fast_speech.py b/recipes/vctk/fast_speech/train_fast_speech.py index a3249de1cf..3bc839504c 100644 --- a/recipes/vctk/fast_speech/train_fast_speech.py +++ b/recipes/vctk/fast_speech/train_fast_speech.py @@ -8,7 +8,7 @@ from TTS.tts.models.forward_tts import ForwardTTS from TTS.tts.utils.speakers import SpeakerManager from TTS.tts.utils.text.tokenizer import TTSTokenizer -from TTS.utils.audio import AudioProcessor +from TTS.utils.audio.processor import AudioProcessor output_path = os.path.dirname(os.path.abspath(__file__)) dataset_config = BaseDatasetConfig(name="vctk", meta_file_train="", path=os.path.join(output_path, "../VCTK/")) diff --git a/recipes/vctk/glow_tts/train_glow_tts.py b/recipes/vctk/glow_tts/train_glow_tts.py index 23c02efc79..f82fca6306 100644 --- a/recipes/vctk/glow_tts/train_glow_tts.py +++ b/recipes/vctk/glow_tts/train_glow_tts.py @@ -9,7 +9,7 @@ from TTS.tts.models.glow_tts import GlowTTS from TTS.tts.utils.speakers import SpeakerManager from TTS.tts.utils.text.tokenizer import TTSTokenizer -from TTS.utils.audio import AudioProcessor +from TTS.utils.audio.processor import AudioProcessor # set experiment paths output_path = os.path.dirname(os.path.abspath(__file__)) diff --git a/recipes/vctk/speedy_speech/train_speedy_speech.py b/recipes/vctk/speedy_speech/train_speedy_speech.py index bcd0105af8..b24b2f3b4c 100644 --- a/recipes/vctk/speedy_speech/train_speedy_speech.py +++ b/recipes/vctk/speedy_speech/train_speedy_speech.py @@ -8,7 +8,7 @@ from TTS.tts.models.forward_tts import ForwardTTS from TTS.tts.utils.speakers import SpeakerManager from TTS.tts.utils.text.tokenizer import TTSTokenizer -from TTS.utils.audio import AudioProcessor +from TTS.utils.audio.processor import AudioProcessor output_path = os.path.dirname(os.path.abspath(__file__)) dataset_config = BaseDatasetConfig(name="vctk", meta_file_train="", path=os.path.join(output_path, "../VCTK/")) diff --git a/recipes/vctk/tacotron-DDC/train_tacotron-DDC.py b/recipes/vctk/tacotron-DDC/train_tacotron-DDC.py index 36e28ed769..efdb150e78 100644 --- a/recipes/vctk/tacotron-DDC/train_tacotron-DDC.py +++ b/recipes/vctk/tacotron-DDC/train_tacotron-DDC.py @@ -9,7 +9,7 @@ from TTS.tts.models.tacotron import Tacotron from TTS.tts.utils.speakers import SpeakerManager from TTS.tts.utils.text.tokenizer import TTSTokenizer -from TTS.utils.audio import AudioProcessor +from TTS.utils.audio.processor import AudioProcessor output_path = os.path.dirname(os.path.abspath(__file__)) dataset_config = BaseDatasetConfig(name="vctk", meta_file_train="", path=os.path.join(output_path, "../VCTK/")) diff --git a/recipes/vctk/tacotron2-DDC/train_tacotron2-ddc.py b/recipes/vctk/tacotron2-DDC/train_tacotron2-ddc.py index d04d91c066..ea17508506 100644 --- a/recipes/vctk/tacotron2-DDC/train_tacotron2-ddc.py +++ b/recipes/vctk/tacotron2-DDC/train_tacotron2-ddc.py @@ -9,7 +9,7 @@ from TTS.tts.models.tacotron2 import Tacotron2 from TTS.tts.utils.speakers import SpeakerManager from TTS.tts.utils.text.tokenizer import TTSTokenizer -from TTS.utils.audio import AudioProcessor +from TTS.utils.audio.processor import AudioProcessor output_path = os.path.dirname(os.path.abspath(__file__)) dataset_config = BaseDatasetConfig(name="vctk", meta_file_train="", path=os.path.join(output_path, "../VCTK/")) diff --git a/recipes/vctk/tacotron2/train_tacotron2.py b/recipes/vctk/tacotron2/train_tacotron2.py index 5a0e157a93..76bc25d4e8 100644 --- a/recipes/vctk/tacotron2/train_tacotron2.py +++ b/recipes/vctk/tacotron2/train_tacotron2.py @@ -9,7 +9,7 @@ from TTS.tts.models.tacotron2 import Tacotron2 from TTS.tts.utils.speakers import SpeakerManager from TTS.tts.utils.text.tokenizer import TTSTokenizer -from TTS.utils.audio import AudioProcessor +from TTS.utils.audio.processor import AudioProcessor output_path = os.path.dirname(os.path.abspath(__file__)) dataset_config = BaseDatasetConfig(name="vctk", meta_file_train="", path=os.path.join(output_path, "../VCTK/")) diff --git a/recipes/vctk/vits/train_vits.py b/recipes/vctk/vits/train_vits.py index 88fd7de9a1..9aeb2de770 100644 --- a/recipes/vctk/vits/train_vits.py +++ b/recipes/vctk/vits/train_vits.py @@ -9,7 +9,7 @@ from TTS.tts.models.vits import Vits, VitsArgs from TTS.tts.utils.speakers import SpeakerManager from TTS.tts.utils.text.tokenizer import TTSTokenizer -from TTS.utils.audio import AudioProcessor +from TTS.utils.audio.processor import AudioProcessor output_path = os.path.dirname(os.path.abspath(__file__)) dataset_config = BaseDatasetConfig( diff --git a/tests/tts_tests/test_fast_pitch_e2e.py b/tests/tts_tests/test_fast_pitch_e2e.py new file mode 100644 index 0000000000..ca34dec92f --- /dev/null +++ b/tests/tts_tests/test_fast_pitch_e2e.py @@ -0,0 +1,365 @@ +import copy +import os +import unittest + +import torch +from trainer.logging.tensorboard_logger import TensorboardLogger + +from tests import assertHasAttr, assertHasNotAttr, get_tests_data_path, get_tests_input_path, get_tests_output_path +from TTS.tts.configs.fast_pitch_e2e_config import FastPitchE2eConfig +from TTS.tts.models.forward_tts_e2e import ForwardTTSE2e, ForwardTTSE2eArgs + +LANG_FILE = os.path.join(get_tests_input_path(), "language_ids.json") +SPEAKER_ENCODER_CONFIG = os.path.join(get_tests_input_path(), "test_speaker_encoder_config.json") +WAV_FILE = os.path.join(get_tests_input_path(), "example_1.wav") + + +torch.manual_seed(1) +use_cuda = torch.cuda.is_available() +device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + +MAX_INPUT_LEN = 57 +MAX_SPEC_LEN = 33 + + +# pylint: disable=no-self-use +class TestFastPitchE2E(unittest.TestCase): + def _create_inputs(self, config, batch_size=2): + + input_dummy = torch.randint(0, 24, (batch_size, MAX_INPUT_LEN)).long().to(device) + input_lengths = torch.randint(10, MAX_INPUT_LEN, (batch_size,)).long().to(device) + input_lengths[-1] = MAX_INPUT_LEN + spec = torch.rand(batch_size, MAX_SPEC_LEN, config.audio["num_mels"]).to(device) + spec_lengths = torch.randint(20, MAX_SPEC_LEN, (batch_size,)).long().to(device) + spec_lengths[-1] = MAX_SPEC_LEN + waveform = torch.rand(batch_size, 1, spec.size(1) * config.audio["hop_length"]).to(device) + pitch = torch.rand(batch_size, 1, spec.size(1)).to(device) + return input_dummy, input_lengths, spec, spec_lengths, waveform, pitch + + def _check_forward_outputs(self, config, output_dict, batch_size=2): + self.assertEqual( + output_dict["model_outputs"].shape[2], config.model_args.spec_segment_size * config.audio["hop_length"] + ) + self.assertEqual(output_dict["alignments"].shape, (batch_size, MAX_SPEC_LEN, MAX_INPUT_LEN)) + self.assertEqual(output_dict["alignments"].max(), 1) + self.assertEqual(output_dict["alignments"].min(), 0) + self.assertEqual( + output_dict["waveform_seg"].shape[2], config.model_args.spec_segment_size * config.audio["hop_length"] + ) + + def _check_inference_outputs(self, outputs, input_dummy, batch_size=1): + feat_dim = 256 # hard-coded based on model architecture + feat_len = outputs["o_en_ex"].shape[2] + self.assertEqual(outputs["o_en_ex"].shape, (batch_size, feat_dim, feat_len)) + self.assertEqual(outputs["model_outputs"].shape[:2], (batch_size, 1)) # we don't know the channel dimension + self.assertEqual(outputs["alignments"].shape, (batch_size, input_dummy.shape[1], feat_len)) + + @staticmethod + def _check_parameter_changes(model, model_ref): + count = 0 + for item1, item2 in zip(model.named_parameters(), model_ref.named_parameters()): + name = item1[0] + param = item1[1] + param_ref = item2[1] + assert (param != param_ref).any(), "param {} with shape {} not updated!! \n{}\n{}".format( + name, param.shape, param, param_ref + ) + count = count + 1 + + def _create_batch(self, config, batch_size): + input_dummy, input_lengths, spec, spec_lengths, waveform, pitch = self._create_inputs(config, batch_size) + batch = {} + batch["text_input"] = input_dummy + batch["text_lengths"] = input_lengths + batch["mel_lengths"] = spec_lengths + batch["mel_input"] = spec + batch["waveform"] = waveform # B x C X T + batch["d_vectors"] = None + batch["speaker_ids"] = None + batch["language_ids"] = None + batch["pitch"] = pitch + return batch + + def test_init_multispeaker(self): + + num_speakers = 10 + model_args = ForwardTTSE2eArgs() + model_args.num_speakers = num_speakers + model_args.use_speaker_embedding = True + model = ForwardTTSE2e(model_args) + assertHasAttr(self, model.encoder_model, "emb_g") + + model_args = ForwardTTSE2eArgs() + model_args.num_speakers = 0 + model_args.use_speaker_embedding = True + model = ForwardTTSE2e(model_args) + assertHasNotAttr(self, model.encoder_model, "emb_g") + + model_args = ForwardTTSE2eArgs() + model_args.num_speakers = 10 + model_args.use_speaker_embedding = False + model = ForwardTTSE2e(model_args) + assertHasNotAttr(self, model.encoder_model, "emb_g") + + model_args = ForwardTTSE2eArgs(d_vector_dim=101, use_d_vector_file=True) + model = ForwardTTSE2e(model_args) + self.assertEqual(model.encoder_model.embedded_speaker_dim, 101) + + def test_init_multilingual(self): + """TODO""" + + def test_get_aux_input(self): + aux_input = {"speaker_ids": None, "style_wav": None, "d_vectors": None, "language_ids": None} + model_args = ForwardTTSE2eArgs() + model = ForwardTTSE2e(model_args) + aux_out = model.get_aux_input(aux_input) + + speaker_id = torch.randint(10, (1,)) + language_id = torch.randint(10, (1,)) + d_vector = torch.rand(1, 128) + aux_input = {"speaker_ids": speaker_id, "style_wav": None, "d_vectors": d_vector, "language_ids": language_id} + aux_out = model.get_aux_input(aux_input) + self.assertEqual(aux_out["speaker_ids"].shape, speaker_id.shape) + self.assertEqual(aux_out["language_ids"].shape, language_id.shape) + self.assertEqual(aux_out["d_vectors"].shape, d_vector.unsqueeze(0).transpose(2, 1).shape) + + def test_forward(self): + model_args = ForwardTTSE2eArgs(spec_segment_size=10) + config = FastPitchE2eConfig(model_args=model_args) + input_dummy, input_lengths, spec, spec_lengths, waveform, pitch = self._create_inputs(config) + model = ForwardTTSE2e(config).to(device) + output_dict = model.forward( + x=input_dummy, x_lengths=input_lengths, spec=spec, spec_lengths=spec_lengths, waveform=waveform, pitch=pitch + ) + self._check_forward_outputs(config, output_dict) + + def test_multispeaker_forward(self): + batch_size = 2 + num_speakers = 10 + model_args = ForwardTTSE2eArgs(spec_segment_size=10, num_speakers=num_speakers, use_speaker_embedding=True) + config = FastPitchE2eConfig(model_args=model_args) + config.model_args.spec_segment_size = 10 + + input_dummy, input_lengths, spec, spec_lengths, waveform, pitch = self._create_inputs( + config, batch_size=batch_size + ) + speaker_ids = torch.randint(0, num_speakers, (batch_size,)).long().to(device) + + model = ForwardTTSE2e(config).to(device) + output_dict = model.forward( + x=input_dummy, + x_lengths=input_lengths, + spec=spec, + spec_lengths=spec_lengths, + waveform=waveform, + pitch=pitch, + aux_input={"speaker_ids": speaker_ids}, + ) + self._check_forward_outputs(config, output_dict) + + def test_d_vector_forward(self): + batch_size = 2 + model_args = ForwardTTSE2eArgs(spec_segment_size=10, use_d_vector_file=True, d_vector_dim=256) + config = FastPitchE2eConfig(model_args=model_args) + config.model_args.spec_segment_size = 10 + model = ForwardTTSE2e(config).to(device) + model.train() + input_dummy, input_lengths, spec, spec_lengths, waveform, pitch = self._create_inputs( + config, batch_size=batch_size + ) + d_vectors = torch.randn(batch_size, 256).to(device) + output_dict = model.forward( + x=input_dummy, + x_lengths=input_lengths, + spec=spec, + spec_lengths=spec_lengths, + waveform=waveform, + pitch=pitch, + aux_input={"d_vectors": d_vectors}, + ) + self._check_forward_outputs(config, output_dict) + + # def test_multilingual_forward(self): + # """TODO""" + + def test_inference(self): + model_args = ForwardTTSE2eArgs(spec_segment_size=10) + config = FastPitchE2eConfig(model_args=model_args) + model = ForwardTTSE2e(config).to(device) + model.eval() + + batch_size = 1 + input_dummy, *_ = self._create_inputs(config, batch_size=batch_size) + outputs = model.inference(input_dummy.to(device)) + self._check_inference_outputs(outputs, input_dummy, batch_size=batch_size) + + # TODO implemented batched inferenece + # batch_size = 2 + # input_dummy, input_lengths, *_ = self._create_inputs(config, batch_size=batch_size) + # outputs = model.inference(input_dummy, aux_input={"x_lengths": input_lengths}) + # self._check_inference_outputs(outputs, input_dummy, batch_size=batch_size) + + def test_multispeaker_inference(self): + num_speakers = 10 + model_args = ForwardTTSE2eArgs(spec_segment_size=10, num_speakers=num_speakers, use_speaker_embedding=True) + config = FastPitchE2eConfig(model_args=model_args) + model = ForwardTTSE2e(config).to(device) + + batch_size = 1 + input_dummy, *_ = self._create_inputs(config, batch_size=batch_size) + speaker_ids = torch.randint(0, num_speakers, (batch_size,)).long().to(device) + outputs = model.inference(input_dummy, {"speaker_ids": speaker_ids}) + self._check_inference_outputs(outputs, input_dummy, batch_size=batch_size) + + # batch_size = 2 + # input_dummy, input_lengths, *_ = self._create_inputs(config, batch_size=batch_size) + # speaker_ids = torch.randint(0, num_speakers, (batch_size,)).long().to(device) + # outputs = model.inference(input_dummy, {"x_lengths": input_lengths, "speaker_ids": speaker_ids}) + # self._check_inference_outputs(outputs, input_dummy, batch_size=batch_size) + + # def test_multilingual_inference(self): + # """TODO""" + + def test_d_vector_inference(self): + model_args = ForwardTTSE2eArgs( + spec_segment_size=10, + num_chars=32, + use_d_vector_file=True, + d_vector_dim=256, + d_vector_file=os.path.join(get_tests_data_path(), "dummy_speakers.json"), + ) + config = FastPitchE2eConfig(model_args=model_args) + model = ForwardTTSE2e(config).to(device) + model.eval() + # batch size = 1 + input_dummy, *_ = self._create_inputs(config, batch_size=1) + d_vectors = torch.randn(1, 256).to(device) + outputs = model.inference(input_dummy, aux_input={"d_vectors": d_vectors}) + self._check_inference_outputs(outputs, input_dummy) + # batch size = 2 + # input_dummy, input_lengths, *_ = self._create_inputs(config) + # d_vectors = torch.randn(2, 256).to(device) + # outputs = model.inference(input_dummy, aux_input={"x_lengths": input_lengths, "d_vectors": d_vectors}) + # self._check_inference_outputs(outputs, input_dummy, batch_size=2) + + def test_train_step(self): + # setup the model + with torch.autograd.set_detect_anomaly(True): + model_args = ForwardTTSE2eArgs(spec_segment_size=10) + config = FastPitchE2eConfig(model_args=model_args) + model = ForwardTTSE2e(config).to(device) + model.train() + # model to train + optimizers = model.get_optimizer() + criterions = model.get_criterion() + criterions = [criterions[0].to(device), criterions[1].to(device)] + # reference model to compare model weights + model_ref = ForwardTTSE2e(config).to(device) + # # pass the state to ref model + model_ref.load_state_dict(copy.deepcopy(model.state_dict())) + count = 0 + for param, param_ref in zip(model.parameters(), model_ref.parameters()): + assert (param - param_ref).sum() == 0, param + count = count + 1 + for _ in range(5): + batch = self._create_batch(config, 2) + for idx in [0, 1]: + outputs, loss_dict = model.train_step(batch, criterions, idx) + self.assertFalse(not outputs) + self.assertFalse(not loss_dict) + loss_dict["loss"].backward() + optimizers[idx].step() + optimizers[idx].zero_grad() + + # check parameter changes + self._check_parameter_changes(model, model_ref) + + def test_train_eval_log(self): + batch_size = 2 + model_args = ForwardTTSE2eArgs(spec_segment_size=10) + config = FastPitchE2eConfig(model_args=model_args) + model = ForwardTTSE2e.init_from_config(config, verbose=False).to(device) + model.train() + model.on_init_start(trainer=None) # create mel_basis + batch = self._create_batch(config, batch_size) + logger = TensorboardLogger( + log_dir=os.path.join(get_tests_output_path(), "dummy_fast_pitch_e2e_logs"), + model_name="fast_pitch_e2e_test_train_log", + ) + criterion = model.get_criterion() + criterion = [criterion[0].to(device), criterion[1].to(device)] + outputs = [None] * 2 + outputs[0], _ = model.train_step(batch, criterion, 0) + outputs[1], _ = model.train_step(batch, criterion, 1) + model.train_log(batch=batch, outputs=outputs, logger=logger, assets=None, steps=1) + model.eval_log(batch, outputs, logger, None, 1) + logger.finish() + + def test_test_run(self): + model_args = ForwardTTSE2eArgs(spec_segment_size=10) + config = FastPitchE2eConfig(model_args=model_args) + model = ForwardTTSE2e.init_from_config(config, verbose=False).to(device) + model.eval() + model.on_init_start(trainer=None) # create mel_basis + test_figures, test_audios = model.test_run(None) + self.assertTrue(test_figures is not None) + self.assertTrue(test_audios is not None) + + def test_load_checkpoint(self): + chkp_path = os.path.join(get_tests_output_path(), "dummy_fast_pitch_e2e_tts_checkpoint.pth") + model_args = ForwardTTSE2eArgs(spec_segment_size=10) + config = FastPitchE2eConfig(model_args=model_args) + model = ForwardTTSE2e.init_from_config(config, verbose=False).to(device) + chkp = {} + chkp["model"] = model.state_dict() + torch.save(chkp, chkp_path) + model.load_checkpoint(config, chkp_path) + self.assertTrue(model.training) + model.load_checkpoint(config, chkp_path, eval=True) + self.assertFalse(model.training) + + def test_get_criterion(self): + model_args = ForwardTTSE2eArgs(spec_segment_size=10) + config = FastPitchE2eConfig(model_args=model_args) + model = ForwardTTSE2e.init_from_config(config, verbose=False).to(device) + criterion = model.get_criterion() + self.assertTrue(criterion is not None) + + def test_init_from_config(self): + model_args = ForwardTTSE2eArgs(spec_segment_size=10) + config = FastPitchE2eConfig(model_args=model_args) + model = ForwardTTSE2e.init_from_config(config, verbose=False).to(device) + + model_args = ForwardTTSE2eArgs(spec_segment_size=10, num_speakers=2) + config = FastPitchE2eConfig(model_args=model_args) + model = ForwardTTSE2e.init_from_config(config, verbose=False).to(device) + self.assertTrue(not hasattr(model, "emb_g")) + + model_args = ForwardTTSE2eArgs(spec_segment_size=10, num_speakers=2, use_speaker_embedding=True) + config = FastPitchE2eConfig(model_args=model_args) + model = ForwardTTSE2e.init_from_config(config, verbose=False).to(device) + self.assertEqual(model.num_speakers, 2) + self.assertTrue(hasattr(model, "emb_g")) + + model_args = ForwardTTSE2eArgs( + spec_segment_size=10, + num_speakers=2, + use_speaker_embedding=True, + speakers_file=os.path.join(get_tests_data_path(), "ljspeech", "speakers.json"), + ) + config = FastPitchE2eConfig(model_args=model_args) + model = ForwardTTSE2e.init_from_config(config, verbose=False).to(device) + self.assertEqual(model.num_speakers, 10) + self.assertTrue(hasattr(model, "emb_g")) + + model_args = ForwardTTSE2eArgs( + spec_segment_size=10, + use_d_vector_file=True, + d_vector_dim=256, + d_vector_file=os.path.join(get_tests_data_path(), "ljspeech", "speakers.json"), + ) + config = FastPitchE2eConfig(model_args=model_args) + model = ForwardTTSE2e.init_from_config(config, verbose=False).to(device) + self.assertTrue(model.num_speakers == 10) + self.assertTrue(not hasattr(model, "emb_g")) + self.assertTrue(model.embedded_speaker_dim == config.model_args.d_vector_dim)