Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

4042 Hugging Face Hub integration #6454

Merged
merged 34 commits into from
Oct 19, 2023
Merged
Show file tree
Hide file tree
Changes from 31 commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
b7d462d
Add huggingface_hub as an optional dependency
katielink May 1, 2023
be3e678
Add download from huggingface_hub functionality
katielink May 1, 2023
bd61de8
Add huggingface_hub as an optional dependency
katielink May 1, 2023
d23aacd
Add download from huggingface_hub functionality
katielink May 1, 2023
fd11128
Merge branch '4042-hf-hub-integration' of https://github.com/katielin…
katielink May 2, 2023
165d659
Refactored downloading bundle
katielink May 2, 2023
e11e6a1
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 2, 2023
44b09ac
Add initial push functionality
katielink Jun 21, 2023
cbbece8
Merge branch '4042-hf-hub-integration' of https://github.com/katielin…
katielink Jun 21, 2023
de05033
Fix docstring
katielink Jun 21, 2023
c6071c2
Merge remote-tracking branch 'upstream/dev' into 4042-hf-hub-integration
katielink Jun 21, 2023
12fc6c2
Style + naming updates
katielink Jul 17, 2023
393e0d0
Add exception for repo format for hf hub
katielink Jul 17, 2023
96ef680
Refactor integration for better repo/bundle versioning
katielink Aug 15, 2023
697fbf1
Remove create_branch in push_to_hub flow for simplicity
katielink Aug 15, 2023
77ba1e1
Update docstring for push_to_huggingface_hub
katielink Aug 15, 2023
d210e41
Formatting and minor updates
katielink Aug 16, 2023
9950573
Add push_to_hf_hub usage example and function to docs
katielink Aug 16, 2023
f80734c
Formatting
katielink Aug 16, 2023
9698039
Fix bug
katielink Aug 16, 2023
c3e1956
Merge branch 'dev' into 4042-hf-hub-integration
katielink Aug 16, 2023
aac4679
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 16, 2023
bd53123
Merge branch 'dev' into 4042-hf-hub-integration
katielink Aug 16, 2023
eff08bc
Add push to hub test
katielink Oct 18, 2023
9008e29
Merge branch '4042-hf-hub-integration' of https://github.com/katielin…
katielink Oct 18, 2023
d92911d
Merge branch 'dev' into 4042-hf-hub-integration
katielink Oct 18, 2023
724269c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 18, 2023
1570cb9
Fix lint errors
katielink Oct 18, 2023
6a848be
Add test to exclude_cases
katielink Oct 18, 2023
654b1d8
Fix mypy error
katielink Oct 19, 2023
492bb59
Fix docs error
katielink Oct 19, 2023
6646ac3
Update requirements-dev.txt
katielink Oct 19, 2023
e63cfda
Update setup.cfg
katielink Oct 19, 2023
b6fb1ec
Merge branch 'dev' into 4042-hf-hub-integration
wyli Oct 19, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/bundle.rst
Original file line number Diff line number Diff line change
Expand Up @@ -48,4 +48,5 @@ Model Bundle
.. autofunction:: verify_metadata
.. autofunction:: verify_net_in_out
.. autofunction:: init_bundle
.. autofunction:: push_to_hf_hub
.. autofunction:: update_kwargs
5 changes: 2 additions & 3 deletions docs/source/installation.md
Original file line number Diff line number Diff line change
Expand Up @@ -254,11 +254,10 @@ Since MONAI v0.2.0, the extras syntax such as `pip install 'monai[nibabel]'` is
- The options are

```
[nibabel, skimage, scipy, pillow, tensorboard, gdown, ignite, torchvision, itk, tqdm, lmdb, psutil, cucim, openslide, pandas, einops, transformers, mlflow, clearml, matplotlib, tensorboardX, tifffile, imagecodecs, pyyaml, fire, jsonschema, ninja, pynrrd, pydicom, h5py, nni, optuna, onnx, onnxruntime, zarr, lpips, pynvml]
[nibabel, skimage, scipy, pillow, tensorboard, gdown, ignite, torchvision, itk, tqdm, lmdb, psutil, cucim, openslide, pandas, einops, transformers, mlflow, clearml, matplotlib, tensorboardX, tifffile, imagecodecs, pyyaml, fire, jsonschema, ninja, pynrrd, pydicom, h5py, nni, optuna, onnx, onnxruntime, zarr, lpips, pynvml, huggingface_hub]
```

