Skip to content

Commit

Permalink
Fix RandomWeightedCrop for Integer Weightmap Handling (#8097)
Browse files Browse the repository at this point in the history
Fixes #7949 .

### Description
Regardless of the type of `weight map`, random numbers should be kept as
floating-point numbers for calculating the sampling location. However,
`searchsorted` requires matching data structures. I have modified
`convert_to_dst_type` to control converting only the data structure
while maintaining the original data type. Additionally, I have included
an example with integer weight maps in the test file.

### Types of changes
<!--- Put an `x` in all the boxes that apply, and remove the not
applicable items -->
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [ ] New tests added to cover the changes.
- [ ] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [ ] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [ ] In-line docstrings updated.
- [ ] Documentation updated, tested `make html` command in the `docs/`
folder.

---------

Signed-off-by: Han123su <popsmall212@gmail.com>
Signed-off-by: Han123su <107395380+Han123su@users.noreply.github.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
Han123su and pre-commit-ci[bot] authored Sep 20, 2024
1 parent d2d492e commit fa1c1af
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 1 deletion.
3 changes: 2 additions & 1 deletion monai/transforms/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -582,7 +582,8 @@ def weighted_patch_samples(
if not v[-1] or not isfinite(v[-1]) or v[-1] < 0: # uniform sampling
idx = r_state.randint(0, len(v), size=n_samples)
else:
r, *_ = convert_to_dst_type(r_state.random(n_samples), v)
r_samples = r_state.random(n_samples)
r, *_ = convert_to_dst_type(r_samples, v, dtype=r_samples.dtype)
idx = searchsorted(v, r * v[-1], right=True) # type: ignore
idx, *_ = convert_to_dst_type(idx, v, dtype=torch.int) # type: ignore
# compensate 'valid' mode
Expand Down
30 changes: 30 additions & 0 deletions tests/test_rand_weighted_crop.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,21 @@ def get_data(ndim):
[[63, 37], [31, 43], [66, 20]],
]
)
im = SEG1_2D
weight_map = np.zeros_like(im, dtype=np.int32)
weight_map[0, 30, 20] = 3
weight_map[0, 45, 44] = 1
weight_map[0, 60, 50] = 2
TESTS.append(
[
"int w 2d",
dict(spatial_size=(10, 12), num_samples=3),
p(im),
q(weight_map),
(1, 10, 12),
[[60, 50], [30, 20], [45, 44]],
]
)
im = SEG1_3D
weight = np.zeros_like(im)
weight[0, 5, 30, 17] = 1.1
Expand Down Expand Up @@ -149,6 +164,21 @@ def get_data(ndim):
[[32, 24, 40], [32, 24, 40], [32, 24, 40]],
]
)
im = SEG1_3D
weight_map = np.zeros_like(im, dtype=np.int32)
weight_map[0, 6, 22, 19] = 4
weight_map[0, 8, 40, 31] = 2
weight_map[0, 13, 20, 24] = 3
TESTS.append(
[
"int w 3d",
dict(spatial_size=(8, 10, 12), num_samples=3),
p(im),
q(weight_map),
(1, 8, 10, 12),
[[13, 20, 24], [6, 22, 19], [8, 40, 31]],
]
)


class TestRandWeightedCrop(CropTest):
Expand Down

0 comments on commit fa1c1af

Please sign in to comment.