diff --git a/parallel_wavegan/utils/utils.py b/parallel_wavegan/utils/utils.py index 43cfab83..0e0ba992 100644 --- a/parallel_wavegan/utils/utils.py +++ b/parallel_wavegan/utils/utils.py @@ -8,6 +8,7 @@ import fnmatch import logging import os +import re import sys import tarfile @@ -359,22 +360,48 @@ def load_model(checkpoint, config=None, stats=None): return model -def download_pretrained_model(tag, download_dir=None): +def download_pretrained_model(tag_or_url, download_dir=None): """Download pretrained model form google drive. Args: - tag (str): Pretrained model tag. + tag_or_url (str): Pretrained model tag or the google drive url for the model. download_dir (str): Directory to save downloaded files. Returns: str: Path of downloaded model checkpoint. + Examples: + # Download by specifying tag + >>> from parallel_wavegan.utils import download_pretrained_model + >>> tag = "ljspeech_parallel_wavegan.v1" + >>> download_path = download_pretrained_model(tag) + + # Download by specifying URL + >>> from parallel_wavegan.utils import download_pretrained_model + >>> url = "https://drive.google.com/file/d/10GYvB_mIKzXzSjD67tSnBhknZRoBjsNb" + >>> download_path = download_pretrained_model(url) + + # The following URL also works + >>> url = "https://drive.google.com/file/d/10GYvB_mIKzXzSjD67tSnBhknZRoBjsNb/view?usp=sharing" + >>> download_path = download_pretrained_model(url) + """ - assert tag in PRETRAINED_MODEL_LIST, f"{tag} does not exists." - id_ = PRETRAINED_MODEL_LIST[tag] if download_dir is None: download_dir = os.path.expanduser("~/.cache/parallel_wavegan") - output_path = f"{download_dir}/{tag}.tar.gz" + if tag_or_url in PRETRAINED_MODEL_LIST: + id_ = PRETRAINED_MODEL_LIST[tag_or_url] + output_path = f"{download_dir}/{tag_or_url}.tar.gz" + tag = tag_or_url + else: + # get google drive id from the url link + assert ( + "drive.google.com" in tag_or_url + ), "Unknown URL format. Please use google drive for the model." + p = re.compile(r"/[-\w]{25,}") + id_ = p.findall(tag_or_url)[0][1:] + tag = id_ + output_path = f"{download_dir}/{id_}.tar.gz" + os.makedirs(f"{download_dir}", exist_ok=True) with FileLock(output_path + ".lock"): if not os.path.exists(output_path):