Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enhancing get_all_bundles_list under monai.bundle to support model zoo NGC hosting #6997

Merged
merged 12 commits into from
Sep 18, 2023
63 changes: 40 additions & 23 deletions monai/bundle/scripts.py
Original file line number Diff line number Diff line change
Expand Up @@ -541,11 +541,16 @@ def load(
return model


@deprecated_arg_default("tag", "hosting_storage_v1", "dev", since="1.2", replaced="1.5")
def _get_all_bundles_info(
repo: str = "Project-MONAI/model-zoo", tag: str = "hosting_storage_v1", auth_token: str | None = None
) -> dict[str, dict[str, dict[str, Any]]]:
if has_requests:
request_url = f"https://api.github.com/repos/{repo}/releases"
if tag == "hosting_storage_v1":
request_url = f"https://api.github.com/repos/{repo}/releases"
else:
request_url = f"https://raw.githubusercontent.com/{repo}/{tag}/models/model_info.json"

if auth_token is not None:
headers = {"Authorization": f"Bearer {auth_token}"}
resp = requests_get(request_url, headers=headers)
Expand All @@ -558,33 +563,40 @@ def _get_all_bundles_info(
bundle_name_pattern = re.compile(r"_v\d*.")
bundles_info: dict[str, dict[str, dict[str, Any]]] = {}

for release in releases_list:
if release["tag_name"] == tag:
for asset in release["assets"]:
asset_name = bundle_name_pattern.split(asset["name"])[0]
if asset_name not in bundles_info:
bundles_info[asset_name] = {}
asset_version = asset["name"].split(f"{asset_name}_v")[-1].replace(".zip", "")
bundles_info[asset_name][asset_version] = {
"id": asset["id"],
"name": asset["name"],
"size": asset["size"],
"download_count": asset["download_count"],
"browser_download_url": asset["browser_download_url"],
"created_at": asset["created_at"],
"updated_at": asset["updated_at"],
}
return bundles_info
if tag == "hosting_storage_v1":
for release in releases_list:
if release["tag_name"] == tag:
for asset in release["assets"]:
asset_name = bundle_name_pattern.split(asset["name"])[0]
if asset_name not in bundles_info:
bundles_info[asset_name] = {}
asset_version = asset["name"].split(f"{asset_name}_v")[-1].replace(".zip", "")
bundles_info[asset_name][asset_version] = dict(asset)
return bundles_info
else:
for asset in releases_list.keys():
asset_name = bundle_name_pattern.split(asset)[0]
if asset_name not in bundles_info:
bundles_info[asset_name] = {}
asset_version = asset.split(f"{asset_name}_v")[-1]
bundles_info[asset_name][asset_version] = {
"name": asset,
"browser_download_url": releases_list[asset]["source"],
}
return bundles_info
ericspod marked this conversation as resolved.
Show resolved Hide resolved
return bundles_info


@deprecated_arg_default("tag", "hosting_storage_v1", "dev", since="1.3", replaced="1.5")
def get_all_bundles_list(
repo: str = "Project-MONAI/model-zoo", tag: str = "hosting_storage_v1", auth_token: str | None = None
) -> list[tuple[str, str]]:
"""
Get all bundles names (and the latest versions) that are stored in the release of specified repository
with the provided tag. The default values of arguments correspond to the release of MONAI model zoo.
In order to increase the rate limits of calling Github APIs, you can input your personal access token.
with the provided tag. If tag is "dev", will get model information from
https://raw.githubusercontent.com/repo_owner/repo_name/dev/models/model_info.json.
The default values of arguments correspond to the release of MONAI model zoo. In order to increase the
rate limits of calling Github APIs, you can input your personal access token.
Please check the following link for more details about rate limiting:
https://docs.github.com/en/rest/overview/resources-in-the-rest-api#rate-limiting

Expand All @@ -610,6 +622,7 @@ def get_all_bundles_list(
return bundles_list


@deprecated_arg_default("tag", "hosting_storage_v1", "dev", since="1.3", replaced="1.5")
def get_bundle_versions(
bundle_name: str,
repo: str = "Project-MONAI/model-zoo",
Expand All @@ -618,7 +631,8 @@ def get_bundle_versions(
) -> dict[str, list[str] | str]:
"""
Get the latest version, as well as all existing versions of a bundle that is stored in the release of specified
repository with the provided tag.
repository with the provided tag. If tag is "dev", will get model information from
https://raw.githubusercontent.com/repo_owner/repo_name/dev/models/model_info.json.
In order to increase the rate limits of calling Github APIs, you can input your personal access token.
Please check the following link for more details about rate limiting:
https://docs.github.com/en/rest/overview/resources-in-the-rest-api#rate-limiting
Expand Down Expand Up @@ -646,6 +660,7 @@ def get_bundle_versions(
return {"latest_version": all_versions[-1], "all_versions": all_versions}


@deprecated_arg_default("tag", "hosting_storage_v1", "dev", since="1.3", replaced="1.5")
def get_bundle_info(
bundle_name: str,
version: str | None = None,
Expand All @@ -656,7 +671,9 @@ def get_bundle_info(
"""
Get all information
(include "id", "name", "size", "download_count", "browser_download_url", "created_at", "updated_at") of a bundle
with the specified bundle name and version.
with the specified bundle name and version which is stored in the release of specified repository with the provided tag.
Since v1.5, "hosting_storage_v1" will be deprecated in favor of 'dev', which contains only "name" and "browser_download_url".
information about a bundle.
In order to increase the rate limits of calling Github APIs, you can input your personal access token.
Please check the following link for more details about rate limiting:
https://docs.github.com/en/rest/overview/resources-in-the-rest-api#rate-limiting
Expand Down Expand Up @@ -685,7 +702,7 @@ def get_bundle_info(
if version not in bundle_info:
raise ValueError(f"version: {version} of bundle: {bundle_name} is not existing.")

return bundle_info[version]
return bundle_info[version] # type: ignore[no-any-return]


@deprecated_arg("runner_id", since="1.1", removed="1.3", new_name="run_id", msg_suffix="please use `run_id` instead.")
Expand Down
32 changes: 27 additions & 5 deletions tests/test_bundle_get_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,21 +25,34 @@

TEST_CASE_2 = [{"bundle_name": "spleen_ct_segmentation", "version": "0.1.0", "auth_token": None}]

TEST_CASE_FAKE_TOKEN = [{"bundle_name": "spleen_ct_segmentation", "version": "0.1.0", "auth_token": "ghp_errortoken"}]
TEST_CASE_3 = [{"tag": "hosting_storage_v1"}]

TEST_CASE_4 = [{"tag": "dev"}]

TEST_CASE_5 = [{"bundle_name": "brats_mri_segmentation", "tag": "dev"}]

TEST_CASE_6 = [{"bundle_name": "spleen_ct_segmentation", "version": "0.1.0", "auth_token": None, "tag": "dev"}]

TEST_CASE_FAKE_TOKEN_1 = [{"bundle_name": "spleen_ct_segmentation", "version": "0.1.0", "auth_token": "ghp_errortoken"}]

TEST_CASE_FAKE_TOKEN_2 = [
{"bundle_name": "spleen_ct_segmentation", "version": "0.1.0", "auth_token": "ghp_errortoken", "tag": "dev"}
]


@skip_if_windows
@SkipIfNoModule("requests")
class TestGetBundleData(unittest.TestCase):
@parameterized.expand([TEST_CASE_3, TEST_CASE_4])
@skip_if_quick
def test_get_all_bundles_list(self):
def test_get_all_bundles_list(self, params):
with skip_if_downloading_fails():
output = get_all_bundles_list()
output = get_all_bundles_list(**params)
self.assertTrue(isinstance(output, list))
self.assertTrue(isinstance(output[0], tuple))
self.assertTrue(len(output[0]) == 2)

@parameterized.expand([TEST_CASE_1])
@parameterized.expand([TEST_CASE_1, TEST_CASE_5])
@skip_if_quick
def test_get_bundle_versions(self, params):
with skip_if_downloading_fails():
Expand All @@ -57,7 +70,16 @@ def test_get_bundle_info(self, params):
for key in ["id", "name", "size", "download_count", "browser_download_url"]:
self.assertTrue(key in output)

@parameterized.expand([TEST_CASE_FAKE_TOKEN])
@parameterized.expand([TEST_CASE_5, TEST_CASE_6])
@skip_if_quick
def test_get_bundle_info_monaihosting(self, params):
with skip_if_downloading_fails():
output = get_bundle_info(**params)
self.assertTrue(isinstance(output, dict))
for key in ["name", "browser_download_url"]:
self.assertTrue(key in output)

@parameterized.expand([TEST_CASE_FAKE_TOKEN_1, TEST_CASE_FAKE_TOKEN_2])
@skip_if_quick
def test_fake_token(self, params):
with skip_if_downloading_fails():
Expand Down