diff --git a/monai/apps/vista3d/inferer.py b/monai/apps/vista3d/inferer.py index 709f81f624..8f622ef6cd 100644 --- a/monai/apps/vista3d/inferer.py +++ b/monai/apps/vista3d/inferer.py @@ -100,7 +100,7 @@ def point_based_window_inferer( point_labels=point_labels, class_vector=class_vector, prompt_class=prompt_class, - patch_coords=unravel_slice, + patch_coords=[unravel_slice], prev_mask=prev_mask, **kwargs, ) diff --git a/monai/networks/nets/vista3d.py b/monai/networks/nets/vista3d.py index 979a090df0..4215a9a594 100644 --- a/monai/networks/nets/vista3d.py +++ b/monai/networks/nets/vista3d.py @@ -336,7 +336,7 @@ def set_auto_grad(self, auto_freeze: bool = False, point_freeze: bool = False): def forward( self, input_images: torch.Tensor, - patch_coords: Sequence[slice] | None = None, + patch_coords: list[Sequence[slice]] | None = None, point_coords: torch.Tensor | None = None, point_labels: torch.Tensor | None = None, class_vector: torch.Tensor | None = None, @@ -364,8 +364,12 @@ def forward( the points are for zero-shot or supported class. When class_vector and point_coords are both provided, prompt_class is the same as class_vector. For prompt_class[b] > 512, point_coords[b] will be considered novel class. - patch_coords: a sequence of the python slice objects representing the patch coordinates during sliding window inference. - This value is passed from sliding_window_inferer. This is an indicator for training phase or validation phase. + patch_coords: a list of sequence of the python slice objects representing the patch coordinates during sliding window + inference. This value is passed from sliding_window_inferer. + This is an indicator for training phase or validation phase. + Notice for sliding window batch size > 1 (only supported by automatic segmentation), patch_coords will inlcude + coordinates of multiple patches. If point prompts are included, the batch size can only be one and all the + functions using patch_coords will by default use patch_coords[0]. labels: [1, 1, H, W, D], the groundtruth label tensor, only used for point-only evaluation label_set: the label index matching the indexes in labels. If labels are mapped to global index using RelabelID, this label_set should be global mapped index. If labels are not mapped to global index, e.g. in zero-shot @@ -395,14 +399,14 @@ def forward( if val_point_sampler is None: # TODO: think about how to refactor this part. val_point_sampler = self.sample_points_patch_val - point_coords, point_labels, prompt_class = val_point_sampler(labels, patch_coords, label_set) + point_coords, point_labels, prompt_class = val_point_sampler(labels, patch_coords[0], label_set) if prompt_class[0].item() == 0: # type: ignore point_labels[0] = -1 # type: ignore labels, prev_mask = None, None elif point_coords is not None: # If not performing patch-based point only validation, use user provided click points for inference. # the point clicks is in original image space, convert it to current patch-coordinate space. - point_coords, point_labels = self.update_point_to_patch(patch_coords, point_coords, point_labels) # type: ignore + point_coords, point_labels = self.update_point_to_patch(patch_coords[0], point_coords, point_labels) # type: ignore if point_coords is not None and point_labels is not None: # remove points that used for padding purposes (point_label = -1) @@ -455,7 +459,7 @@ def forward( logits[mapping_index] = self.point_head(out, point_coords, point_labels, class_vector=prompt_class) if prev_mask is not None and patch_coords is not None: logits = self.connected_components_combine( - prev_mask[patch_coords].transpose(1, 0).to(logits.device), + prev_mask[patch_coords[0]].transpose(1, 0).to(logits.device), logits[mapping_index], point_coords, # type: ignore point_labels, # type: ignore