Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix transpose and patch coords bug #8047

Merged
merged 10 commits into from
Aug 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -42,3 +42,4 @@ zarr
huggingface_hub
pyamg>=5.0.0
packaging
polygraphy
29 changes: 18 additions & 11 deletions monai/apps/vista3d/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,6 @@
import torch
from torch import Tensor

__all__ = ["sample_prompt_pairs"]

ENABLE_SPECIAL = True
SPECIAL_INDEX = (23, 24, 25, 26, 27, 57, 128)
MERGE_LIST = {
Expand All @@ -30,6 +28,8 @@
132: [57], # overlap with trachea merge into airway
}

__all__ = ["sample_prompt_pairs"]


def _get_point_label(id: int) -> tuple[int, int]:
if id in SPECIAL_INDEX and ENABLE_SPECIAL:
Expand Down Expand Up @@ -66,22 +66,29 @@ def sample_prompt_pairs(
max_backprompt: int, max number of prompt from background.
max_point: maximum number of points for each object.
include_background: if include 0 into training prompt. If included, background 0 is treated
the same as foreground. Always be False for multi-partial-dataset training. If needed,
can be true for finetuning specific dataset, .
the same as foreground and points will be sampled. Can be true only if user want to segment
background 0 with point clicks, otherwise always be false.
drop_label_prob: probability to drop label prompt.
drop_point_prob: probability to drop point prompt.
point_sampler: sampler to augment masks with supervoxel.
point_sampler_kwargs: arguments for point_sampler.

Returns:
label_prompt: [B, 1]. The classes used for training automatic segmentation.
point: [B, N, 3]. The corresponding points for each class.
Note that background label prompt requires matching point as well ([0,0,0] is used).
point_label: [B, N]. The corresponding point labels for each point (negative or positive).
-1 is used for padding the background label prompt and will be ignored.
prompt_class: [B, 1], exactly the same with label_prompt for label indexing for training loss.
label_prompt can be None, and prompt_class is used to identify point classes.
tuple:
- label_prompt (Tensor | None): Tensor of shape [B, 1] containing the classes used for
training automatic segmentation.
- point (Tensor | None): Tensor of shape [B, N, 3] representing the corresponding points
for each class. Note that background label prompts require matching points as well
(e.g., [0, 0, 0] is used).
- point_label (Tensor | None): Tensor of shape [B, N] representing the corresponding point
labels for each point (negative or positive). -1 is used for padding the background
label prompt and will be ignored.
- prompt_class (Tensor | None): Tensor of shape [B, 1], exactly the same as label_prompt
for label indexing during training. If label_prompt is None, prompt_class is used to
identify point classes.

"""

# class label number
if not labels.shape[0] == 1:
raise ValueError("only support batch size 1")
Expand Down
7 changes: 5 additions & 2 deletions monai/networks/nets/vista3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,11 +336,11 @@ def set_auto_grad(self, auto_freeze: bool = False, point_freeze: bool = False):
def forward(
self,
input_images: torch.Tensor,
patch_coords: Sequence[slice] | None = None,
point_coords: torch.Tensor | None = None,
point_labels: torch.Tensor | None = None,
class_vector: torch.Tensor | None = None,
prompt_class: torch.Tensor | None = None,
patch_coords: Sequence[slice] | None = None,
labels: torch.Tensor | None = None,
label_set: Sequence[int] | None = None,
prev_mask: torch.Tensor | None = None,
Expand Down Expand Up @@ -421,7 +421,10 @@ def forward(
point_coords, point_labels = None, None

if point_coords is None and class_vector is None:
return self.NINF_VALUE + torch.zeros([bs, 1, *image_size], device=device)
logits = self.NINF_VALUE + torch.zeros([bs, 1, *image_size], device=device)
if transpose:
logits = logits.transpose(1, 0)
return logits

if self.image_embeddings is not None and kwargs.get("keep_cache", False) and class_vector is None:
out, out_auto = self.image_embeddings, None
Expand Down
Loading