which correspond to `nibabel`, `scikit-image`,`scipy`, `pillow`, `tensorboard`,
`gdown`, `pytorch-ignite`, `torchvision`, `itk`, `tqdm`, `lmdb`, `psutil`, `cucim`, `openslide-python`, `pandas`, `einops`, `transformers`, `mlflow`, `clearml`, `matplotlib`, `tensorboardX`, `tifffile`, `imagecodecs`, `pyyaml`, `fire`, `jsonschema`, `ninja`, `pynrrd`, `pydicom`, `h5py`, `nni`, `optuna`, `onnx`, `onnxruntime`, `zarr`, `lpips` and `nvidia-ml-py` respectively.

`gdown`, `pytorch-ignite`, `torchvision`, `itk`, `tqdm`, `lmdb`, `psutil`, `cucim`, `openslide-python`, `pandas`, `einops`, `transformers`, `mlflow`, `clearml`, `matplotlib`, `tensorboardX`, `tifffile`, `imagecodecs`, `pyyaml`, `fire`, `jsonschema`, `ninja`, `pynrrd`, `pydicom`, `h5py`, `nni`, `optuna`, `onnx`, `onnxruntime`, `zarr`, `lpips`, `nvidia-ml-py`, and `huggingface_hub` respectively.

- `pip install 'monai[all]'` installs all the optional dependencies.
1 change: 1 addition & 0 deletions monai/bundle/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
init_bundle,
load,
onnx_export,
push_to_hf_hub,
run,
run_workflow,
trt_export,
Expand Down
124 changes: 113 additions & 11 deletions monai/bundle/scripts.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
Checkpoint, has_ignite = optional_import("ignite.handlers", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Checkpoint")
requests_get, has_requests = optional_import("requests", name="get")
onnx, _ = optional_import("onnx")
huggingface_hub, _ = optional_import("huggingface_hub")

logger = get_logger(module_name=__name__)

Expand Down Expand Up @@ -244,6 +245,14 @@ def _get_latest_bundle_version(source: str, name: str, repo: str) -> dict[str, l
elif source == "github":
repo_owner, repo_name, tag_name = repo.split("/")
return get_bundle_versions(name, repo=f"{repo_owner}/{repo_name}", tag=tag_name)["latest_version"]
elif source == "huggingface_hub":
refs = huggingface_hub.list_repo_refs(repo_id=repo)
if len(refs.tags) > 0:
all_versions = [t.name for t in refs.tags] # git tags, not to be confused with `tag`
latest_version = ["latest_version" if "latest_version" in all_versions else all_versions[-1]][0]
else:
latest_version = [b.name for b in refs.branches][0] # use the branch that was last updated
return latest_version
else:
raise ValueError(
f"To get the latest bundle version, source should be 'github', 'monaihosting' or 'ngc', got {source}."
Expand Down Expand Up @@ -293,6 +302,9 @@ def download(
# Execute this module as a CLI entry, and download bundle from monaihosting with latest version:
python -m monai.bundle download --name <bundle_name> --source "monaihosting" --bundle_dir "./"
# Execute this module as a CLI entry, and download bundle from Hugging Face Hub:
python -m monai.bundle download --name "bundle_name" --source "huggingface_hub" --repo "repo_owner/repo_name"
# Execute this module as a CLI entry, and download bundle via URL:
python -m monai.bundle download --name <bundle_name> --url <url>
Expand All @@ -311,14 +323,15 @@ def download(
"monai_brats_mri_segmentation" in ngc:
https://catalog.ngc.nvidia.com/models?filters=&orderBy=scoreDESC&query=monai.
version: version name of the target bundle to download, like: "0.1.0". If `None`, will download
the latest version.
the latest version (or the last commit to the `main` branch in the case of Hugging Face Hub).
bundle_dir: target directory to store the downloaded data.
Default is `bundle` subfolder under `torch.hub.get_dir()`.
source: storage location name. This argument is used when `url` is `None`.
In default, the value is achieved from the environment variable BUNDLE_DOWNLOAD_SRC, and
it should be "ngc", "monaihosting" or "github".
repo: repo name. This argument is used when `url` is `None` and `source` is "github".
If used, it should be in the form of "repo_owner/repo_name/release_tag".
it should be "ngc", "monaihosting", "github", or "huggingface_hub".
repo: repo name. This argument is used when `url` is `None` and `source` is "github" or "huggingface_hub".
If `source` is "github", it should be in the form of "repo_owner/repo_name/release_tag".
If `source` is "huggingface_hub", it should be in the form of "repo_owner/repo_name".
url: url to download the data. If not `None`, data will be downloaded directly
and `source` will not be checked.
If `name` is `None`, filename is determined by `monai.apps.utils._basename(url)`.
Expand Down Expand Up @@ -351,9 +364,10 @@ def download(
bundle_dir_ = _process_bundle_dir(bundle_dir_)
if repo_ is None:
repo_ = "Project-MONAI/model-zoo/hosting_storage_v1"
if len(repo_.split("/")) != 3:
if len(repo_.split("/")) != 3 and source_ != "huggingface_hub":
raise ValueError("repo should be in the form of `repo_owner/repo_name/release_tag`.")

elif len(repo_.split("/")) != 2 and source_ == "huggingface_hub":
raise ValueError("Hugging Face Hub repo should be in the form of `repo_owner/repo_name`")
if url_ is not None:
if name_ is not None:
filepath = bundle_dir_ / f"{name_}.zip"
Expand All @@ -380,9 +394,12 @@ def download(
remove_prefix=remove_prefix_,
progress=progress_,
)
elif source_ == "huggingface_hub":
extract_path = os.path.join(bundle_dir_, name_)
huggingface_hub.snapshot_download(repo_id=repo_, revision=version_, local_dir=extract_path)
else:
raise NotImplementedError(
"Currently only download from `url`, source 'github', 'monaihosting' or 'ngc' are implemented,"
"Currently only download from `url`, source 'github', 'monaihosting', 'huggingface_hub' or 'ngc' are implemented,"
f"got source: {source_}."
)

Expand Down Expand Up @@ -427,7 +444,7 @@ def load(
https://api.ngc.nvidia.com/v2/models/nvidia/monaihosting/mednist_gan/versions/0.2.0/files/mednist_gan_v0.2.0.zip
model: a pytorch module to be updated. Default to None, using the "network_def" in the bundle.
version: version name of the target bundle to download, like: "0.1.0". If `None`, will download
the latest version.
the latest version. If `source` is "huggingface_hub", this argument is a Git revision id.
workflow_type: specifies the workflow type: "train" or "training" for a training workflow,
or "infer", "inference", "eval", "evaluation" for a inference workflow,
other unsupported string will raise a ValueError.
Expand All @@ -440,9 +457,10 @@ def load(
source: storage location name. This argument is used when `model_file` is not existing locally and need to be
downloaded first.
In default, the value is achieved from the environment variable BUNDLE_DOWNLOAD_SRC, and
it should be "ngc", "monaihosting" or "github".
repo: repo name. This argument is used when `url` is `None` and `source` is "github".
If used, it should be in the form of "repo_owner/repo_name/release_tag".
it should be "ngc", "monaihosting", "github", or "huggingface_hub".
repo: repo name. This argument is used when `url` is `None` and `source` is "github" or "huggingface_hub".
If `source` is "github", it should be in the form of "repo_owner/repo_name/release_tag".
If `source` is "huggingface_hub", it should be in the form of "repo_owner/repo_name".
remove_prefix: This argument is used when `source` is "ngc". Currently, all ngc bundles
have the ``monai_`` prefix, which is not existing in their model zoo contrasts. In order to
maintain the consistency between these three sources, remove prefix is necessary.
Expand Down Expand Up @@ -1597,6 +1615,90 @@ def init_bundle(
save_state(network, str(models_dir / "model.pt"))


def _add_model_card_metadata(new_modelcard_path):
# Extract license from LICENSE file
license_name = "unknown"
license_path = os.path.join(os.path.dirname(new_modelcard_path), "LICENSE")
if os.path.exists(license_path):
with open(license_path) as file:
content = file.read()
if "Apache License" in content and "Version 2.0" in content:
license_name = "apache-2.0"
elif "MIT License" in content:
license_name = "mit"
# Add relevant tags
tags = "- monai\n- medical\nlibrary_name: monai\n"
# Create tag section
tag_content = f"---\ntags:\n{tags}license: {license_name}\n---"

# Update model card
with open(new_modelcard_path) as file:
content = file.read()
new_content = tag_content + "\n" + content
with open(new_modelcard_path, "w") as file:
file.write(new_content)


def push_to_hf_hub(
repo: str,
name: str,
bundle_dir: str,
token: str | None = None,
private: bool | None = True,
version: str | None = None,
tag_as_latest_version: bool | None = False,
**upload_folder_kwargs: Any,
) -> Any:
"""
Push a MONAI bundle to the Hugging Face Hub.
Typical usage examples:
.. code-block:: bash
python -m monai.bundle push_to_hf_hub --repo <HF repository id> --name <bundle name> \
--bundle_dir <bundle directory> --version <version> ...
Args:
repo: namespace (user or organization) and a repo name separated by a /, e.g. `hf_username/bundle_name`
bundle_name: name of the bundle directory to push.
bundle_dir: path to the bundle directory.
token: Hugging Face authentication token. Default is `None` (will default to the stored token).
private: Private visibility of the repository on Hugging Face. Default is `True`.
version_name: Name of the version tag to create. Default is `None` (no version tag is created).
tag_as_latest_version: Whether to tag the commit as `latest_version`.
This version will downloaded by default when using `bundle.download()`. Default is `False`.
upload_folder_kwargs: Keyword arguments to pass to `HfApi.upload_folder`.
Returns:
repo_url: URL of the Hugging Face repo
"""
# Connect to API and create repo
hf_api = huggingface_hub.HfApi(token=token)
hf_api.create_repo(repo_id=repo, private=private, exist_ok=True)

# Create model card in bundle directory
new_modelcard_path = os.path.join(bundle_dir, name, "README.md")
modelcard_path = os.path.join(bundle_dir, name, "docs", "README.md")
if os.path.exists(modelcard_path):
# Copy README from old path if it exists
copyfile(modelcard_path, new_modelcard_path)
_add_model_card_metadata(new_modelcard_path)

# Upload bundle folder to repo
repo_url = hf_api.upload_folder(repo_id=repo, folder_path=os.path.join(bundle_dir, name), **upload_folder_kwargs)

# Create version tag if specified
if version is not None:
hf_api.create_tag(repo_id=repo, tag=version, exist_ok=True)

# Optionally tag as `latest_version`
if tag_as_latest_version:
hf_api.create_tag(repo_id=repo, tag="latest_version", exist_ok=True)

return repo_url


def create_workflow(
workflow_name: str | BundleWorkflow | None = None,
config_file: str | Sequence[str] | None = None,
Expand Down
1 change: 1 addition & 0 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ onnx>=1.13.0
onnxruntime; python_version <= '3.10'
typeguard<3 # https://github.com/microsoft/nni/issues/5457
filelock!=3.12.0 # https://github.com/microsoft/nni/issues/5523
huggingface_hub
katielink marked this conversation as resolved.
Show resolved Hide resolved
zarr
lpips==0.1.4
nvidia-ml-py
Expand Down
3 changes: 3 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ all =
optuna
onnx>=1.13.0
onnxruntime; python_version <= '3.10'
huggingface_hub
katielink marked this conversation as resolved.
Show resolved Hide resolved
zarr
lpips==0.1.4
nvidia-ml-py
Expand Down Expand Up @@ -160,6 +161,8 @@ pynvml =
# # workaround https://github.com/Project-MONAI/MONAI/issues/5882
# MetricsReloaded =
# MetricsReloaded @ git+https://github.com/Project-MONAI/MetricsReloaded@monai-support#egg=MetricsReloaded
huggingface_hub =
huggingface_hub

[flake8]
select = B,C,E,F,N,P,T4,W,B9
Expand Down
1 change: 1 addition & 0 deletions tests/min_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def run_testsuit():
"test_auto3dseg",
"test_bundle_onnx_export",
"test_bundle_trt_export",
"test_bundle_push_to_hf_hub",
"test_cachedataset",
"test_cachedataset_parallel",
"test_cachedataset_persistent_workers",
Expand Down
Loading