Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ForwardTTSE2E implementations and related API changes #1510

Closed
wants to merge 32 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
f237e4c
Merge pull request #1574 from coqui-ai/update_badge
erogol May 13, 2022
fccda5a
Implement ForwardTTSE2Eg
erogol Apr 4, 2022
8573148
Implement FastPitchE2EConfig
erogol Apr 4, 2022
2a61b8f
Implement ForwardTTSE2E tests
erogol Apr 4, 2022
aea8cb7
Implement FastPitchE2E LJSpeech recipe
erogol Apr 4, 2022
b16613c
Implement ForwardTTSE2E Loss
erogol Apr 4, 2022
c125024
Implement BaseTTSE2E
erogol Apr 4, 2022
28a53c7
Refactor multi-speaker init in ForwardTTS
erogol Apr 4, 2022
775a6ab
Add cond layer in decoder
erogol Apr 4, 2022
760f045
Rename vars in VITS
erogol Apr 4, 2022
0738cb0
Fix Vocoder logging
erogol Apr 4, 2022
9f8d86b
Remove redundancy
erogol Apr 4, 2022
5f9d559
Update import statements
erogol Apr 19, 2022
4556c61
Update fastpitche2e recipe
erogol Apr 19, 2022
231c69b
Remove AP from FastPitchE2e
erogol Apr 19, 2022
e7c5db0
Add missing kernel size attr to transformer layer
erogol Apr 19, 2022
cc57c20
Make plot results more general
erogol Apr 19, 2022
c3fb49b
Refactor ForwardTTS to skip decoder
erogol Apr 19, 2022
6a53b77
Add numpy and torch transforms
erogol Apr 19, 2022
dbe5eb9
Make AP optional in BaseTTS
erogol Apr 19, 2022
4171f4e
Update ForwardTTSE2eLoss
erogol Apr 19, 2022
0b585b4
Refactor TTSDataset to use numpy transforms
erogol Apr 19, 2022
edd59c8
Update ForwardTTSe2e tests
erogol Apr 19, 2022
9291d13
Make style
erogol Apr 19, 2022
96779e7
Return duration by ForwardTTS inference
erogol Apr 19, 2022
ce4f962
Remove remaned trainer functions
erogol Apr 22, 2022
b3fb0e1
Implement get_state_dict
erogol Apr 22, 2022
a05c82f
Fix audio_config handling
erogol Apr 22, 2022
c437db1
Fix dirt
erogol May 17, 2022
8e915b7
Make hifigan discriminator configurable
erogol May 17, 2022
2d29e82
Fix up
erogol May 17, 2022
8adcd1d
Rename `g` as `spk_emb`
erogol May 17, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion TTS/encoder/models/resnet.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import torch
from torch import nn

# from TTS.utils.audio import TorchSTFT
# from TTS.utils.audio.torch_transforms import TorchSTFT
from TTS.encoder.models.base_encoder import BaseEncoder


Expand Down
2 changes: 1 addition & 1 deletion TTS/encoder/utils/generic_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from TTS.encoder.models.lstm import LSTMSpeakerEncoder
from TTS.encoder.models.resnet import ResNetSpeakerEncoder
from TTS.utils.io import save_fsspec
from trainer.io import save_fsspec


class AugmentWAV(object):
Expand Down
178 changes: 178 additions & 0 deletions TTS/tts/configs/fast_pitch_e2e_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
from dataclasses import dataclass, field
from typing import List

from TTS.tts.configs.shared_configs import BaseTTSConfig
from TTS.tts.models.forward_tts_e2e import ForwardTTSE2eArgs


