Skip to content

Commit

Permalink
Add download from huggingface_hub functionality
Browse files Browse the repository at this point in the history
  • Loading branch information
katielink committed May 1, 2023
1 parent b7d462d commit be3e678
Showing 1 changed file with 26 additions and 5 deletions.
31 changes: 26 additions & 5 deletions monai/bundle/scripts.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from collections.abc import Mapping, Sequence
from pathlib import Path
from pydoc import locate
from shutil import copyfile
from shutil import copyfile, copytree, rmtree
from textwrap import dedent
from typing import Any, Callable

Expand Down Expand Up @@ -193,6 +193,15 @@ def _download_from_ngc(
extractall(filepath=filepath, output_dir=extract_path, has_base=True)


def _download_from_huggingface_hub(repo: str, download_path: str, filename: str) -> None:
if len(repo.split("/")) != 2:
raise ValueError("if source is `hf_hub`, repo should be in the form `repo_owner/repo_name`")
snapshot_folder = huggingface_hub.snapshot_download(repo_id=repo, cache_dir=download_path)
download_dir = os.path.join(download_path, filename)
copytree(snapshot_folder, download_dir, dirs_exist_ok=True)
rmtree(snapshot_folder)


def _get_latest_bundle_version(source: str, name: str, repo: str) -> dict[str, list[str] | str] | Any | None:
if source == "ngc":
name = _add_ngc_prefix(name)
Expand Down Expand Up @@ -248,6 +257,9 @@ def download(
# Execute this module as a CLI entry, and download bundle from ngc with latest version:
python -m monai.bundle download --name <bundle_name> --source "ngc" --bundle_dir "./"
# Execute this module as a CLI entry, and download bundle from Hugging Face Hub:
python -m monai.bundle download --name "bundle_name" --source "huggingface_hub" --repo "repo_owner/repo_name"
# Execute this module as a CLI entry, and download bundle via URL:
python -m monai.bundle download --name <bundle_name> --url <url>
Expand All @@ -271,9 +283,10 @@ def download(
Default is `bundle` subfolder under `torch.hub.get_dir()`.
source: storage location name. This argument is used when `url` is `None`.
In default, the value is achieved from the environment variable BUNDLE_DOWNLOAD_SRC, and
it should be "ngc" or "github".
repo: repo name. This argument is used when `url` is `None` and `source` is "github".
If used, it should be in the form of "repo_owner/repo_name/release_tag".
it should be "ngc", "github", or "huggingface_hub".
repo: repo name. This argument is used when `url` is `None` and `source` is "github" or "huggingface_hub".
If `source` is "github", it should be in the form of "repo_owner/repo_name/release_tag".
If `source` is "huggingface_hub", it should be in the form of "repo_owner/repo_name".
url: url to download the data. If not `None`, data will be downloaded directly
and `source` will not be checked.
If `name` is `None`, filename is determined by `monai.apps.utils._basename(url)`.
Expand Down Expand Up @@ -333,9 +346,17 @@ def download(
remove_prefix=remove_prefix_,
progress=progress_,
)
elif source_ == "huggingface_hub":
if name_ is None:
raise ValueError(f"To download from source: 'huggingface_hub', `name` must be provided, got {name_}.")
_download_from_huggingface_hub(
repo=repo_,
download_path=bundle_dir_,
filename=name_
)
else:
raise NotImplementedError(
f"Currently only download from `url`, source 'github' or 'ngc' are implemented, got source: {source_}."
f"Currently only download from `url`, source 'github', 'ngc', or 'huggingface_hub' are implemented, got source: {source_}."
)


Expand Down

0 comments on commit be3e678

Please sign in to comment.