Skip to content

Commit

Permalink
Decouple DinoV2 for semantic segmentation (#4136)
Browse files Browse the repository at this point in the history
* dinov2 decoupled. Perf tests

* added dino

* remove dinov2 backbone

* fix linter

* remove unit test

* fix integration tests

* revert perf test back
  • Loading branch information
kprokofi authored Dec 4, 2024
1 parent 5707bc5 commit 5d6f8d3
Show file tree
Hide file tree
Showing 10 changed files with 203 additions and 234 deletions.
167 changes: 158 additions & 9 deletions src/otx/algo/classification/backbones/vision_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
"""Copy from mmpretrain/models/backbones/vision_transformer.py."""
from __future__ import annotations

import math
from functools import partial
from typing import TYPE_CHECKING, Any, Callable, Literal

Expand Down Expand Up @@ -46,6 +47,7 @@
"vit-huge",
"dinov2-s",
"dinov2-small",
"dinov2-small-seg",
"dinov2-b",
"dinov2-base",
"dinov2-l",
Expand Down Expand Up @@ -87,6 +89,7 @@ class VisionTransformer(BaseModule):
norm_layer: Normalization layer.
act_layer: MLP activation layer.
block_fn: Transformer block layer.
interpolate_offset: work-around offset to apply when interpolating positional embeddings
lora: Enable LoRA training.
"""

Expand Down Expand Up @@ -147,6 +150,17 @@ class VisionTransformer(BaseModule):
"num_heads": 6,
"reg_tokens": 4,
"no_embed_class": True,
},
),
**dict.fromkeys(
["dinov2-small-seg"], # segmentation
{
"patch_size": 14,
"embed_dim": 384,
"depth": 12,
"num_heads": 6,
"reg_tokens": 0,
"no_embed_class": False,
"init_values": 1e-5,
},
),
Expand Down Expand Up @@ -193,9 +207,9 @@ class VisionTransformer(BaseModule):

def __init__( # noqa: PLR0913
self,
arch: VIT_ARCH_TYPE = "vit-base",
arch: VIT_ARCH_TYPE | str = "vit-base",
img_size: int | tuple[int, int] = 224,
patch_size: int | tuple[int, int] | None = None,
patch_size: int | None = None,
in_chans: int = 3,
num_classes: int = 1000,
embed_dim: int | None = None,
Expand All @@ -221,6 +235,7 @@ def __init__( # noqa: PLR0913
mlp_layer: nn.Module | None = None,
act_layer: LayerType | None = None,
norm_layer: LayerType | None = None,
interpolate_offset: float = 0.1,
lora: bool = False,
) -> None:
super().__init__()
Expand All @@ -231,7 +246,7 @@ def __init__( # noqa: PLR0913
arch_settings: dict[str, Any] = self.arch_zoo[arch]

self.img_size: int | tuple[int, int] = img_size
self.patch_size: int | tuple[int, int] = patch_size or arch_settings.get("patch_size", 16)
self.patch_size: int = patch_size or arch_settings.get("patch_size", 16)
self.embed_dim = embed_dim or arch_settings.get("embed_dim", 768)
depth = depth or arch_settings.get("depth", 12)
num_heads = num_heads or arch_settings.get("num_heads", 12)
Expand All @@ -251,6 +266,7 @@ def __init__( # noqa: PLR0913
self.no_embed_class = no_embed_class # don't embed prefix positions (includes reg)
self.dynamic_img_size = dynamic_img_size
self.grad_checkpointing = False
self.interpolate_offset = interpolate_offset

embed_args = {}
if dynamic_img_size:
Expand Down Expand Up @@ -353,15 +369,17 @@ def resize_positional_embeddings(pos_embed: torch.Tensor, new_shape: tuple[int,
# convert dinov2 pretrained weights
state_dict = torch.load(checkpoint_path)
state_dict.pop("mask_token", None)
state_dict["reg_token"] = state_dict.pop("register_tokens")
if "reg_token" in state_dict:
state_dict["reg_token"] = state_dict.pop("register_tokens")
state_dict["cls_token"] = state_dict.pop("cls_token") + state_dict["pos_embed"][:, 0]

img_size = (self.img_size, self.img_size) if isinstance(self.img_size, int) else self.img_size
patch_size = (self.patch_size, self.patch_size) if isinstance(self.patch_size, int) else self.patch_size
state_dict["pos_embed"] = resize_positional_embeddings(
state_dict.pop("pos_embed")[:, 1:],
(img_size[0] // patch_size[0], img_size[1] // patch_size[1]),
)
patch_size = (self.patch_size, self.patch_size)
if state_dict["pos_embed"].shape != self.pos_embed.shape:
state_dict["pos_embed"] = resize_positional_embeddings(
state_dict.pop("pos_embed")[:, 1:],
(img_size[0] // patch_size[0], img_size[1] // patch_size[1]),
)
self.load_state_dict(state_dict, strict=False)
else:
msg = f"Unsupported `checkpoint_extension` {checkpoint_ext}, please choose from 'npz' or 'pth'."
Expand Down Expand Up @@ -401,6 +419,137 @@ def _pos_embed(self, x: torch.Tensor) -> torch.Tensor:

return self.pos_drop(x)

def interpolate_pos_encoding(self, x: torch.Tensor, w: int, h: int) -> torch.Tensor:
"""Interpolates the positional encoding to match the input dimensions.
Args:
x (torch.Tensor): Input tensor.
w (int): Width of the input image.
h (int): Height of the input image.
Returns:
torch.Tensor: Tensor with interpolated positional encoding.
"""
previous_dtype = x.dtype
npatch = x.shape[1]
n = self.pos_embed.shape[1]
if npatch == n and w == h:
return self.pos_embed
pos_embed = self.pos_embed.float()
class_pos_embed = pos_embed[:, 0]
patch_pos_embed = pos_embed[:, 1:]
dim = x.shape[-1]
w0 = w // self.patch_size
h0 = h // self.patch_size
m = int(math.sqrt(n)) # Recover the number of patches in each dimension
if m * m != n:
msg = f"Expected m * m to equal n, but got m={m}, n={n}"
raise ValueError(msg)
kwargs = {}
if self.interpolate_offset:
# fix float error by introducing small offset
sx = float(w0 + self.interpolate_offset) / m
sy = float(h0 + self.interpolate_offset) / m
kwargs["scale_factor"] = (sx, sy)
else:
# Simply specify an output size instead of a scale factor
kwargs["size"] = (w0, h0)
patch_pos_embed = nn.functional.interpolate(
patch_pos_embed.reshape(1, m, m, dim).permute(0, 3, 1, 2),
mode="bicubic",
**kwargs,
)
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype)

def prepare_tokens_with_masks(self, x: torch.Tensor, masks: torch.Tensor | None = None) -> torch.Tensor:
"""Prepare tokens with optional masks.
Args:
x (torch.Tensor): Input tensor.
masks (torch.Tensor | None): Optional masks tensor.
Returns:
torch.Tensor: Tensor with prepared tokens.
"""
_, _, w, h = x.shape
x = self.patch_embed(x)
if masks is not None:
x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x)

x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
x = x + self.interpolate_pos_encoding(x, w, h)

if self.reg_token is not None:
x = torch.cat(
(
x[:, :1],
self.reg_token.expand(x.shape[0], -1, -1),
x[:, 1:],
),
dim=1,
)

return x

def _get_intermediate_layers_not_chunked(self, x: torch.Tensor, n: int = 1) -> list[torch.Tensor]:
"""Get intermediate layers without chunking.
Args:
x (torch.Tensor): Input tensor.
n (int): Number of last blocks to take. If it's a list, take the specified blocks.
Returns:
list[torch.Tensor]: List of intermediate layer outputs.
"""
x = self.prepare_tokens_with_masks(x)
# If n is an int, take the n last blocks. If it's a list, take them
output, total_block_len = [], len(self.blocks)
blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
for i, blk in enumerate(self.blocks):
x = blk(x)
if i in blocks_to_take:
output.append(x)
if len(output) != len(blocks_to_take):
msg = f"only {len(output)} / {len(blocks_to_take)} blocks found"
raise RuntimeError(msg)
return output

def get_intermediate_layers(
self,
x: torch.Tensor,
n: int = 1, # Layers or n last layers to take
reshape: bool = False,
return_class_token: bool = False,
norm: bool = True,
) -> tuple:
"""Get intermediate layers of the VisionTransformer.
Args:
x (torch.Tensor): Input tensor.
n (int): Number of last blocks to take. If it's a list, take the specified blocks.
reshape (bool): Whether to reshape the output feature maps.
return_class_token (bool): Whether to return the class token.
norm (bool): Whether to apply normalization to the outputs.
Returns:
tuple: A tuple containing the intermediate layer outputs.
"""
outputs = self._get_intermediate_layers_not_chunked(x, n)
if norm:
outputs = [self.norm(out) for out in outputs]
class_tokens = [out[:, 0] for out in outputs]
outputs = [out[:, 1 + self.num_reg_tokens :] for out in outputs]
if reshape:
b, _, w, h = x.shape
outputs = [
out.reshape(b, w // self.patch_size, h // self.patch_size, -1).permute(0, 3, 1, 2).contiguous()
for out in outputs
]
if return_class_token:
return tuple(zip(outputs, class_tokens))
return tuple(outputs)

def forward(
self,
x: torch.Tensor,
Expand Down
3 changes: 1 addition & 2 deletions src/otx/algo/segmentation/backbones/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@
#
"""Backbone modules for OTX segmentation model."""

from .dinov2 import DinoVisionTransformer
from .litehrnet import LiteHRNetBackbone
from .mscan import MSCAN

__all__ = ["LiteHRNetBackbone", "DinoVisionTransformer", "MSCAN"]
__all__ = ["LiteHRNetBackbone", "MSCAN"]
98 changes: 0 additions & 98 deletions src/otx/algo/segmentation/backbones/dinov2.py

This file was deleted.

Loading

0 comments on commit 5d6f8d3

Please sign in to comment.