Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support instance masks (N,H,W) #2856

Merged
merged 5 commits into from
Apr 2, 2024
Merged

Conversation

ashnair1
Copy link
Contributor

@ashnair1 ashnair1 commented Mar 25, 2024

Changes

Fixes #2855 related to #941

Type of change

  • 🧪 Tests Cases
  • 🐞 Bug fix (non-breaking change which fixes an issue)

Checklist

  • My code follows the style guidelines of this project
  • I have performed a self-review of my own code
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • Did you update CHANGELOG in case of a major change?

@ashnair1 ashnair1 marked this pull request as draft March 25, 2024 09:15
@ashnair1 ashnair1 marked this pull request as ready for review March 25, 2024 11:23
@johnnv1 johnnv1 requested a review from shijianjian March 25, 2024 23:57
Copy link
Member

@shijianjian shijianjian left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Meanwhile, what we did for the boxes is to merge them all into one tensor, which is

def _merge_box_list(boxes: list[torch.Tensor], method: str = "pad") -> tuple[torch.Tensor, list[int]]:
r"""Merge a list of boxes into one tensor."""
if not all(box.shape[-2:] == torch.Size([4, 2]) and box.dim() == 3 for box in boxes):
raise TypeError(f"Input boxes must be a list of (N, 4, 2) shaped. Got: {[box.shape for box in boxes]}.")
if method == "pad":
max_N = max(box.shape[0] for box in boxes)
stats = [max_N - box.shape[0] for box in boxes]
output = torch.nn.utils.rnn.pad_sequence(boxes, batch_first=True)
else:
raise NotImplementedError(f"`{method}` is not implemented.")
return output, stats

Not sure if we should do the same to the masks. Can anyone benchmark on which approach runs faster? Iteration or batching?

kornia/augmentation/container/ops.py Show resolved Hide resolved
@ashnair1
Copy link
Contributor Author

ashnair1 commented Mar 29, 2024

Merging a list of tensors into tensors will not work for instance segmentation. This is because each image will have a different number of objects and so different number of masks. Because of this, they cannot be batched.

This is how it would look like:
Single sample: Image (C, H, W) and Mask (N, H, W)
Batched: Image (B, C, H, W) and Mask ([(N1, H, W), (N2, H, W) ... (Nb, H, W)])

N (no. of detections) varies from image to image, so they cannot be batched.

Refer to #2417 for a proper example of this use case.

@ashnair1 ashnair1 requested a review from shijianjian March 29, 2024 19:39
@shijianjian
Copy link
Member

Single sample: Image (C, H, W) and Mask (N, H, W)
Batched: Image (B, C, H, W) and Mask ([(N1, H, W), (N2, H, W) ... (Nb, H, W)])

I meant that

N_max = max(N1, ..., Nb)
padding = [N_max - N1, ..., N_max - Nb]
padded_masks = (B, N_max, H, W)

... # After augmentation

Unpad (B, N_max, H, W) => Mask ([(N1, H, W), (N2, H, W) ... (Nb, H, W)])

The current change is fine in this PR I think. We should benchmark the iterative and batching strategies to see how those methods perform. Maybe a Mask class, similar to Boxes, to handle these.

@shijianjian shijianjian merged commit 2c761ee into kornia:main Apr 2, 2024
27 checks passed
@ashnair1 ashnair1 deleted the msk-dim-fix branch April 2, 2024 13:49
cjpurackal pushed a commit to cjpurackal/kornia that referenced this pull request May 18, 2024
* Fix for shape error in transform_masks

* Iterate over batch_prob in transform_list

* Run ruff

* Revert prev changes

* Update test case
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

AugmentationSequential does not support instance masks shape (N, H, W)
4 participants