Skip to content

Commit 8015ad5

Browse files
林旻佑林旻佑
authored andcommitted
BUGFIX: support NDHWC input in sliding_window_inference and DiceMetric
1 parent d388d1c commit 8015ad5

File tree

2 files changed

+15
-0
lines changed

2 files changed

+15
-0
lines changed

monai/inferers/utils.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ def sliding_window_inference(
6060
*args: Any,
6161
**kwargs: Any,
6262
) -> torch.Tensor | tuple[torch.Tensor, ...] | dict[Any, torch.Tensor]:
63+
6364
"""
6465
Sliding window inference on `inputs` with `predictor`.
6566
@@ -134,6 +135,14 @@ def sliding_window_inference(
134135
- input must be channel-first and have a batch dim, supports N-D sliding window.
135136
136137
"""
138+
139+
# auto transform (N,D,H,W,C) → (N,C,D,H,W)
140+
if isinstance(inputs, torch.Tensor) and inputs.ndim == 5 and inputs.shape[-1] in (1, 3, 4):
141+
inputs = inputs.permute(0, 4, 1, 2, 3).contiguous()
142+
143+
144+
145+
137146
buffered = buffer_steps is not None and buffer_steps > 0
138147
num_spatial_dims = len(inputs.shape) - 2
139148
if buffered:

monai/metrics/meandice.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,12 @@ def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor
134134
Raises:
135135
ValueError: when `y_pred` has fewer than three dimensions.
136136
"""
137+
138+
if isinstance(y_pred, torch.Tensor) and y_pred.ndim == 5 and y_pred.shape[-1] in (1, 3, 4):
139+
y_pred = y_pred.permute(0, 4, 1, 2, 3).contiguous()
140+
if isinstance(y, torch.Tensor) and y.ndim == 5 and y.shape[-1] in (1, 3, 4):
141+
y = y.permute(0, 4, 1, 2, 3).contiguous()
142+
137143
dims = y_pred.ndimension()
138144
if dims < 3:
139145
raise ValueError(f"y_pred should have at least 3 dimensions (batch, channel, spatial), got {dims}.")

0 commit comments

Comments
 (0)