@dataclass
class FastPitchE2eConfig(BaseTTSConfig):
"""Configure `ForwardTTSE2e` as FastPitchE2e model.

Example:

>>> from TTS.tts.configs.fast_pitch_e2e_config import FastPitchE2EConfig
>>> config = FastPitchE2EConfig()

Args:
model (str):
Model name used for selecting the right model at initialization. Defaults to `fast_pitch`.

base_model (str):
Name of the base model being configured as this model so that 🐸 TTS knows it needs to initiate
the base model rather than searching for the `model` implementation. Defaults to `forward_tts`.

model_args (Coqpit):
Model class arguments. Check `FastPitchArgs` for more details. Defaults to `FastPitchArgs()`.

data_dep_init_steps (int):
Number of steps used for computing normalization parameters at the beginning of the training. GlowTTS uses
Activation Normalization that pre-computes normalization stats at the beginning and use the same values
for the rest. Defaults to 10.

speakers_file (str):
Path to the file containing the list of speakers. Needed at inference for loading matching speaker ids to
speaker names. Defaults to `None`.

use_speaker_embedding (bool):
enable / disable using speaker embeddings for multi-speaker models. If set True, the model is
in the multi-speaker mode. Defaults to False.

use_d_vector_file (bool):
enable /disable using external speaker embeddings in place of the learned embeddings. Defaults to False.

d_vector_file (str):
Path to the file including pre-computed speaker embeddings. Defaults to None.

d_vector_dim (int):
Dimension of the external speaker embeddings. Defaults to 0.

optimizer (str):
Name of the model optimizer. Defaults to `Adam`.

optimizer_params (dict):
Arguments of the model optimizer. Defaults to `{"betas": [0.9, 0.998], "weight_decay": 1e-6}`.

lr_scheduler (str):
Name of the learning rate scheduler. Defaults to `Noam`.

lr_scheduler_params (dict):
Arguments of the learning rate scheduler. Defaults to `{"warmup_steps": 4000}`.

lr (float):
Initial learning rate. Defaults to `1e-3`.

grad_clip (float):
Gradient norm clipping value. Defaults to `5.0`.

spec_loss_type (str):
Type of the spectrogram loss. Check `ForwardTTSLoss` for possible values. Defaults to `mse`.

duration_loss_type (str):
Type of the duration loss. Check `ForwardTTSLoss` for possible values. Defaults to `mse`.

use_ssim_loss (bool):
Enable/disable the use of SSIM (Structural Similarity) loss. Defaults to True.

wd (float):
Weight decay coefficient. Defaults to `1e-7`.

ssim_loss_alpha (float):
Weight for the SSIM loss. If set 0, disables the SSIM loss. Defaults to 1.0.

dur_loss_alpha (float):
Weight for the duration predictor's loss. If set 0, disables the huber loss. Defaults to 1.0.

spec_loss_alpha (float):
Weight for the L1 spectrogram loss. If set 0, disables the L1 loss. Defaults to 1.0.

pitch_loss_alpha (float):
Weight for the pitch predictor's loss. If set 0, disables the pitch predictor. Defaults to 1.0.

binary_align_loss_alpha (float):
Weight for the binary loss. If set 0, disables the binary loss. Defaults to 1.0.

binary_loss_warmup_epochs (float):
Number of epochs to gradually increase the binary loss impact. Defaults to 150.

min_seq_len (int):
Minimum input sequence length to be used at training.

max_seq_len (int):
Maximum input sequence length to be used at training. Larger values result in more VRAM usage.
"""

model: str = "fast_pitch_e2e"
base_model: str = "forward_tts_e2e"

# model specific params
# model_args: ForwardTTSE2eArgs = ForwardTTSE2eArgs(vocoder_config=HifiganConfig())
model_args: ForwardTTSE2eArgs = ForwardTTSE2eArgs()

# multi-speaker settings
# num_speakers: int = 0
# speakers_file: str = None
# use_speaker_embedding: bool = False
# use_d_vector_file: bool = False
# d_vector_file: str = False
# d_vector_dim: int = 0
spec_segment_size: int = 30

# optimizer
grad_clip: List[float] = field(default_factory=lambda: [1000, 1000])
lr_gen: float = 0.0002
lr_disc: float = 0.0002
lr_scheduler_gen: str = "ExponentialLR"
lr_scheduler_gen_params: dict = field(default_factory=lambda: {"gamma": 0.999875, "last_epoch": -1})
lr_scheduler_disc: str = "ExponentialLR"
lr_scheduler_disc_params: dict = field(default_factory=lambda: {"gamma": 0.999875, "last_epoch": -1})
scheduler_after_epoch: bool = True
optimizer: str = "AdamW"
optimizer_params: dict = field(default_factory=lambda: {"betas": [0.8, 0.99], "eps": 1e-9, "weight_decay": 0.01})

# encoder loss params
spec_loss_type: str = "mse"
duration_loss_type: str = "mse"
use_ssim_loss: bool = False
ssim_loss_alpha: float = 0.0
spec_loss_alpha: float = 1.0
aligner_loss_alpha: float = 1.0
pitch_loss_alpha: float = 1.0
dur_loss_alpha: float = 1.0
binary_align_loss_alpha: float = 0.1
binary_loss_warmup_epochs: int = 150

