Skip to content

Commit

Permalink
trainer-API updates #1
Browse files Browse the repository at this point in the history
  • Loading branch information
erogol committed May 27, 2021
1 parent 0d58360 commit 8baa9e1
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 16 deletions.
29 changes: 16 additions & 13 deletions TTS/tts/utils/speakers.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import json
import os
import random
from typing import Union
from typing import Union, List, Any

import numpy as np
import torch
Expand Down Expand Up @@ -35,9 +35,7 @@ def save_speaker_mapping(out_path, speaker_mapping):


def get_speakers(items):
"""Returns a sorted, unique list of speakers in a given dataset."""
speakers = {e[2] for e in items}
return sorted(speakers)



def parse_speakers(c, args, meta_data_train, OUT_PATH):
Expand Down Expand Up @@ -121,26 +119,31 @@ class SpeakerManager:
Args:
x_vectors_file_path (str, optional): Path to the metafile including x vectors. Defaults to "".
speaker_id_file_path (str, optional): Path to the metafile that maps speaker names to ids used by the
TTS model. Defaults to "".
speaker_id_file_path (str, optional): Path to the metafile that maps speaker names to ids used by
TTS models. Defaults to "".
encoder_model_path (str, optional): Path to the speaker encoder model file. Defaults to "".
encoder_config_path (str, optional): Path to the spealer encoder config file. Defaults to "".
"""

def __init__(
self,
data_items: List[List[Any]] = None,
x_vectors_file_path: str = "",
speaker_id_file_path: str = "",
encoder_model_path: str = "",
encoder_config_path: str = "",
):

self.x_vectors = None
self.speaker_ids = None
self.clip_ids = None
self.data_items = []
self.x_vectors = []
self.speaker_ids = []
self.clip_ids = []
self.speaker_encoder = None
self.speaker_encoder_ap = None

if data_items:
self.speaker_ids = self.parse_speakers()

if x_vectors_file_path:
self.load_x_vectors_file(x_vectors_file_path)

Expand Down Expand Up @@ -169,10 +172,10 @@ def x_vector_dim(self):
return len(self.x_vectors[list(self.x_vectors.keys())[0]]["embedding"])

def parser_speakers_from_items(self, items: list):
speaker_ids = sorted({item[2] for item in items})
self.speaker_ids = speaker_ids
num_speakers = len(speaker_ids)
return speaker_ids, num_speakers
speakers = sorted({item[2] for item in items})
self.speaker_ids = {name: i for i, name in enumerate(speakers)}
num_speakers = len(self.speaker_ids)
return self.speaker_ids, num_speakers

def save_ids_file(self, file_path: str):
self._save_json(file_path, self.speaker_ids)
Expand Down
6 changes: 3 additions & 3 deletions TTS/tts/utils/text/cleaners.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def basic_cleaners(text):

def transliteration_cleaners(text):
"""Pipeline for non-English text that transliterates to ASCII."""
text = convert_to_ascii(text)
# text = convert_to_ascii(text)
text = lowercase(text)
text = collapse_whitespace(text)
return text
Expand All @@ -89,7 +89,7 @@ def basic_turkish_cleaners(text):

def english_cleaners(text):
"""Pipeline for English text, including number and abbreviation expansion."""
text = convert_to_ascii(text)
# text = convert_to_ascii(text)
text = lowercase(text)
text = expand_time_english(text)
text = expand_numbers(text)
Expand Down Expand Up @@ -129,7 +129,7 @@ def chinese_mandarin_cleaners(text: str) -> str:
def phoneme_cleaners(text):
"""Pipeline for phonemes mode, including number and abbreviation expansion."""
text = expand_numbers(text)
text = convert_to_ascii(text)
# text = convert_to_ascii(text)
text = expand_abbreviations(text)
text = replace_symbols(text)
text = remove_aux_symbols(text)
Expand Down
1 change: 1 addition & 0 deletions tests/vocoder_tests/test_melgan_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
print_step=1,
discriminator_model_params={"base_channels": 16, "max_channels": 256, "downsample_factors": [4, 4, 4]},
print_eval=True,
discriminator_model_params={"base_channels": 16, "max_channels": 256, "downsample_factors": [4, 4, 4]},
data_path="tests/data/ljspeech",
output_path=output_path,
)
Expand Down

0 comments on commit 8baa9e1

Please sign in to comment.