From 77b175986a063b82f0147eb311579003b4ed3569 Mon Sep 17 00:00:00 2001 From: YunLiu <55491388+KumoLiu@users.noreply.github.com> Date: Fri, 20 Oct 2023 16:59:04 +0800 Subject: [PATCH] Using keys in .npz files for `NumpyReader` (#7148) Fixes #7147. ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: KumoLiu --- monai/data/image_reader.py | 2 +- tests/test_numpy_reader.py | 36 +++++++++++++++++++++++------------- 2 files changed, 24 insertions(+), 14 deletions(-) diff --git a/monai/data/image_reader.py b/monai/data/image_reader.py index ce77c549ce..4c7f2c8c3b 100644 --- a/monai/data/image_reader.py +++ b/monai/data/image_reader.py @@ -1068,7 +1068,7 @@ def read(self, data: Sequence[PathLike] | PathLike, **kwargs): img = np.load(name, allow_pickle=True, **kwargs_) if Path(name).name.endswith(".npz"): # load expected items from NPZ file - npz_keys = [f"arr_{i}" for i in range(len(img))] if self.npz_keys is None else self.npz_keys + npz_keys = list(img.keys()) if self.npz_keys is None else self.npz_keys for k in npz_keys: img_.append(img[k]) else: diff --git a/tests/test_numpy_reader.py b/tests/test_numpy_reader.py index 393f613163..eeff2922ad 100644 --- a/tests/test_numpy_reader.py +++ b/tests/test_numpy_reader.py @@ -19,7 +19,7 @@ import numpy as np from monai.data import DataLoader, Dataset, NumpyReader -from monai.transforms import LoadImaged +from monai.transforms import LoadImage, LoadImaged from tests.utils import assert_allclose @@ -97,22 +97,32 @@ def test_kwargs(self): def test_dataloader(self): test_data = np.random.randint(0, 256, size=[3, 4, 5]) - datalist = [] + datalist_dict, datalist_array = [], [] with tempfile.TemporaryDirectory() as tempdir: for i in range(4): filepath = os.path.join(tempdir, f"test_data{i}.npz") np.savez(filepath, test_data) - datalist.append({"image": filepath}) - - num_workers = 2 if sys.platform == "linux" else 0 - loader = DataLoader( - Dataset(data=datalist, transform=LoadImaged(keys="image", reader=NumpyReader())), - batch_size=2, - num_workers=num_workers, - ) - for d in loader: - for c in d["image"]: - assert_allclose(c, test_data, type_test=False) + datalist_dict.append({"image": filepath}) + datalist_array.append(filepath) + + num_workers = 2 if sys.platform == "linux" else 0 + loader = DataLoader( + Dataset(data=datalist_dict, transform=LoadImaged(keys="image", reader=NumpyReader())), + batch_size=2, + num_workers=num_workers, + ) + for d in loader: + for c in d["image"]: + assert_allclose(c, test_data, type_test=False) + + loader = DataLoader( + Dataset(data=datalist_array, transform=LoadImage(reader=NumpyReader())), + batch_size=2, + num_workers=num_workers, + ) + for d in loader: + for c in d: + assert_allclose(c, test_data, type_test=False) def test_channel_dim(self): test_data = np.random.randint(0, 256, size=[3, 4, 5, 2])