# dvocoder loss params
disc_loss_alpha: float = 1.0
gen_loss_alpha: float = 1.0
feat_loss_alpha: float = 1.0
mel_loss_alpha: float = 10.0
multi_scale_stft_loss_alpha: float = 2.5
multi_scale_stft_loss_params: dict = field(
default_factory=lambda: {
"n_ffts": [1024, 2048, 512],
"hop_lengths": [120, 240, 50],
"win_lengths": [600, 1200, 240],
}
)

# data loader params
return_wav: bool = True

# overrides
r: int = 1

# dataset configs
compute_f0: bool = True
f0_cache_path: str = None

# testing
test_sentences: List[str] = field(
default_factory=lambda: [
"It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.",
"Be a voice, not an echo.",
"I'm sorry Dave. I'm afraid I can't do that.",
"This cake is great. It's so delicious and moist.",
"Prior to November 22, 1963.",
]
)
7 changes: 2 additions & 5 deletions TTS/tts/configs/vits_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,12 +116,8 @@ class VitsConfig(BaseTTSConfig):
dur_loss_alpha: float = 1.0
speaker_encoder_loss_alpha: float = 1.0

# data loader params
return_wav: bool = True
compute_linear_spec: bool = True

# overrides
r: int = 1 # DO NOT CHANGE
# r: int = 1 # DO NOT CHANGE
add_blank: bool = True

# testing
Expand All @@ -137,6 +133,7 @@ class VitsConfig(BaseTTSConfig):

# multi-speaker settings
# use speaker embedding layer
# TODO: keep this only in VitsArgs
num_speakers: int = 0
use_speaker_embedding: bool = False
speakers_file: str = None
Expand Down
4 changes: 0 additions & 4 deletions TTS/tts/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ def split_dataset(items, eval_split_max_size=None, eval_split_size=0.01):
"""Split a dataset into train and eval. Consider speaker distribution in multi-speaker training.

Args:
<<<<<<< HEAD
items (List[List]):
A list of samples. Each sample is a list of `[audio_path, text, speaker_id]`.

Expand All @@ -23,9 +22,6 @@ def split_dataset(items, eval_split_max_size=None, eval_split_size=0.01):
eval_split_size (float):
If between 0.0 and 1.0 represents the proportion of the dataset to include in the evaluation set.
If > 1, represents the absolute number of evaluation samples. Defaults to 0.01 (1%).
=======
items (List[List]): A list of samples. Each sample is a list of `[text, audio_path, speaker_id]`.
>>>>>>> Fix docstring
"""
speakers = [item["speaker_name"] for item in items]
is_multi_speaker = len(set(speakers)) > 1
Expand Down
36 changes: 18 additions & 18 deletions TTS/tts/datasets/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from torch.utils.data import Dataset

from TTS.tts.utils.data import prepare_data, prepare_stop_target, prepare_tensor
from TTS.utils.audio import AudioProcessor
from TTS.utils.audio.numpy_transforms import compute_f0, load_wav, wav_to_mel, wav_to_spec

