diff --git a/monai/bundle/scripts.py b/monai/bundle/scripts.py index 7cabc0cc2b..9379532465 100644 --- a/monai/bundle/scripts.py +++ b/monai/bundle/scripts.py @@ -19,7 +19,7 @@ from collections.abc import Mapping, Sequence from pathlib import Path from pydoc import locate -from shutil import copyfile +from shutil import copyfile, copytree, rmtree from textwrap import dedent from typing import Any, Callable @@ -193,6 +193,15 @@ def _download_from_ngc( extractall(filepath=filepath, output_dir=extract_path, has_base=True) +def _download_from_huggingface_hub(repo: str, download_path: str, filename: str) -> None: + if len(repo.split("/")) != 2: + raise ValueError("if source is `hf_hub`, repo should be in the form `repo_owner/repo_name`") + snapshot_folder = huggingface_hub.snapshot_download(repo_id=repo, cache_dir=download_path) + download_dir = os.path.join(download_path, filename) + copytree(snapshot_folder, download_dir, dirs_exist_ok=True) + rmtree(snapshot_folder) + + def _get_latest_bundle_version(source: str, name: str, repo: str) -> dict[str, list[str] | str] | Any | None: if source == "ngc": name = _add_ngc_prefix(name) @@ -248,6 +257,9 @@ def download( # Execute this module as a CLI entry, and download bundle from ngc with latest version: python -m monai.bundle download --name --source "ngc" --bundle_dir "./" + # Execute this module as a CLI entry, and download bundle from Hugging Face Hub: + python -m monai.bundle download --name "bundle_name" --source "huggingface_hub" --repo "repo_owner/repo_name" + # Execute this module as a CLI entry, and download bundle via URL: python -m monai.bundle download --name --url @@ -271,9 +283,10 @@ def download( Default is `bundle` subfolder under `torch.hub.get_dir()`. source: storage location name. This argument is used when `url` is `None`. In default, the value is achieved from the environment variable BUNDLE_DOWNLOAD_SRC, and - it should be "ngc" or "github". - repo: repo name. This argument is used when `url` is `None` and `source` is "github". - If used, it should be in the form of "repo_owner/repo_name/release_tag". + it should be "ngc", "github", or "huggingface_hub". + repo: repo name. This argument is used when `url` is `None` and `source` is "github" or "huggingface_hub". + If `source` is "github", it should be in the form of "repo_owner/repo_name/release_tag". + If `source` is "huggingface_hub", it should be in the form of "repo_owner/repo_name". url: url to download the data. If not `None`, data will be downloaded directly and `source` will not be checked. If `name` is `None`, filename is determined by `monai.apps.utils._basename(url)`. @@ -333,9 +346,17 @@ def download( remove_prefix=remove_prefix_, progress=progress_, ) + elif source_ == "huggingface_hub": + if name_ is None: + raise ValueError(f"To download from source: 'huggingface_hub', `name` must be provided, got {name_}.") + _download_from_huggingface_hub( + repo=repo_, + download_path=bundle_dir_, + filename=name_ + ) else: raise NotImplementedError( - f"Currently only download from `url`, source 'github' or 'ngc' are implemented, got source: {source_}." + f"Currently only download from `url`, source 'github', 'ngc', or 'huggingface_hub' are implemented, got source: {source_}." )