Skip to content

Commit

Permalink
Redownload XTTS with the local and remote config do not match
Browse files Browse the repository at this point in the history
  • Loading branch information
Edresson committed Oct 6, 2023
1 parent 0520697 commit 4a6103f
Showing 1 changed file with 47 additions and 21 deletions.
68 changes: 47 additions & 21 deletions TTS/utils/manage.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from shutil import copyfile, rmtree
from typing import Dict, List, Tuple

import fsspec
import requests
from tqdm import tqdm

Expand Down Expand Up @@ -320,6 +321,31 @@ def tos_agreed(self, model_item, model_full_path):
return False
return True

def check_if_files_size(self, model_name):
pass

def create_dir_and_download_model(self, model_name, model_item, output_path):
os.makedirs(output_path, exist_ok=True)
# handle TOS
if not self.tos_agreed(model_item, output_path):
if not self.ask_tos(output_path):
os.rmdir(output_path)
raise Exception(" [!] You must agree to the terms of service to use this model.")
print(f" > Downloading model to {output_path}")
try:
if "fairseq" in model_name:
self.download_fairseq_model(model_name, output_path)
elif "github_rls_url" in model_item:
self._download_github_model(model_item, output_path)
elif "hf_url" in model_item:
self._download_hf_model(model_item, output_path)

except requests.RequestException as e:
print(f" > Failed to download the model file to {output_path}")
rmtree(output_path)
raise e
self.print_model_license(model_item=model_item)

def download_model(self, model_name):
"""Download model files given the full model name.
Model name is in the format
Expand All @@ -338,28 +364,28 @@ def download_model(self, model_name):
# set the model specific output path
output_path = os.path.join(self.output_prefix, model_full_name)
if os.path.exists(output_path):
print(f" > {model_name} is already downloaded.")
# if the configs are different, redownload it
# ToDo: we need a better way to handle it
if "xtts_v1" in model_name:
with fsspec.open(self._find_files(output_path)[1], "r", encoding="utf-8") as f:
config_local = json.load(f)
remote_url = None
for url in model_item["hf_url"]:
if "config.json" in url:
remote_url = url
break

with fsspec.open(remote_url, "r", encoding="utf-8") as f:
config_remote = json.load(f)

if not config_local == config_remote:
print(f" > {model_name} is already downloaded however it has been changed. Redownloading it...")
self.create_dir_and_download_model(model_name, model_item, output_path)
else:
print(f" > {model_name} is already downloaded.")
else:
os.makedirs(output_path, exist_ok=True)
# handle TOS
if not self.tos_agreed(model_item, output_path):
if not self.ask_tos(output_path):
os.rmdir(output_path)
raise Exception(" [!] You must agree to the terms of service to use this model.")
print(f" > Downloading model to {output_path}")
try:
if "fairseq" in model_name:
self.download_fairseq_model(model_name, output_path)
elif "github_rls_url" in model_item:
self._download_github_model(model_item, output_path)
elif "hf_url" in model_item:
self._download_hf_model(model_item, output_path)

except requests.RequestException as e:
print(f" > Failed to download the model file to {output_path}")
rmtree(output_path)
raise e
self.print_model_license(model_item=model_item)
self.create_dir_and_download_model(model_name, model_item, output_path)

# find downloaded files
output_model_path = output_path
output_config_path = None
Expand Down

0 comments on commit 4a6103f

Please sign in to comment.