Skip to content

Commit

Permalink
5770 add remove prefix to monai.bundle.load (#5771)
Browse files Browse the repository at this point in the history
Signed-off-by: Yiheng Wang <vennw@nvidia.com>

Fixes #5770 .

### Description

This PR enables the `remove_prefix` arg in `monai.bundle.load`

### Types of changes
<!--- Put an `x` in all the boxes that apply, and remove the not
applicable items -->
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [ ] New tests added to cover the changes.
- [ ] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [ ] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [ ] In-line docstrings updated.
- [ ] Documentation updated, tested `make html` command in the `docs/`
folder.

Signed-off-by: Yiheng Wang <vennw@nvidia.com>
  • Loading branch information
yiheng-wang-nv authored Dec 19, 2022
1 parent 8037fcd commit b159ce7
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 3 deletions.
19 changes: 18 additions & 1 deletion monai/bundle/scripts.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,7 @@ def load(
bundle_dir: Optional[PathLike] = None,
source: str = download_source,
repo: Optional[str] = None,
remove_prefix: Optional[str] = "monai_",
progress: bool = True,
device: Optional[str] = None,
key_in_ckpt: Optional[str] = None,
Expand Down Expand Up @@ -356,6 +357,10 @@ def load(
it should be "ngc" or "github".
repo: repo name. This argument is used when `url` is `None` and `source` is "github".
If used, it should be in the form of "repo_owner/repo_name/release_tag".
remove_prefix: This argument is used when `source` is "ngc". Currently, all ngc bundles
have the ``monai_`` prefix, which is not existing in their model zoo contrasts. In order to
maintain the consistency between these two sources, remove prefix is necessary.
Therefore, if specified, downloaded folder name will remove the prefix.
progress: whether to display a progress bar when downloading.
device: target device of returned weights or module, if `None`, prefer to "cuda" if existing.
key_in_ckpt: for nested checkpoint like `{"model": XXX, "optimizer": XXX, ...}`, specify the key of model
Expand All @@ -379,9 +384,21 @@ def load(

if model_file is None:
model_file = os.path.join("models", "model.ts" if load_ts_module is True else "model.pt")
if source == "ngc":
name = _add_ngc_prefix(name)
if remove_prefix:
name = _remove_ngc_prefix(name, prefix=remove_prefix)
full_path = os.path.join(bundle_dir_, name, model_file)
if not os.path.exists(full_path):
download(name=name, version=version, bundle_dir=bundle_dir_, source=source, repo=repo, progress=progress)
download(
name=name,
version=version,
bundle_dir=bundle_dir_,
source=source,
repo=repo,
remove_prefix=remove_prefix,
progress=progress,
)

if device is None:
device = "cuda:0" if is_available() else "cpu"
Expand Down
35 changes: 33 additions & 2 deletions tests/ngc_bundle_download.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@

from monai.apps import check_hash
from monai.apps.mmars import MODEL_DESC, load_from_mmar
from monai.bundle import download
from monai.bundle import download, load
from monai.config import print_debug_info
from monai.networks.utils import copy_model_state
from tests.utils import skip_if_downloading_fails, skip_if_quick, skip_if_windows
from tests.utils import assert_allclose, skip_if_downloading_fails, skip_if_quick, skip_if_windows

TEST_CASE_NGC_1 = [
"spleen_ct_segmentation",
Expand All @@ -41,6 +41,30 @@
"b418a2dc8672ce2fd98dc255036e7a3d",
]

TESTCASE_WEIGHTS = {
"key": "model.0.conv.unit0.adn.N.bias",
"value": torch.tensor(
[
-0.0705,
-0.0937,
-0.0422,
-0.2068,
0.1023,
-0.2007,
-0.0883,
0.0018,
-0.1719,
0.0116,
0.0285,
-0.0044,
0.1223,
-0.1287,
-0.1858,
0.0460,
]
),
}


@skip_if_windows
class TestNgcBundleDownload(unittest.TestCase):
Expand All @@ -56,6 +80,13 @@ def test_ngc_download_bundle(self, bundle_name, version, remove_prefix, download
self.assertTrue(os.path.exists(full_file_path))
self.assertTrue(check_hash(filepath=full_file_path, val=hash_val))

weights = load(
name=bundle_name, source="ngc", version=version, bundle_dir=tempdir, remove_prefix=remove_prefix
)
assert_allclose(
weights[TESTCASE_WEIGHTS["key"]], TESTCASE_WEIGHTS["value"], atol=1e-4, rtol=1e-4, type_test=False
)


@unittest.skip("deprecating mmar tests")
class TestAllDownloadingMMAR(unittest.TestCase):
Expand Down

0 comments on commit b159ce7

Please sign in to comment.