diff --git a/docs/requirements.txt b/docs/requirements.txt index ff94f7b6de..7307d8e5f9 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -42,3 +42,4 @@ zarr huggingface_hub pyamg>=5.0.0 packaging +polygraphy diff --git a/monai/apps/vista3d/sampler.py b/monai/apps/vista3d/sampler.py index b7aeb89a2e..17b2d34911 100644 --- a/monai/apps/vista3d/sampler.py +++ b/monai/apps/vista3d/sampler.py @@ -20,8 +20,6 @@ import torch from torch import Tensor -__all__ = ["sample_prompt_pairs"] - ENABLE_SPECIAL = True SPECIAL_INDEX = (23, 24, 25, 26, 27, 57, 128) MERGE_LIST = { @@ -30,6 +28,8 @@ 132: [57], # overlap with trachea merge into airway } +__all__ = ["sample_prompt_pairs"] + def _get_point_label(id: int) -> tuple[int, int]: if id in SPECIAL_INDEX and ENABLE_SPECIAL: @@ -66,22 +66,29 @@ def sample_prompt_pairs( max_backprompt: int, max number of prompt from background. max_point: maximum number of points for each object. include_background: if include 0 into training prompt. If included, background 0 is treated - the same as foreground. Always be False for multi-partial-dataset training. If needed, - can be true for finetuning specific dataset, . + the same as foreground and points will be sampled. Can be true only if user want to segment + background 0 with point clicks, otherwise always be false. drop_label_prob: probability to drop label prompt. drop_point_prob: probability to drop point prompt. point_sampler: sampler to augment masks with supervoxel. point_sampler_kwargs: arguments for point_sampler. Returns: - label_prompt: [B, 1]. The classes used for training automatic segmentation. - point: [B, N, 3]. The corresponding points for each class. - Note that background label prompt requires matching point as well ([0,0,0] is used). - point_label: [B, N]. The corresponding point labels for each point (negative or positive). - -1 is used for padding the background label prompt and will be ignored. - prompt_class: [B, 1], exactly the same with label_prompt for label indexing for training loss. - label_prompt can be None, and prompt_class is used to identify point classes. + tuple: + - label_prompt (Tensor | None): Tensor of shape [B, 1] containing the classes used for + training automatic segmentation. + - point (Tensor | None): Tensor of shape [B, N, 3] representing the corresponding points + for each class. Note that background label prompts require matching points as well + (e.g., [0, 0, 0] is used). + - point_label (Tensor | None): Tensor of shape [B, N] representing the corresponding point + labels for each point (negative or positive). -1 is used for padding the background + label prompt and will be ignored. + - prompt_class (Tensor | None): Tensor of shape [B, 1], exactly the same as label_prompt + for label indexing during training. If label_prompt is None, prompt_class is used to + identify point classes. + """ + # class label number if not labels.shape[0] == 1: raise ValueError("only support batch size 1") diff --git a/monai/networks/nets/vista3d.py b/monai/networks/nets/vista3d.py index 9148e36542..979a090df0 100644 --- a/monai/networks/nets/vista3d.py +++ b/monai/networks/nets/vista3d.py @@ -336,11 +336,11 @@ def set_auto_grad(self, auto_freeze: bool = False, point_freeze: bool = False): def forward( self, input_images: torch.Tensor, + patch_coords: Sequence[slice] | None = None, point_coords: torch.Tensor | None = None, point_labels: torch.Tensor | None = None, class_vector: torch.Tensor | None = None, prompt_class: torch.Tensor | None = None, - patch_coords: Sequence[slice] | None = None, labels: torch.Tensor | None = None, label_set: Sequence[int] | None = None, prev_mask: torch.Tensor | None = None, @@ -421,7 +421,10 @@ def forward( point_coords, point_labels = None, None if point_coords is None and class_vector is None: - return self.NINF_VALUE + torch.zeros([bs, 1, *image_size], device=device) + logits = self.NINF_VALUE + torch.zeros([bs, 1, *image_size], device=device) + if transpose: + logits = logits.transpose(1, 0) + return logits if self.image_embeddings is not None and kwargs.get("keep_cache", False) and class_vector is None: out, out_auto = self.image_embeddings, None