# to prevent too many open files error as suggested here
# https://github.com/pytorch/pytorch/issues/11201#issuecomment-421146936
Expand Down Expand Up @@ -37,9 +37,9 @@ def noise_augment_audio(wav):
class TTSDataset(Dataset):
def __init__(
self,
audio_config: "Coqpit" = None,
outputs_per_step: int = 1,
compute_linear_spec: bool = False,
ap: AudioProcessor = None,
samples: List[Dict] = None,
tokenizer: "TTSTokenizer" = None,
compute_f0: bool = False,
Expand All @@ -64,12 +64,12 @@ def __init__(
If you need something different, you can subclass and override.

Args:
audio_config (Coqpit): Audio configuration.

outputs_per_step (int): Number of time frames predicted per step.

compute_linear_spec (bool): compute linear spectrogram if True.

ap (TTS.tts.utils.AudioProcessor): Audio processor object.

samples (list): List of dataset samples.

tokenizer (TTSTokenizer): tokenizer to convert text to sequence IDs. If None init internally else
Expand Down Expand Up @@ -115,6 +115,7 @@ def __init__(
verbose (bool): Print diagnostic information. Defaults to false.
"""
super().__init__()
self.audio_config = audio_config
self.batch_group_size = batch_group_size
self._samples = samples
self.outputs_per_step = outputs_per_step
Expand All @@ -126,7 +127,6 @@ def __init__(
self.max_audio_len = max_audio_len
self.min_text_len = min_text_len
self.max_text_len = max_text_len
self.ap = ap
self.phoneme_cache_path = phoneme_cache_path
self.speaker_id_mapping = speaker_id_mapping
self.d_vector_mapping = d_vector_mapping
Expand All @@ -146,7 +146,7 @@ def __init__(

if compute_f0:
self.f0_dataset = F0Dataset(
self.samples, self.ap, cache_path=f0_cache_path, precompute_num_workers=precompute_num_workers
self.samples, self.audio_config, cache_path=f0_cache_path, precompute_num_workers=precompute_num_workers
)

if self.verbose:
Expand Down Expand Up @@ -188,7 +188,7 @@ def print_logs(self, level: int = 0) -> None:
print(f"{indent}| > Number of instances : {len(self.samples)}")

def load_wav(self, filename):
waveform = self.ap.load_wav(filename)
waveform = load_wav(filename)
assert waveform.size > 0
return waveform

Expand Down Expand Up @@ -408,7 +408,7 @@ def collate_fn(self, batch):
else:
speaker_ids = None
# compute features
mel = [self.ap.melspectrogram(w).astype("float32") for w in batch["wav"]]
mel = [wav_to_mel(w).astype("float32") for w in batch["wav"]]

mel_lengths = [m.shape[1] for m in mel]

Expand Down Expand Up @@ -455,7 +455,7 @@ def collate_fn(self, batch):
# compute linear spectrogram
linear = None
if self.compute_linear_spec:
linear = [self.ap.spectrogram(w).astype("float32") for w in batch["wav"]]
linear = [wav_to_spec(w).astype("float32") for w in batch["wav"]]
linear = prepare_tensor(linear, self.outputs_per_step)
linear = linear.transpose(0, 2, 1)
assert mel.shape[1] == linear.shape[1]
Expand All @@ -465,13 +465,13 @@ def collate_fn(self, batch):
wav_padded = None
if self.return_wav:
wav_lengths = [w.shape[0] for w in batch["wav"]]
max_wav_len = max(mel_lengths_adjusted) * self.ap.hop_length
max_wav_len = max(mel_lengths_adjusted) * self.audio_config.hop_length
wav_lengths = torch.LongTensor(wav_lengths)
wav_padded = torch.zeros(len(batch["wav"]), 1, max_wav_len)
for i, w in enumerate(batch["wav"]):
mel_length = mel_lengths_adjusted[i]
w = np.pad(w, (0, self.ap.hop_length * self.outputs_per_step), mode="edge")
w = w[: mel_length * self.ap.hop_length]
w = np.pad(w, (0, self.audio_config.hop_length * self.outputs_per_step), mode="edge")
w = w[: mel_length * self.audio_config.hop_length]
wav_padded[i, :, : w.shape[0]] = torch.from_numpy(w)
wav_padded.transpose_(1, 2)

Expand Down Expand Up @@ -647,14 +647,14 @@ class F0Dataset:
def __init__(
self,
samples: Union[List[List], List[Dict]],
ap: "AudioProcessor",
audio_config: "AudioConfig",
verbose=False,
cache_path: str = None,
precompute_num_workers=0,
normalize_f0=True,
):
self.samples = samples
self.ap = ap
self.audio_config = audio_config
self.verbose = verbose
self.cache_path = cache_path
self.normalize_f0 = normalize_f0
Expand Down Expand Up @@ -711,9 +711,9 @@ def create_pitch_file_path(wav_file, cache_path):
return pitch_file

@staticmethod
def _compute_and_save_pitch(ap, wav_file, pitch_file=None):
wav = ap.load_wav(wav_file)
pitch = ap.compute_f0(wav)
def _compute_and_save_pitch(audio_config, wav_file, pitch_file=None):
wav = load_wav(wav_file)
pitch = compute_f0(x=wav, pitch_fmax=audio_config.pitch_fmax, hop_length=audio_config.hop_length, sample_rate=audio_config.sample_rate)
if pitch_file:
np.save(pitch_file, pitch)
return pitch
Expand Down Expand Up @@ -750,7 +750,7 @@ def compute_or_load(self, wav_file):
"""
pitch_file = self.create_pitch_file_path(wav_file, self.cache_path)
if not os.path.exists(pitch_file):
pitch = self._compute_and_save_pitch(self.ap, wav_file, pitch_file)
pitch = self._compute_and_save_pitch(self.audio_config, wav_file, pitch_file)
else:
pitch = np.load(pitch_file)
return pitch.astype(np.float32)
Expand Down
Loading