From 8baa9e119ab549e5504f9b42d2f4c74e6e61f4e1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Thu, 20 May 2021 18:23:53 +0200 Subject: [PATCH] trainer-API updates #1 --- TTS/tts/utils/speakers.py | 29 +++++++++++++----------- TTS/tts/utils/text/cleaners.py | 6 ++--- tests/vocoder_tests/test_melgan_train.py | 1 + 3 files changed, 20 insertions(+), 16 deletions(-) diff --git a/TTS/tts/utils/speakers.py b/TTS/tts/utils/speakers.py index 84da1f72ee..4ab78f8836 100755 --- a/TTS/tts/utils/speakers.py +++ b/TTS/tts/utils/speakers.py @@ -1,7 +1,7 @@ import json import os import random -from typing import Union +from typing import Union, List, Any import numpy as np import torch @@ -35,9 +35,7 @@ def save_speaker_mapping(out_path, speaker_mapping): def get_speakers(items): - """Returns a sorted, unique list of speakers in a given dataset.""" - speakers = {e[2] for e in items} - return sorted(speakers) + def parse_speakers(c, args, meta_data_train, OUT_PATH): @@ -121,26 +119,31 @@ class SpeakerManager: Args: x_vectors_file_path (str, optional): Path to the metafile including x vectors. Defaults to "". - speaker_id_file_path (str, optional): Path to the metafile that maps speaker names to ids used by the - TTS model. Defaults to "". + speaker_id_file_path (str, optional): Path to the metafile that maps speaker names to ids used by + TTS models. Defaults to "". encoder_model_path (str, optional): Path to the speaker encoder model file. Defaults to "". encoder_config_path (str, optional): Path to the spealer encoder config file. Defaults to "". """ def __init__( self, + data_items: List[List[Any]] = None, x_vectors_file_path: str = "", speaker_id_file_path: str = "", encoder_model_path: str = "", encoder_config_path: str = "", ): - self.x_vectors = None - self.speaker_ids = None - self.clip_ids = None + self.data_items = [] + self.x_vectors = [] + self.speaker_ids = [] + self.clip_ids = [] self.speaker_encoder = None self.speaker_encoder_ap = None + if data_items: + self.speaker_ids = self.parse_speakers() + if x_vectors_file_path: self.load_x_vectors_file(x_vectors_file_path) @@ -169,10 +172,10 @@ def x_vector_dim(self): return len(self.x_vectors[list(self.x_vectors.keys())[0]]["embedding"]) def parser_speakers_from_items(self, items: list): - speaker_ids = sorted({item[2] for item in items}) - self.speaker_ids = speaker_ids - num_speakers = len(speaker_ids) - return speaker_ids, num_speakers + speakers = sorted({item[2] for item in items}) + self.speaker_ids = {name: i for i, name in enumerate(speakers)} + num_speakers = len(self.speaker_ids) + return self.speaker_ids, num_speakers def save_ids_file(self, file_path: str): self._save_json(file_path, self.speaker_ids) diff --git a/TTS/tts/utils/text/cleaners.py b/TTS/tts/utils/text/cleaners.py index 3d2caa9764..4b041ed845 100644 --- a/TTS/tts/utils/text/cleaners.py +++ b/TTS/tts/utils/text/cleaners.py @@ -65,7 +65,7 @@ def basic_cleaners(text): def transliteration_cleaners(text): """Pipeline for non-English text that transliterates to ASCII.""" - text = convert_to_ascii(text) + # text = convert_to_ascii(text) text = lowercase(text) text = collapse_whitespace(text) return text @@ -89,7 +89,7 @@ def basic_turkish_cleaners(text): def english_cleaners(text): """Pipeline for English text, including number and abbreviation expansion.""" - text = convert_to_ascii(text) + # text = convert_to_ascii(text) text = lowercase(text) text = expand_time_english(text) text = expand_numbers(text) @@ -129,7 +129,7 @@ def chinese_mandarin_cleaners(text: str) -> str: def phoneme_cleaners(text): """Pipeline for phonemes mode, including number and abbreviation expansion.""" text = expand_numbers(text) - text = convert_to_ascii(text) + # text = convert_to_ascii(text) text = expand_abbreviations(text) text = replace_symbols(text) text = remove_aux_symbols(text) diff --git a/tests/vocoder_tests/test_melgan_train.py b/tests/vocoder_tests/test_melgan_train.py index 3ff65b5af4..e3004db7e7 100644 --- a/tests/vocoder_tests/test_melgan_train.py +++ b/tests/vocoder_tests/test_melgan_train.py @@ -21,6 +21,7 @@ print_step=1, discriminator_model_params={"base_channels": 16, "max_channels": 256, "downsample_factors": [4, 4, 4]}, print_eval=True, + discriminator_model_params={"base_channels": 16, "max_channels": 256, "downsample_factors": [4, 4, 4]}, data_path="tests/data/ljspeech", output_path=output_path, )