Skip to content

Commit

Permalink
trainer-API updates #1
Browse files Browse the repository at this point in the history
  • Loading branch information
erogol committed May 20, 2021
1 parent 128cc79 commit 7462bcc
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 32 deletions.
29 changes: 16 additions & 13 deletions TTS/tts/utils/speakers.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import json
import os
import random
from typing import Union
from typing import Union, List, Any

import numpy as np
import torch
Expand Down Expand Up @@ -35,9 +35,7 @@ def save_speaker_mapping(out_path, speaker_mapping):


def get_speakers(items):
"""Returns a sorted, unique list of speakers in a given dataset."""
speakers = {e[2] for e in items}
return sorted(speakers)



def parse_speakers(c, args, meta_data_train, OUT_PATH):
Expand Down Expand Up @@ -121,26 +119,31 @@ class SpeakerManager:
Args:
x_vectors_file_path (str, optional): Path to the metafile including x vectors. Defaults to "".
speaker_id_file_path (str, optional): Path to the metafile that maps speaker names to ids used by the
TTS model. Defaults to "".
speaker_id_file_path (str, optional): Path to the metafile that maps speaker names to ids used by
TTS models. Defaults to "".
encoder_model_path (str, optional): Path to the speaker encoder model file. Defaults to "".
encoder_config_path (str, optional): Path to the spealer encoder config file. Defaults to "".
"""

def __init__(
self,
data_items: List[List[Any]] = None,
x_vectors_file_path: str = "",
speaker_id_file_path: str = "",
encoder_model_path: str = "",
encoder_config_path: str = "",
):

self.x_vectors = None
self.speaker_ids = None
self.clip_ids = None
self.data_items = []
self.x_vectors = []
self.speaker_ids = []
self.clip_ids = []
self.speaker_encoder = None
self.speaker_encoder_ap = None

if data_items:
self.speaker_ids = self.parse_speakers()

if x_vectors_file_path:
self.load_x_vectors_file(x_vectors_file_path)

Expand Down Expand Up @@ -169,10 +172,10 @@ def x_vector_dim(self):
return len(self.x_vectors[list(self.x_vectors.keys())[0]]["embedding"])

def parser_speakers_from_items(self, items: list):
speaker_ids = sorted({item[2] for item in items})
self.speaker_ids = speaker_ids
num_speakers = len(speaker_ids)
return speaker_ids, num_speakers
speakers = sorted({item[2] for item in items})
self.speaker_ids = {name: i for i, name in enumerate(speakers)}
num_speakers = len(self.speaker_ids)
return self.speaker_ids, num_speakers

def save_ids_file(self, file_path: str):
self._save_json(file_path, self.speaker_ids)
Expand Down
24 changes: 5 additions & 19 deletions TTS/tts/utils/text/cleaners.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,5 @@
"""
Cleaners are transformations that run over the input text at both training and eval time.
Cleaners can be selected by passing a comma-delimited list of cleaner names as the "cleaners"
hyperparameter. Some cleaners are English-specific. You'll typically want to use:
1. "english_cleaners" for English text
2. "transliteration_cleaners" for non-English text that can be transliterated to ASCII using
the Unidecode library (https://pypi.python.org/pypi/Unidecode)
3. "basic_cleaners" if you do not want to transliterate (in this case, you should also update
the symbols in symbols.py to match your data).
"""

import re

from unidecode import unidecode

from TTS.tts.utils.text.chinese_mandarin.numbers import replace_numbers_to_characters_in_text

from .abbreviations import abbreviations_en, abbreviations_fr
Expand Down Expand Up @@ -46,8 +32,8 @@ def collapse_whitespace(text):
return re.sub(_whitespace_re, " ", text).strip()


def convert_to_ascii(text):
return unidecode(text)
# def convert_to_ascii(text):
# return unidecode(text)


def remove_aux_symbols(text):
Expand Down Expand Up @@ -77,7 +63,7 @@ def basic_cleaners(text):

def transliteration_cleaners(text):
"""Pipeline for non-English text that transliterates to ASCII."""
text = convert_to_ascii(text)
# text = convert_to_ascii(text)
text = lowercase(text)
text = collapse_whitespace(text)
return text
Expand All @@ -101,7 +87,7 @@ def basic_turkish_cleaners(text):

def english_cleaners(text):
"""Pipeline for English text, including number and abbreviation expansion."""
text = convert_to_ascii(text)
# text = convert_to_ascii(text)
text = lowercase(text)
text = expand_time_english(text)
text = expand_numbers(text)
Expand Down Expand Up @@ -141,7 +127,7 @@ def chinese_mandarin_cleaners(text: str) -> str:
def phoneme_cleaners(text):
"""Pipeline for phonemes mode, including number and abbreviation expansion."""
text = expand_numbers(text)
text = convert_to_ascii(text)
# text = convert_to_ascii(text)
text = expand_abbreviations(text)
text = replace_symbols(text)
text = remove_aux_symbols(text)
Expand Down
1 change: 1 addition & 0 deletions tests/vocoder_tests/test_melgan_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
eval_split_size=1,
print_step=1,
print_eval=True,
discriminator_model_params={"base_channels": 16, "max_channels": 256, "downsample_factors": [4, 4, 4]},
data_path="tests/data/ljspeech",
output_path=output_path,
)
Expand Down

0 comments on commit 7462bcc

Please sign in to comment.