Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
Signed-off-by: KumoLiu <yunl@nvidia.com>
  • Loading branch information
KumoLiu committed Sep 18, 2023
1 parent 281cb01 commit 546d6c0
Show file tree
Hide file tree
Showing 2 changed files with 130 additions and 30 deletions.
113 changes: 85 additions & 28 deletions monai/bundle/scripts.py
Original file line number Diff line number Diff line change
Expand Up @@ -541,11 +541,24 @@ def load(
return model


@deprecated_arg_default(
"model_info_url",
None,
"https://raw.githubusercontent.com/Project-MONAI/model-zoo/dev/models/model_info.json",
since="1.3",
replaced="1.5",
)
def _get_all_bundles_info(
repo: str = "Project-MONAI/model-zoo", tag: str = "hosting_storage_v1", auth_token: str | None = None
repo: str = "Project-MONAI/model-zoo",
tag: str = "hosting_storage_v1",
auth_token: str | None = None,
model_info_url: str | None = None,
) -> dict[str, dict[str, dict[str, Any]]]:
if has_requests:
request_url = f"https://api.github.com/repos/{repo}/releases"
if model_info_url is not None:
request_url = model_info_url
else:
request_url = f"https://api.github.com/repos/{repo}/releases"
if auth_token is not None:
headers = {"Authorization": f"Bearer {auth_token}"}
resp = requests_get(request_url, headers=headers)
Expand All @@ -558,33 +571,56 @@ def _get_all_bundles_info(
bundle_name_pattern = re.compile(r"_v\d*.")
bundles_info: dict[str, dict[str, dict[str, Any]]] = {}

for release in releases_list:
if release["tag_name"] == tag:
for asset in release["assets"]:
asset_name = bundle_name_pattern.split(asset["name"])[0]
if asset_name not in bundles_info:
bundles_info[asset_name] = {}
asset_version = asset["name"].split(f"{asset_name}_v")[-1].replace(".zip", "")
bundles_info[asset_name][asset_version] = {
"id": asset["id"],
"name": asset["name"],
"size": asset["size"],
"download_count": asset["download_count"],
"browser_download_url": asset["browser_download_url"],
"created_at": asset["created_at"],
"updated_at": asset["updated_at"],
}
return bundles_info
if model_info_url is not None:
for asset in releases_list.keys():
asset_name = bundle_name_pattern.split(asset)[0]
if asset_name not in bundles_info:
bundles_info[asset_name] = {}
asset_version = asset.split(f"{asset_name}_v")[-1]
bundles_info[asset_name][asset_version] = {
"name": asset,
"browser_download_url": releases_list[asset]["source"],
}
return bundles_info
else:
for release in releases_list:
if release["tag_name"] == tag:
for asset in release["assets"]:
asset_name = bundle_name_pattern.split(asset["name"])[0]
if asset_name not in bundles_info:
bundles_info[asset_name] = {}
asset_version = asset["name"].split(f"{asset_name}_v")[-1].replace(".zip", "")
bundles_info[asset_name][asset_version] = {
"id": asset["id"],
"name": asset["name"],
"size": asset["size"],
"download_count": asset["download_count"],
"browser_download_url": asset["browser_download_url"],
"created_at": asset["created_at"],
"updated_at": asset["updated_at"],
}
return bundles_info
return bundles_info


@deprecated_arg_default(
"model_info_url",
None,
"https://raw.githubusercontent.com/Project-MONAI/model-zoo/dev/models/model_info.json",
since="1.3",
replaced="1.5",
)
def get_all_bundles_list(
repo: str = "Project-MONAI/model-zoo", tag: str = "hosting_storage_v1", auth_token: str | None = None
repo: str = "Project-MONAI/model-zoo",
tag: str = "hosting_storage_v1",
auth_token: str | None = None,
model_info_url: str | None = None,
) -> list[tuple[str, str]]:
"""
Get all bundles names (and the latest versions) that are stored in the release of specified repository
with the provided tag. The default values of arguments correspond to the release of MONAI model zoo.
In order to increase the rate limits of calling Github APIs, you can input your personal access token.
with the provided tag or listed in `model_info_url`. The default values of arguments correspond to the
release of MONAI model zoo. In order to increase the rate limits of calling Github APIs, you can input
your personal access token.
Please check the following link for more details about rate limiting:
https://docs.github.com/en/rest/overview/resources-in-the-rest-api#rate-limiting
Expand All @@ -595,13 +631,14 @@ def get_all_bundles_list(
repo: it should be in the form of "repo_owner/repo_name/".
tag: the tag name of the release.
auth_token: github personal access token.
model_info_url: a JSON file link containing all of the model information.
Returns:
a list of tuple in the form of (bundle name, latest version).
"""

bundles_info = _get_all_bundles_info(repo=repo, tag=tag, auth_token=auth_token)
bundles_info = _get_all_bundles_info(repo=repo, tag=tag, auth_token=auth_token, model_info_url=model_info_url)
bundles_list = []
for bundle_name in bundles_info:
latest_version = sorted(bundles_info[bundle_name].keys())[-1]
Expand All @@ -610,15 +647,23 @@ def get_all_bundles_list(
return bundles_list


@deprecated_arg_default(
"model_info_url",
None,
"https://raw.githubusercontent.com/Project-MONAI/model-zoo/dev/models/model_info.json",
since="1.3",
replaced="1.5",
)
def get_bundle_versions(
bundle_name: str,
repo: str = "Project-MONAI/model-zoo",
tag: str = "hosting_storage_v1",
auth_token: str | None = None,
model_info_url: str | None = None,
) -> dict[str, list[str] | str]:
"""
Get the latest version, as well as all existing versions of a bundle that is stored in the release of specified
repository with the provided tag.
repository with the provided tag or listed in `model_info_url`.
In order to increase the rate limits of calling Github APIs, you can input your personal access token.
Please check the following link for more details about rate limiting:
https://docs.github.com/en/rest/overview/resources-in-the-rest-api#rate-limiting
Expand All @@ -631,13 +676,14 @@ def get_bundle_versions(
repo: it should be in the form of "repo_owner/repo_name/".
tag: the tag name of the release.
auth_token: github personal access token.
model_info_url: a JSON file link containing all of the model information.
Returns:
a dictionary that contains the latest version and all versions of a bundle.
"""

bundles_info = _get_all_bundles_info(repo=repo, tag=tag, auth_token=auth_token)
bundles_info = _get_all_bundles_info(repo=repo, tag=tag, auth_token=auth_token, model_info_url=model_info_url)
if bundle_name not in bundles_info:
raise ValueError(f"bundle: {bundle_name} is not existing in repo: {repo}.")
bundle_info = bundles_info[bundle_name]
Expand All @@ -646,17 +692,27 @@ def get_bundle_versions(
return {"latest_version": all_versions[-1], "all_versions": all_versions}


@deprecated_arg_default(
"model_info_url",
None,
"https://raw.githubusercontent.com/Project-MONAI/model-zoo/dev/models/model_info.json",
since="1.3",
replaced="1.5",
)
def get_bundle_info(
bundle_name: str,
version: str | None = None,
repo: str = "Project-MONAI/model-zoo",
tag: str = "hosting_storage_v1",
auth_token: str | None = None,
model_info_url: str | None = None,
) -> dict[str, Any]:
"""
Get all information
(include "id", "name", "size", "download_count", "browser_download_url", "created_at", "updated_at") of a bundle
with the specified bundle name and version.
with the specified bundle name and version which is stored in the release of specified repository with the provided tag.
Since v1.5, it has been deprecated in favor of'model_info_url', which contains only "name" and "browser_download_url"
information about a bundle.
In order to increase the rate limits of calling Github APIs, you can input your personal access token.
Please check the following link for more details about rate limiting:
https://docs.github.com/en/rest/overview/resources-in-the-rest-api#rate-limiting
Expand All @@ -670,13 +726,14 @@ def get_bundle_info(
repo: it should be in the form of "repo_owner/repo_name/".
tag: the tag name of the release.
auth_token: github personal access token.
model_info_url: a JSON file link containing all of the model information.
Returns:
a dictionary that contains the bundle's information.
"""

bundles_info = _get_all_bundles_info(repo=repo, tag=tag, auth_token=auth_token)
bundles_info = _get_all_bundles_info(repo=repo, tag=tag, auth_token=auth_token, model_info_url=model_info_url)
if bundle_name not in bundles_info:
raise ValueError(f"bundle: {bundle_name} is not existing.")
bundle_info = bundles_info[bundle_name]
Expand All @@ -685,7 +742,7 @@ def get_bundle_info(
if version not in bundle_info:
raise ValueError(f"version: {version} of bundle: {bundle_name} is not existing.")

return bundle_info[version]
return bundle_info[version] # type: ignore[no-any-return]


@deprecated_arg("runner_id", since="1.1", removed="1.3", new_name="run_id", msg_suffix="please use `run_id` instead.")
Expand Down
47 changes: 45 additions & 2 deletions tests/test_bundle_get_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,16 @@

TEST_CASE_2 = [{"bundle_name": "spleen_ct_segmentation", "version": "0.1.0", "auth_token": None}]

TEST_CASE_FAKE_TOKEN = [{"bundle_name": "spleen_ct_segmentation", "version": "0.1.0", "auth_token": "ghp_errortoken"}]
TEST_CASE_FAKE_TOKEN_1 = [{"bundle_name": "spleen_ct_segmentation", "version": "0.1.0", "auth_token": "ghp_errortoken"}]

TEST_CASE_FAKE_TOKEN_2 = [
{
"bundle_name": "spleen_ct_segmentation",
"version": "0.1.0",
"auth_token": "ghp_errortoken",
"model_info_url": "https://raw.githubusercontent.com/Project-MONAI/model-zoo/dev/models/model_info.json",
}
]


@skip_if_windows
Expand All @@ -39,6 +48,16 @@ def test_get_all_bundles_list(self):
self.assertTrue(isinstance(output[0], tuple))
self.assertTrue(len(output[0]) == 2)

@skip_if_quick
def test_get_all_bundles_list_model_info_url(self):
with skip_if_downloading_fails():
output = get_all_bundles_list(
model_info_url="https://raw.githubusercontent.com/Project-MONAI/model-zoo/dev/models/model_info.json"
)
self.assertTrue(isinstance(output, list))
self.assertTrue(isinstance(output[0], tuple))
self.assertTrue(len(output[0]) == 2)

@parameterized.expand([TEST_CASE_1])
@skip_if_quick
def test_get_bundle_versions(self, params):
Expand All @@ -48,6 +67,18 @@ def test_get_bundle_versions(self, params):
self.assertTrue("latest_version" in output and "all_versions" in output)
self.assertTrue("0.1.0" in output["all_versions"])

@parameterized.expand([TEST_CASE_1])
@skip_if_quick
def test_get_bundle_versions_model_info_url(self, params):
with skip_if_downloading_fails():
output = get_bundle_versions(
model_info_url="https://raw.githubusercontent.com/Project-MONAI/model-zoo/dev/models/model_info.json",
**params,
)
self.assertTrue(isinstance(output, dict))
self.assertTrue("latest_version" in output and "all_versions" in output)
self.assertTrue("0.1.0" in output["all_versions"])

@parameterized.expand([TEST_CASE_1, TEST_CASE_2])
@skip_if_quick
def test_get_bundle_info(self, params):
Expand All @@ -57,7 +88,19 @@ def test_get_bundle_info(self, params):
for key in ["id", "name", "size", "download_count", "browser_download_url"]:
self.assertTrue(key in output)

@parameterized.expand([TEST_CASE_FAKE_TOKEN])
@parameterized.expand([TEST_CASE_1, TEST_CASE_2])
@skip_if_quick
def test_get_bundle_info_model_info_url(self, params):
with skip_if_downloading_fails():
output = get_bundle_info(
model_info_url="https://raw.githubusercontent.com/Project-MONAI/model-zoo/dev/models/model_info.json",
**params,
)
self.assertTrue(isinstance(output, dict))
for key in ["name", "browser_download_url"]:
self.assertTrue(key in output)

@parameterized.expand([TEST_CASE_FAKE_TOKEN_1, TEST_CASE_FAKE_TOKEN_2])
@skip_if_quick
def test_fake_token(self, params):
with skip_if_downloading_fails():
Expand Down

0 comments on commit 546d6c0

Please sign in to comment.