Skip to content

Commit

Permalink
Add masked average pooling for pooling with segmentation masks (DetCo…
Browse files Browse the repository at this point in the history
…n) (#1739)

* add masked average pooling for DetCon

* add tests for masked average pooling

* skip tests for torch < 1.12

* fix formatting

* add return type hints

* add return type hints to tests

* Update lightly/models/utils.py

Co-authored-by: guarin <43336610+guarin@users.noreply.github.com>

* Update lightly/models/utils.py

Co-authored-by: guarin <43336610+guarin@users.noreply.github.com>

* change tensor shape doc convention

* call batched function instead of separate implementation

* change to new_zeros()

* remove packaging dependency

* Update tests/models/test_ModelUtils.py

Co-authored-by: guarin <43336610+guarin@users.noreply.github.com>

* squeeze output of unbatched function

* fix scatterops available

* add type hints to tests

* make test more readable

* remove unnecessary import

* rename pooling function

---------

Co-authored-by: guarin <43336610+guarin@users.noreply.github.com>
  • Loading branch information
liopeer and guarin authored Nov 21, 2024
1 parent 8781c33 commit 48ece5e
Show file tree
Hide file tree
Showing 2 changed files with 173 additions and 0 deletions.
57 changes: 57 additions & 0 deletions lightly/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,63 @@
from timm.models.vision_transformer import VisionTransformer


def pool_masked(
source: Tensor, mask: Tensor, reduce: str = "mean", num_cls: Optional[int] = None
) -> Tensor:
"""Reduce image feature maps (B, C, H, W) or (C, H, W) according to an integer
index given by `mask` (B, H, W) or (H, W).
Args:
source: Float tensor of shape (B, C, H, W) or (C, H, W) to be reduced.
mask: Integer tensor of shape (B, H, W) or (H, W) containing the integer indices.
reduce: The reduction operation to be applied, one of 'prod', 'mean', 'amax' or
'amin'. Defaults to 'mean'.
num_cls: The number of classes in the possible masks. If None, the number of classes
is inferred from the unique elements in `mask`. This is useful when not all
classes are present in the mask.
Returns:
A tensor of shape (B, C, N) or (C, N) where N is the number of unique elements
in `mask` or `num_cls` if specified.
"""
if source.dim() == 3:
return _mask_reduce(source, mask, reduce, num_cls)
elif source.dim() == 4:
return _mask_reduce_batched(source, mask, num_cls)
else:
raise ValueError("source must have 3 or 4 dimensions")


def _mask_reduce(
source: Tensor, mask: Tensor, reduce: str = "mean", num_cls: Optional[int] = None
) -> Tensor:
output = _mask_reduce_batched(
source.unsqueeze(0), mask.unsqueeze(0), num_cls=num_cls
)
return output.squeeze(0)


def _mask_reduce_batched(
source: Tensor, mask: Tensor, num_cls: Optional[int] = None
) -> Tensor:
b, c, h, w = source.shape
if num_cls is None:
cls = mask.unique(sorted=True)
else:
cls = torch.arange(num_cls, device=mask.device)
num_cls = cls.size(0)
# create output tensor
output = source.new_zeros((b, c, num_cls)) # (B C N)
mask = mask.unsqueeze(1).expand(-1, c, -1, -1).view(b, c, -1) # (B C HW)
source = source.view(b, c, -1) # (B C HW)
output.scatter_reduce_(
dim=2, index=mask, src=source, reduce="mean", include_self=False
) # (B C N)
# scatter_reduce_ produces NaNs if the count is zero
output = torch.nan_to_num(output, nan=0.0)
return output


@torch.no_grad()
def batch_shuffle(
batch: torch.Tensor, distributed: bool = False
Expand Down
116 changes: 116 additions & 0 deletions tests/models/test_ModelUtils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,132 @@

from lightly.models import utils
from lightly.models.utils import (
_mask_reduce,
_mask_reduce_batched,
_no_grad_trunc_normal,
activate_requires_grad,
batch_shuffle,
batch_unshuffle,
deactivate_requires_grad,
nearest_neighbors,
normalize_weight,
pool_masked,
update_momentum,
)

is_scatter_reduce_available = hasattr(Tensor, "scatter_reduce_")


@pytest.mark.skipif(
not is_scatter_reduce_available,
reason="scatter operations require torch >= 1.12.0",
)
class TestMaskReduce:
@pytest.fixture()
def mask1(self) -> Tensor:
return torch.tensor([[0, 0], [1, 2]], dtype=torch.int64)

@pytest.fixture()
def mask2(self) -> Tensor:
return torch.tensor([[1, 0], [0, 1]], dtype=torch.int64)

@pytest.fixture()
def feature_map1(self) -> Tensor:
feature_map = torch.tensor(
[[[0, 1], [2, 3]], [[4, 5], [6, 7]], [[8, 9], [10, 11]]],
dtype=torch.float32,
) # (C H W) = (3, 2, 2)
return feature_map

@pytest.fixture()
def feature_map2(self) -> Tensor:
feature_map = torch.tensor(
[[[1, 2], [3, 4]], [[5, 6], [7, 8]], [[9, 10], [11, 12]]],
dtype=torch.float32,
) # (C H W) = (3, 2, 2)
return feature_map

@pytest.fixture()
def expected_result1(self) -> Tensor:
res = torch.tensor(
[[0.5, 2.0, 3.0], [4.5, 6.0, 7.0], [8.5, 10.0, 11.0]], dtype=torch.float32
)
return res

@pytest.fixture()
def expected_result2(self) -> Tensor:
res = torch.tensor(
[[2.5, 2.5, 0.0], [6.5, 6.5, 0.0], [10.5, 10.5, 0.0]], dtype=torch.float32
)
return res

def test__mask_reduce_batched(
self,
feature_map1: Tensor,
feature_map2: Tensor,
mask1: Tensor,
mask2: Tensor,
expected_result1: Tensor,
expected_result2: Tensor,
) -> None:
feature_map = torch.stack([feature_map1, feature_map2], dim=0)
mask = torch.stack([mask1, mask2], dim=0)
expected_result = torch.stack([expected_result1, expected_result2], dim=0)

out = _mask_reduce_batched(feature_map, mask, num_cls=3)
assert (out == expected_result).all()

def test_masked_pooling_manual(
self, feature_map2: Tensor, mask2: Tensor, expected_result2: Tensor
) -> None:
out_manual = pool_masked(
feature_map2.unsqueeze(0), mask2.unsqueeze(0), num_cls=2
)
assert out_manual.shape == (1, 3, 2)
assert (out_manual == expected_result2[:, :2]).all()

def test_masked_pooling_auto(
self, feature_map2: Tensor, mask2: Tensor, expected_result2: Tensor
) -> None:
out_auto = pool_masked(
feature_map2.unsqueeze(0), mask2.unsqueeze(0), num_cls=None
)
assert out_auto.shape == (1, 3, 2)
assert (out_auto == expected_result2[:, :2]).all()

@pytest.mark.parametrize(
"feature_map, mask, expected_result",
[
(
torch.tensor(
[[[0, 1], [2, 3]], [[4, 5], [6, 7]], [[8, 9], [10, 11]]],
dtype=torch.float32,
),
torch.tensor([[0, 0], [1, 2]], dtype=torch.int64),
torch.tensor(
[[0.5, 2.0, 3.0], [4.5, 6.0, 7.0], [8.5, 10.0, 11.0]],
dtype=torch.float32,
),
),
(
torch.tensor(
[[[1, 2], [3, 4]], [[5, 6], [7, 8]], [[9, 10], [11, 12]]],
dtype=torch.float32,
),
torch.tensor([[1, 0], [0, 1]], dtype=torch.int64),
torch.tensor(
[[2.5, 2.5, 0.0], [6.5, 6.5, 0.0], [10.5, 10.5, 0.0]],
dtype=torch.float32,
),
),
],
)
def test__mask_reduce(
self, feature_map: Tensor, mask: Tensor, expected_result: Tensor
) -> None:
out = _mask_reduce(feature_map, mask, num_cls=3)
assert (out == expected_result).all()


def has_grad(model: nn.Module):
"""Helper method to check if a model has `requires_grad` set to True"""
Expand Down

0 comments on commit 48ece5e

Please sign in to comment.