Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
Signed-off-by: Wenqi Li <wenqil@nvidia.com>
  • Loading branch information
wyli committed Jul 29, 2023
1 parent 5feb353 commit 2016f83
Show file tree
Hide file tree
Showing 5 changed files with 15 additions and 1 deletion.
1 change: 1 addition & 0 deletions monai/apps/pathology/inferers/inferer.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,7 @@ def __call__(
self.process_output,
self.buffer_steps,
self.buffer_dim,
False,
*args,
**kwargs,
)
Expand Down
5 changes: 5 additions & 0 deletions monai/inferers/inferer.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,6 +426,8 @@ class SlidingWindowInferer(Inferer):
(i.e. no overlapping among the buffers) non_blocking copy may be automatically enabled for efficiency.
buffer_dim: the spatial dimension along which the buffers are created.
0 indicates the first spatial dimension. Default is -1, the last spatial dimension.
with_coord: whether to pass the window coordinates to ``network``. Defaults to False.
If True, the ``network``'s 2nd input argument should accept the window coordinates.
Note:
``sw_batch_size`` denotes the max number of windows per network inference iteration,
Expand All @@ -449,6 +451,7 @@ def __init__(
cpu_thresh: int | None = None,
buffer_steps: int | None = None,
buffer_dim: int = -1,
with_coord: bool = False,
) -> None:
super().__init__()
self.roi_size = roi_size
Expand All @@ -464,6 +467,7 @@ def __init__(
self.cpu_thresh = cpu_thresh
self.buffer_steps = buffer_steps
self.buffer_dim = buffer_dim
self.with_coord = with_coord

# compute_importance_map takes long time when computing on cpu. We thus
# compute it once if it's static and then save it for future usage
Expand Down Expand Up @@ -525,6 +529,7 @@ def __call__(
None,
buffer_steps,
buffer_dim,
self.with_coord,
*args,
**kwargs,
)
Expand Down
8 changes: 7 additions & 1 deletion monai/inferers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def sliding_window_inference(
process_fn: Callable | None = None,
buffer_steps: int | None = None,
buffer_dim: int = -1,
with_coord: bool = False,
*args: Any,
**kwargs: Any,
) -> torch.Tensor | tuple[torch.Tensor, ...] | dict[Any, torch.Tensor]:
Expand Down Expand Up @@ -125,6 +126,8 @@ def sliding_window_inference(
(i.e. no overlapping among the buffers) non_blocking copy may be automatically enabled for efficiency.
buffer_dim: the spatial dimension along which the buffers are created.
0 indicates the first spatial dimension. Default is -1, the last spatial dimension.
with_coord: whether to pass the window coordinates to ``predictor``. Default is False.
If True, the signature of ``predictor`` should be ``predictor(patch_data, patch_coord, ...)``.
args: optional args to be passed to ``predictor``.
kwargs: optional keyword args to be passed to ``predictor``.
Expand Down Expand Up @@ -220,7 +223,10 @@ def sliding_window_inference(
win_data = torch.cat([inputs[win_slice] for win_slice in unravel_slice]).to(sw_device)
else:
win_data = inputs[unravel_slice[0]].to(sw_device)
seg_prob_out = predictor(win_data, *args, **kwargs) # batched patch
if with_coord:
seg_prob_out = predictor(win_data, unravel_slice, *args, **kwargs) # batched patch
else:
seg_prob_out = predictor(win_data, *args, **kwargs) # batched patch

# convert seg_prob_out to tuple seg_tuple, this does not allocate new memory.
dict_keys, seg_tuple = _flatten_struct(seg_prob_out)
Expand Down
1 change: 1 addition & 0 deletions tests/test_sliding_window_hovernet_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,7 @@ def compute(data, test1, test2):
None,
None,
0,
False,
t1,
test2=t2,
)
Expand Down
1 change: 1 addition & 0 deletions tests/test_sliding_window_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,7 @@ def compute(data, test1, test2):
None,
None,
0,
False,
t1,
test2=t2,
)
Expand Down

0 comments on commit 2016f83

Please sign in to comment.