Skip to content

Commit

Permalink
Fix the VITS upsampling asserts
Browse files Browse the repository at this point in the history
  • Loading branch information
Edresson committed May 7, 2022
1 parent 3f03e30 commit b51e54c
Showing 1 changed file with 12 additions and 4 deletions.
16 changes: 12 additions & 4 deletions TTS/tts/models/vits.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@
hann_window = {}
mel_basis = {}


def load_audio(file_path):
"""Load the audio file normalized in [-1, 1]
Expand Down Expand Up @@ -189,15 +188,20 @@ def wav_to_mel(y, n_fft, num_mels, sample_rate, hop_length, win_length, fmin, fm


class VitsDataset(TTSDataset):
def __init__(self, *args, **kwargs):
def __init__(self, model_args, *args, **kwargs):
super().__init__(*args, **kwargs)
self.pad_id = self.tokenizer.characters.pad_id
self.model_args = model_args

def __getitem__(self, idx):
item = self.samples[idx]
raw_text = item["text"]

wav, _ = load_audio(item["audio_file"])
if self.model_args.encoder_sample_rate is not None:
if wav.size(1) % self.model_args.encoder_sample_rate != 0:
wav = wav[:, : -int(wav.size(1) % self.model_args.encoder_sample_rate)]

wav_filename = os.path.basename(item["audio_file"])

token_ids = self.get_token_ids(idx, item["text"])
Expand Down Expand Up @@ -1401,8 +1405,11 @@ def format_batch_on_device(self, batch):
if self.args.encoder_sample_rate:
# recompute spec with high sampling rate to the loss
spec_mel = wav_to_spec(batch["waveform"], ac.fft_size, ac.hop_length, ac.win_length, center=False)
# remove extra stft frame
spec_mel = spec_mel[:, :, : int(batch["spec"].size(2) * self.interpolate_factor)]
# remove extra stft frames if needed
if spec_mel.size(2) > int(batch["spec"].size(2) * self.interpolate_factor):
spec_mel = spec_mel[:, :, : int(batch["spec"].size(2) * self.interpolate_factor)]
else:
batch["spec"] = batch["spec"][:, :, : int(spec_mel.size(2) / self.interpolate_factor)]
else:
spec_mel = batch["spec"]

Expand Down Expand Up @@ -1451,6 +1458,7 @@ def get_data_loader(
else:
# init dataloader
dataset = VitsDataset(
model_args=self.args,
samples=samples,
batch_group_size=0 if is_eval else config.batch_group_size * config.batch_size,
min_text_len=config.min_text_len,
Expand Down

0 comments on commit b51e54c

Please sign in to comment.