Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix transpose and patch coords bug #8047

Merged
merged 10 commits into from
Aug 28, 2024
9 changes: 5 additions & 4 deletions monai/apps/vista3d/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,8 @@ def sample_prompt_pairs(
max_backprompt: int, max number of prompt from background.
max_point: maximum number of points for each object.
include_background: if include 0 into training prompt. If included, background 0 is treated
the same as foreground. Always be False for multi-partial-dataset training. If needed,
can be true for finetuning specific dataset, .
the same as foreground and points will be sampled. Can be true only if user want to segment
background 0 with point clicks, otherwise always be false.
drop_label_prob: probability to drop label prompt.
drop_point_prob: probability to drop point prompt.
point_sampler: sampler to augment masks with supervoxel.
Expand All @@ -76,12 +76,13 @@ def sample_prompt_pairs(
Returns:
label_prompt: [B, 1]. The classes used for training automatic segmentation.
point: [B, N, 3]. The corresponding points for each class.
Note that background label prompt requires matching point as well ([0,0,0] is used).
Note that background label prompt requires matching point as well ([0,0,0] is used).
point_label: [B, N]. The corresponding point labels for each point (negative or positive).
-1 is used for padding the background label prompt and will be ignored.
-1 is used for padding the background label prompt and will be ignored.
prompt_class: [B, 1], exactly the same with label_prompt for label indexing for training loss.
label_prompt can be None, and prompt_class is used to identify point classes.
"""

# class label number
if not labels.shape[0] == 1:
raise ValueError("only support batch size 1")
Expand Down
7 changes: 5 additions & 2 deletions monai/networks/nets/vista3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,11 +336,11 @@ def set_auto_grad(self, auto_freeze: bool = False, point_freeze: bool = False):
def forward(
self,
input_images: torch.Tensor,
patch_coords: Sequence[slice] | None = None,
point_coords: torch.Tensor | None = None,
point_labels: torch.Tensor | None = None,
class_vector: torch.Tensor | None = None,
prompt_class: torch.Tensor | None = None,
patch_coords: Sequence[slice] | None = None,
labels: torch.Tensor | None = None,
label_set: Sequence[int] | None = None,
prev_mask: torch.Tensor | None = None,
Expand Down Expand Up @@ -421,7 +421,10 @@ def forward(
point_coords, point_labels = None, None

if point_coords is None and class_vector is None:
return self.NINF_VALUE + torch.zeros([bs, 1, *image_size], device=device)
logits = self.NINF_VALUE + torch.zeros([bs, 1, *image_size], device=device)
if transpose:
logits = logits.transpose(1, 0)
return logits

if self.image_embeddings is not None and kwargs.get("keep_cache", False) and class_vector is None:
out, out_auto = self.image_embeddings, None
Expand Down
Loading