Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

4042 Hugging Face Hub integration #6454

Merged
merged 34 commits into from
Oct 19, 2023
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
b7d462d
Add huggingface_hub as an optional dependency
katielink May 1, 2023
be3e678
Add download from huggingface_hub functionality
katielink May 1, 2023
bd61de8
Add huggingface_hub as an optional dependency
katielink May 1, 2023
d23aacd
Add download from huggingface_hub functionality
katielink May 1, 2023
fd11128
Merge branch '4042-hf-hub-integration' of https://github.com/katielin…
katielink May 2, 2023
165d659
Refactored downloading bundle
katielink May 2, 2023
e11e6a1
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 2, 2023
44b09ac
Add initial push functionality
katielink Jun 21, 2023
cbbece8
Merge branch '4042-hf-hub-integration' of https://github.com/katielin…
katielink Jun 21, 2023
de05033
Fix docstring
katielink Jun 21, 2023
c6071c2
Merge remote-tracking branch 'upstream/dev' into 4042-hf-hub-integration
katielink Jun 21, 2023
12fc6c2
Style + naming updates
katielink Jul 17, 2023
393e0d0
Add exception for repo format for hf hub
katielink Jul 17, 2023
96ef680
Refactor integration for better repo/bundle versioning
katielink Aug 15, 2023
697fbf1
Remove create_branch in push_to_hub flow for simplicity
katielink Aug 15, 2023
77ba1e1
Update docstring for push_to_huggingface_hub
katielink Aug 15, 2023
d210e41
Formatting and minor updates
katielink Aug 16, 2023
9950573
Add push_to_hf_hub usage example and function to docs
katielink Aug 16, 2023
f80734c
Formatting
katielink Aug 16, 2023
9698039
Fix bug
katielink Aug 16, 2023
c3e1956
Merge branch 'dev' into 4042-hf-hub-integration
katielink Aug 16, 2023
aac4679
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 16, 2023
bd53123
Merge branch 'dev' into 4042-hf-hub-integration
katielink Aug 16, 2023
eff08bc
Add push to hub test
katielink Oct 18, 2023
9008e29
Merge branch '4042-hf-hub-integration' of https://github.com/katielin…
katielink Oct 18, 2023
d92911d
Merge branch 'dev' into 4042-hf-hub-integration
katielink Oct 18, 2023
724269c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 18, 2023
1570cb9
Fix lint errors
katielink Oct 18, 2023
6a848be
Add test to exclude_cases
katielink Oct 18, 2023
654b1d8
Fix mypy error
katielink Oct 19, 2023
492bb59
Fix docs error
katielink Oct 19, 2023
6646ac3
Update requirements-dev.txt
katielink Oct 19, 2023
e63cfda
Update setup.cfg
katielink Oct 19, 2023
b6fb1ec
Merge branch 'dev' into 4042-hf-hub-integration
wyli Oct 19, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,4 @@ optuna
opencv-python-headless
onnx>=1.13.0
onnxruntime; python_version <= '3.10'
huggingface_hub
8 changes: 4 additions & 4 deletions docs/source/installation.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
- [Uninstall the packages](#uninstall-the-packages)
- [From conda-forge](#from-conda-forge)
- [From GitHub](#from-github)
- [Option 1 (as a part of your system-wide module)](#option-1-as-a-part-of-your-system-wide-module)
- [Option 2 (editable installation)](#option-2-editable-installation)
- [Option 1 (as a part of your system-wide module):](#option-1-as-a-part-of-your-system-wide-module)
- [Option 2 (editable installation):](#option-2-editable-installation)
- [Validating the install](#validating-the-install)
- [MONAI version string](#monai-version-string)
- [From DockerHub](#from-dockerhub)
Expand Down Expand Up @@ -254,10 +254,10 @@ Since MONAI v0.2.0, the extras syntax such as `pip install 'monai[nibabel]'` is
- The options are

```
[nibabel, skimage, pillow, tensorboard, gdown, ignite, torchvision, itk, tqdm, lmdb, psutil, cucim, openslide, pandas, einops, transformers, mlflow, clearml, matplotlib, tensorboardX, tifffile, imagecodecs, pyyaml, fire, jsonschema, ninja, pynrrd, pydicom, h5py, nni, optuna, onnx, onnxruntime]
[nibabel, skimage, pillow, tensorboard, gdown, ignite, torchvision, itk, tqdm, lmdb, psutil, cucim, openslide, pandas, einops, transformers, mlflow, clearml, matplotlib, tensorboardX, tifffile, imagecodecs, pyyaml, fire, jsonschema, ninja, pynrrd, pydicom, h5py, nni, optuna, onnx, onnxruntime, huggingface_hub]
```

which correspond to `nibabel`, `scikit-image`, `pillow`, `tensorboard`,
`gdown`, `pytorch-ignite`, `torchvision`, `itk`, `tqdm`, `lmdb`, `psutil`, `cucim`, `openslide-python`, `pandas`, `einops`, `transformers`, `mlflow`, `clearml`, `matplotlib`, `tensorboardX`, `tifffile`, `imagecodecs`, `pyyaml`, `fire`, `jsonschema`, `ninja`, `pynrrd`, `pydicom`, `h5py`, `nni`, `optuna`, `onnx`, `onnxruntime`, respectively.
`gdown`, `pytorch-ignite`, `torchvision`, `itk`, `tqdm`, `lmdb`, `psutil`, `cucim`, `openslide-python`, `pandas`, `einops`, `transformers`, `mlflow`, `clearml`, `matplotlib`, `tensorboardX`, `tifffile`, `imagecodecs`, `pyyaml`, `fire`, `jsonschema`, `ninja`, `pynrrd`, `pydicom`, `h5py`, `nni`, `optuna`, `onnx`, `onnxruntime`, `huggingface_hub`, respectively.

- `pip install 'monai[all]'` installs all the optional dependencies.
57 changes: 48 additions & 9 deletions monai/bundle/scripts.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
Checkpoint, has_ignite = optional_import("ignite.handlers", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Checkpoint")
requests_get, has_requests = optional_import("requests", name="get")
onnx, _ = optional_import("onnx")
huggingface_hub, _ = optional_import("huggingface_hub")

logger = get_logger(module_name=__name__)

Expand Down Expand Up @@ -192,6 +193,13 @@ def _download_from_ngc(
extractall(filepath=filepath, output_dir=extract_path, has_base=True)


def _download_from_huggingface_hub(repo: str, download_path: str, filename: str, version: str) -> None:
wyli marked this conversation as resolved.
Show resolved Hide resolved
if len(repo.split("/")) != 2:
wyli marked this conversation as resolved.
Show resolved Hide resolved
raise ValueError("if source is `huggingface_hub`, repo should be in the form `repo_owner/repo_name`")
extract_path = os.path.join(download_path, filename)
huggingface_hub.snapshot_download(repo_id=repo, revision=version, local_dir=extract_path, local_dir_use_symlinks="auto")
wyli marked this conversation as resolved.
Show resolved Hide resolved


def _get_latest_bundle_version(source: str, name: str, repo: str) -> dict[str, list[str] | str] | Any | None:
if source == "ngc":
name = _add_ngc_prefix(name)
Expand All @@ -203,6 +211,10 @@ def _get_latest_bundle_version(source: str, name: str, repo: str) -> dict[str, l
elif source == "github":
repo_owner, repo_name, tag_name = repo.split("/")
return get_bundle_versions(name, repo=f"{repo_owner}/{repo_name}", tag=tag_name)["latest_version"]
elif source == "huggingface_hub":
huggingface_hub.list_repo_refs(repo_id=f"{repo}/{name}", repo_type="model")
#TODO: implement this
return None
else:
raise ValueError(f"To get the latest bundle version, source should be 'github' or 'ngc', got {source}.")

Expand Down Expand Up @@ -247,6 +259,9 @@ def download(
# Execute this module as a CLI entry, and download bundle from ngc with latest version:
python -m monai.bundle download --name <bundle_name> --source "ngc" --bundle_dir "./"

# Execute this module as a CLI entry, and download bundle from Hugging Face Hub:
python -m monai.bundle download --name "bundle_name" --source "huggingface_hub" --repo "repo_owner/repo_name"

# Execute this module as a CLI entry, and download bundle via URL:
python -m monai.bundle download --name <bundle_name> --url <url>

Expand All @@ -265,14 +280,15 @@ def download(
"monai_brats_mri_segmentation" in ngc:
https://catalog.ngc.nvidia.com/models?filters=&orderBy=scoreDESC&query=monai.
version: version name of the target bundle to download, like: "0.1.0". If `None`, will download
the latest version.
the latest version. If `source` is "huggingface_hub", this argument is a Git revision id.
bundle_dir: target directory to store the downloaded data.
Default is `bundle` subfolder under `torch.hub.get_dir()`.
source: storage location name. This argument is used when `url` is `None`.
In default, the value is achieved from the environment variable BUNDLE_DOWNLOAD_SRC, and
it should be "ngc" or "github".
repo: repo name. This argument is used when `url` is `None` and `source` is "github".
If used, it should be in the form of "repo_owner/repo_name/release_tag".
it should be "ngc", "github", or "huggingface_hub".
repo: repo name. This argument is used when `url` is `None` and `source` is "github" or "huggingface_hub".
If `source` is "github", it should be in the form of "repo_owner/repo_name/release_tag".
If `source` is "huggingface_hub", it should be in the form of "repo_owner/repo_name".
url: url to download the data. If not `None`, data will be downloaded directly
and `source` will not be checked.
If `name` is `None`, filename is determined by `monai.apps.utils._basename(url)`.
Expand Down Expand Up @@ -332,9 +348,18 @@ def download(
remove_prefix=remove_prefix_,
progress=progress_,
)
elif source_ == "huggingface_hub":
if name_ is None:
raise ValueError(f"To download from source: 'huggingface_hub', `name` must be provided, got {name_}.")
_download_from_huggingface_hub(
repo=repo_,
download_path=bundle_dir_,
filename=name_,
version=version_,
)
else:
raise NotImplementedError(
f"Currently only download from `url`, source 'github' or 'ngc' are implemented, got source: {source_}."
f"Currently only download from `url`, source 'github', 'ngc', or 'huggingface_hub' are implemented, got source: {source_}."
)


Expand Down Expand Up @@ -365,7 +390,7 @@ def load(
"monai_brats_mri_segmentation" in ngc:
https://catalog.ngc.nvidia.com/models?filters=&orderBy=scoreDESC&query=monai.
version: version name of the target bundle to download, like: "0.1.0". If `None`, will download
the latest version.
the latest version. If `source` is "huggingface_hub", this argument is a Git revision id.
model_file: the relative path of the model weights or TorchScript module within bundle.
If `None`, "models/model.pt" or "models/model.ts" will be used.
load_ts_module: a flag to specify if loading the TorchScript module.
Expand All @@ -374,9 +399,10 @@ def load(
source: storage location name. This argument is used when `model_file` is not existing locally and need to be
downloaded first.
In default, the value is achieved from the environment variable BUNDLE_DOWNLOAD_SRC, and
it should be "ngc" or "github".
repo: repo name. This argument is used when `url` is `None` and `source` is "github".
If used, it should be in the form of "repo_owner/repo_name/release_tag".
it should be "ngc", "github", or "huggingface_hub".
repo: repo name. This argument is used when `url` is `None` and `source` is "github" or "huggingface_hub".
If `source` is "github", it should be in the form of "repo_owner/repo_name/release_tag".
If `source` is "huggingface_hub", it should be in the form of "repo_owner/repo_name".
remove_prefix: This argument is used when `source` is "ngc". Currently, all ngc bundles
have the ``monai_`` prefix, which is not existing in their model zoo contrasts. In order to
maintain the consistency between these two sources, remove prefix is necessary.
Expand Down Expand Up @@ -1501,3 +1527,16 @@ def init_bundle(
copyfile(str(ckpt_file), str(models_dir / "model.pt"))
elif network is not None:
save_state(network, str(models_dir / "model.pt"))


def push_to_hf_hub(bundle_dir: str, repo_name: str) -> None:
"""
Push the current bundle to the Hugging Face Hub.

Args:
bundle_dir: path to the bundle directory to push
repo_name: name of the repo to create or push to the HF Hub
"""
hf_api = huggingface_hub.HfApi()
repo_id = hf_api.create_repo(name=repo_name, exist_ok=True)
wyli marked this conversation as resolved.
Show resolved Hide resolved
return hf_api.upload_folder(path=bundle_dir, repo_id=repo_id)
1 change: 1 addition & 0 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -52,3 +52,4 @@ onnx>=1.13.0
onnxruntime; python_version <= '3.10'
typeguard<3 # https://github.com/microsoft/nni/issues/5457
filelock!=3.12.0 # https://github.com/microsoft/nni/issues/5523
huggingface_hub
katielink marked this conversation as resolved.
Show resolved Hide resolved
3 changes: 3 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ all =
optuna
onnx>=1.13.0
onnxruntime; python_version <= '3.10'
huggingface_hub
katielink marked this conversation as resolved.
Show resolved Hide resolved
nibabel =
nibabel
ninja =
Expand Down Expand Up @@ -145,6 +146,8 @@ onnx =
# # workaround https://github.com/Project-MONAI/MONAI/issues/5882
# MetricsReloaded =
# MetricsReloaded @ git+https://github.com/Project-MONAI/MetricsReloaded@monai-support#egg=MetricsReloaded
huggingface_hub =
huggingface_hub

[flake8]
select = B,C,E,F,N,P,T4,W,B9
Expand Down