Skip to content

Commit

Permalink
Enhancing get_all_bundles_list under monai.bundle to support model …
Browse files Browse the repository at this point in the history
…zoo NGC hosting (#6997)

Fixes #6833 

### Description
Add `model_info_url` in `get_all_bundles_list`, `get_bundle_versions`,
and `get_bundle_info`.

### Types of changes
<!--- Put an `x` in all the boxes that apply, and remove the not
applicable items -->
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [x] New tests added to cover the changes.
- [ ] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [x] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [x] In-line docstrings updated.
- [ ] Documentation updated, tested `make html` command in the `docs/`
folder.

---------

Signed-off-by: KumoLiu <yunl@nvidia.com>
  • Loading branch information
KumoLiu authored Sep 18, 2023
1 parent b31367f commit 73a7601
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 28 deletions.
62 changes: 39 additions & 23 deletions monai/bundle/scripts.py
Original file line number Diff line number Diff line change
Expand Up @@ -541,11 +541,16 @@ def load(
return model


@deprecated_arg_default("tag", "hosting_storage_v1", "dev", since="1.2", replaced="1.5")
def _get_all_bundles_info(
repo: str = "Project-MONAI/model-zoo", tag: str = "hosting_storage_v1", auth_token: str | None = None
) -> dict[str, dict[str, dict[str, Any]]]:
if has_requests:
request_url = f"https://api.github.com/repos/{repo}/releases"
if tag == "hosting_storage_v1":
request_url = f"https://api.github.com/repos/{repo}/releases"
else:
request_url = f"https://raw.githubusercontent.com/{repo}/{tag}/models/model_info.json"

if auth_token is not None:
headers = {"Authorization": f"Bearer {auth_token}"}
resp = requests_get(request_url, headers=headers)
Expand All @@ -558,33 +563,39 @@ def _get_all_bundles_info(
bundle_name_pattern = re.compile(r"_v\d*.")
bundles_info: dict[str, dict[str, dict[str, Any]]] = {}

for release in releases_list:
if release["tag_name"] == tag:
for asset in release["assets"]:
asset_name = bundle_name_pattern.split(asset["name"])[0]
if asset_name not in bundles_info:
bundles_info[asset_name] = {}
asset_version = asset["name"].split(f"{asset_name}_v")[-1].replace(".zip", "")
bundles_info[asset_name][asset_version] = {
"id": asset["id"],
"name": asset["name"],
"size": asset["size"],
"download_count": asset["download_count"],
"browser_download_url": asset["browser_download_url"],
"created_at": asset["created_at"],
"updated_at": asset["updated_at"],
}
return bundles_info
if tag == "hosting_storage_v1":
for release in releases_list:
if release["tag_name"] == tag:
for asset in release["assets"]:
asset_name = bundle_name_pattern.split(asset["name"])[0]
if asset_name not in bundles_info:
bundles_info[asset_name] = {}
asset_version = asset["name"].split(f"{asset_name}_v")[-1].replace(".zip", "")
bundles_info[asset_name][asset_version] = dict(asset)
return bundles_info
else:
for asset in releases_list.keys():
asset_name = bundle_name_pattern.split(asset)[0]
if asset_name not in bundles_info:
bundles_info[asset_name] = {}
asset_version = asset.split(f"{asset_name}_v")[-1]
bundles_info[asset_name][asset_version] = {
"name": asset,
"browser_download_url": releases_list[asset]["source"],
}
return bundles_info


@deprecated_arg_default("tag", "hosting_storage_v1", "dev", since="1.3", replaced="1.5")
def get_all_bundles_list(
repo: str = "Project-MONAI/model-zoo", tag: str = "hosting_storage_v1", auth_token: str | None = None
) -> list[tuple[str, str]]:
"""
Get all bundles names (and the latest versions) that are stored in the release of specified repository
with the provided tag. The default values of arguments correspond to the release of MONAI model zoo.
In order to increase the rate limits of calling Github APIs, you can input your personal access token.
with the provided tag. If tag is "dev", will get model information from
https://raw.githubusercontent.com/repo_owner/repo_name/dev/models/model_info.json.
The default values of arguments correspond to the release of MONAI model zoo. In order to increase the
rate limits of calling Github APIs, you can input your personal access token.
Please check the following link for more details about rate limiting:
https://docs.github.com/en/rest/overview/resources-in-the-rest-api#rate-limiting
Expand All @@ -610,6 +621,7 @@ def get_all_bundles_list(
return bundles_list


@deprecated_arg_default("tag", "hosting_storage_v1", "dev", since="1.3", replaced="1.5")
def get_bundle_versions(
bundle_name: str,
repo: str = "Project-MONAI/model-zoo",
Expand All @@ -618,7 +630,8 @@ def get_bundle_versions(
) -> dict[str, list[str] | str]:
"""
Get the latest version, as well as all existing versions of a bundle that is stored in the release of specified
repository with the provided tag.
repository with the provided tag. If tag is "dev", will get model information from
https://raw.githubusercontent.com/repo_owner/repo_name/dev/models/model_info.json.
In order to increase the rate limits of calling Github APIs, you can input your personal access token.
Please check the following link for more details about rate limiting:
https://docs.github.com/en/rest/overview/resources-in-the-rest-api#rate-limiting
Expand Down Expand Up @@ -646,6 +659,7 @@ def get_bundle_versions(
return {"latest_version": all_versions[-1], "all_versions": all_versions}


@deprecated_arg_default("tag", "hosting_storage_v1", "dev", since="1.3", replaced="1.5")
def get_bundle_info(
bundle_name: str,
version: str | None = None,
Expand All @@ -656,7 +670,9 @@ def get_bundle_info(
"""
Get all information
(include "id", "name", "size", "download_count", "browser_download_url", "created_at", "updated_at") of a bundle
with the specified bundle name and version.
with the specified bundle name and version which is stored in the release of specified repository with the provided tag.
Since v1.5, "hosting_storage_v1" will be deprecated in favor of 'dev', which contains only "name" and "browser_download_url".
information about a bundle.
In order to increase the rate limits of calling Github APIs, you can input your personal access token.
Please check the following link for more details about rate limiting:
https://docs.github.com/en/rest/overview/resources-in-the-rest-api#rate-limiting
Expand Down Expand Up @@ -685,7 +701,7 @@ def get_bundle_info(
if version not in bundle_info:
raise ValueError(f"version: {version} of bundle: {bundle_name} is not existing.")

return bundle_info[version]
return bundle_info[version] # type: ignore[no-any-return]


@deprecated_arg("runner_id", since="1.1", removed="1.3", new_name="run_id", msg_suffix="please use `run_id` instead.")
Expand Down
32 changes: 27 additions & 5 deletions tests/test_bundle_get_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,21 +25,34 @@

TEST_CASE_2 = [{"bundle_name": "spleen_ct_segmentation", "version": "0.1.0", "auth_token": None}]

TEST_CASE_FAKE_TOKEN = [{"bundle_name": "spleen_ct_segmentation", "version": "0.1.0", "auth_token": "ghp_errortoken"}]
TEST_CASE_3 = [{"tag": "hosting_storage_v1"}]

TEST_CASE_4 = [{"tag": "dev"}]

TEST_CASE_5 = [{"bundle_name": "brats_mri_segmentation", "tag": "dev"}]

TEST_CASE_6 = [{"bundle_name": "spleen_ct_segmentation", "version": "0.1.0", "auth_token": None, "tag": "dev"}]

TEST_CASE_FAKE_TOKEN_1 = [{"bundle_name": "spleen_ct_segmentation", "version": "0.1.0", "auth_token": "ghp_errortoken"}]

TEST_CASE_FAKE_TOKEN_2 = [
{"bundle_name": "spleen_ct_segmentation", "version": "0.1.0", "auth_token": "ghp_errortoken", "tag": "dev"}
]


@skip_if_windows
@SkipIfNoModule("requests")
class TestGetBundleData(unittest.TestCase):
@parameterized.expand([TEST_CASE_3, TEST_CASE_4])
@skip_if_quick
def test_get_all_bundles_list(self):
def test_get_all_bundles_list(self, params):
with skip_if_downloading_fails():
output = get_all_bundles_list()
output = get_all_bundles_list(**params)
self.assertTrue(isinstance(output, list))
self.assertTrue(isinstance(output[0], tuple))
self.assertTrue(len(output[0]) == 2)

@parameterized.expand([TEST_CASE_1])
@parameterized.expand([TEST_CASE_1, TEST_CASE_5])
@skip_if_quick
def test_get_bundle_versions(self, params):
with skip_if_downloading_fails():
Expand All @@ -57,7 +70,16 @@ def test_get_bundle_info(self, params):
for key in ["id", "name", "size", "download_count", "browser_download_url"]:
self.assertTrue(key in output)

@parameterized.expand([TEST_CASE_FAKE_TOKEN])
@parameterized.expand([TEST_CASE_5, TEST_CASE_6])
@skip_if_quick
def test_get_bundle_info_monaihosting(self, params):
with skip_if_downloading_fails():
output = get_bundle_info(**params)
self.assertTrue(isinstance(output, dict))
for key in ["name", "browser_download_url"]:
self.assertTrue(key in output)

@parameterized.expand([TEST_CASE_FAKE_TOKEN_1, TEST_CASE_FAKE_TOKEN_2])
@skip_if_quick
def test_fake_token(self, params):
with skip_if_downloading_fails():
Expand Down

0 comments on commit 73a7601

Please sign in to comment.