Skip to content

Commit

Permalink
Add support for model_info in CLI (#1623)
Browse files Browse the repository at this point in the history
* model_info

* model_info

* model_info_by_idx and name

* model_info_by_idx and name

* model_info

* Update manage.py

* fixed linter

* fixed linter

* fixed linter

* fixed linter

* fixed return style checks

* fixed linter

* fixed linter

* fixed idx always positive

* added comments

* fix parser.args check

* fix parser.args check

* Make style

Co-authored-by: Eren G??lge <egolge@coqui.ai>
  • Loading branch information
p0p4k and erogol authored Jun 20, 2022
1 parent 8b75e8b commit 71281ff
Show file tree
Hide file tree
Showing 2 changed files with 127 additions and 11 deletions.
63 changes: 52 additions & 11 deletions TTS/bin/synthesize.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,18 @@ def main():
$ tts --list_models
```
- Query info for model info by idx:
```
$ tts --model_info_by_idx "<model_type>/<model_query_idx>"
```
- Query info for model info by full name:
```
$ tts --model_info_by_name "<model_type>/<language>/<dataset>/<model_name>"
```
- Run TTS with default models:
```
Expand All @@ -48,7 +60,7 @@ def main():
- Run a TTS model with its default vocoder model:
```
$ tts --text "Text for TTS" --model_name "<language>/<dataset>/<model_name>
$ tts --text "Text for TTS" --model_name "<language>/<dataset>/<model_name>"
```
- Run with specific TTS and vocoder models from the list:
Expand Down Expand Up @@ -104,6 +116,21 @@ def main():
default=False,
help="list available pre-trained TTS and vocoder models.",
)

parser.add_argument(
"--model_info_by_idx",
type=str,
default=None,
help="model info using query format: <model_type>/<model_query_idx>",
)

parser.add_argument(
"--model_info_by_name",
type=str,
default=None,
help="model info using query format: <model_type>/<language>/<dataset>/<model_name>",
)

parser.add_argument("--text", type=str, default=None, help="Text to generate speech.")

# Args for running pre-trained TTS models.
Expand Down Expand Up @@ -214,13 +241,16 @@ def main():
args = parser.parse_args()

# print the description if either text or list_models is not set
if (
not args.text
and not args.list_models
and not args.list_speaker_idxs
and not args.list_language_idxs
and not args.reference_wav
):
check_args = [
args.text,
args.list_models,
args.list_speaker_idxs,
args.list_language_idxs,
args.reference_wav,
args.model_info_by_idx,
args.model_info_by_name,
]
if not any(check_args):
parser.parse_args(["-h"])

# load model manager
Expand All @@ -236,20 +266,31 @@ def main():
encoder_path = None
encoder_config_path = None

# CASE1: list pre-trained TTS models
# CASE1 #list : list pre-trained TTS models
if args.list_models:
manager.list_models()
sys.exit()

# CASE2: load pre-trained model paths
# CASE2 #info : model info of pre-trained TTS models
if args.model_info_by_idx:
model_query = args.model_info_by_idx
manager.model_info_by_idx(model_query)
sys.exit()

if args.model_info_by_name:
model_query_full_name = args.model_info_by_name
manager.model_info_by_full_name(model_query_full_name)
sys.exit()

# CASE3: load pre-trained model paths
if args.model_name is not None and not args.model_path:
model_path, config_path, model_item = manager.download_model(args.model_name)
args.vocoder_name = model_item["default_vocoder"] if args.vocoder_name is None else args.vocoder_name

if args.vocoder_name is not None and not args.vocoder_path:
vocoder_path, vocoder_config_path, _ = manager.download_model(args.vocoder_name)

# CASE3: set custom model paths
# CASE4: set custom model paths
if args.model_path is not None:
model_path = args.model_path
config_path = args.config_path
Expand Down
75 changes: 75 additions & 0 deletions TTS/utils/manage.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,81 @@ def list_models(self):
models_name_list.extend(model_list)
return models_name_list

def model_info_by_idx(self, model_query):
"""Print the description of the model from .models.json file using model_idx
Args:
model_query (str): <model_tye>/<model_idx>
"""
model_name_list = []
model_type, model_query_idx = model_query.split("/")
try:
model_query_idx = int(model_query_idx)
if model_query_idx <= 0:
print("> model_query_idx should be a positive integer!")
return
except:
print("> model_query_idx should be an integer!")
return
model_count = 0
if model_type in self.models_dict:
for lang in self.models_dict[model_type]:
for dataset in self.models_dict[model_type][lang]:
for model in self.models_dict[model_type][lang][dataset]:
model_name_list.append(f"{model_type}/{lang}/{dataset}/{model}")
model_count += 1
else:
print(f"> model_type {model_type} does not exist in the list.")
return
if model_query_idx > model_count:
print(f"model query idx exceeds the number of available models [{model_count}] ")
else:
model_type, lang, dataset, model = model_name_list[model_query_idx - 1].split("/")
print(f"> model type : {model_type}")
print(f"> language supported : {lang}")
print(f"> dataset used : {dataset}")
print(f"> model name : {model}")
if "description" in self.models_dict[model_type][lang][dataset][model]:
print(f"> description : {self.models_dict[model_type][lang][dataset][model]['description']}")
else:
print("> description : coming soon")
if "default_vocoder" in self.models_dict[model_type][lang][dataset][model]:
print(f"> default_vocoder : {self.models_dict[model_type][lang][dataset][model]['default_vocoder']}")

def model_info_by_full_name(self, model_query_name):
"""Print the description of the model from .models.json file using model_full_name
Args:
model_query_name (str): Format is <model_type>/<language>/<dataset>/<model_name>
"""
model_type, lang, dataset, model = model_query_name.split("/")
if model_type in self.models_dict:
if lang in self.models_dict[model_type]:
if dataset in self.models_dict[model_type][lang]:
if model in self.models_dict[model_type][lang][dataset]:
print(f"> model type : {model_type}")
print(f"> language supported : {lang}")
print(f"> dataset used : {dataset}")
print(f"> model name : {model}")
if "description" in self.models_dict[model_type][lang][dataset][model]:
print(
f"> description : {self.models_dict[model_type][lang][dataset][model]['description']}"
)
else:
print("> description : coming soon")
if "default_vocoder" in self.models_dict[model_type][lang][dataset][model]:
print(
f"> default_vocoder : {self.models_dict[model_type][lang][dataset][model]['default_vocoder']}"
)
else:
print(f"> model {model} does not exist for {model_type}/{lang}/{dataset}.")
else:
print(f"> dataset {dataset} does not exist for {model_type}/{lang}.")
else:
print(f"> lang {lang} does not exist for {model_type}.")
else:
print(f"> model_type {model_type} does not exist in the list.")

def list_tts_models(self):
"""Print all `TTS` models and return a list of model names
Expand Down

0 comments on commit 71281ff

Please sign in to comment.