Skip to content

Commit

Permalink
added interpolation for owlvit & owlv2.
Browse files Browse the repository at this point in the history
  • Loading branch information
kshitij committed Oct 20, 2024
1 parent 816f442 commit 9f54fc9
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 3 deletions.
31 changes: 29 additions & 2 deletions src/transformers/models/owlv2/modeling_owlv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,14 +291,41 @@ def __init__(self, config: Owlv2VisionConfig):
self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False)

def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
def forward(self, pixel_values: torch.FloatTensor, interpolate_pos_encoding: bool = False) -> torch.Tensor:
batch_size = pixel_values.shape[0]
target_size = self.config.image_size

if interpolate_pos_encoding:
if pixel_values.shape[2] != target_size or pixel_values.shape[3] != target_size:
pixel_values = nn.functional.interpolate(
pixel_values, size=(target_size, target_size), mode="bilinear", align_corners=False
)
else:
if pixel_values.shape[2] != target_size or pixel_values.shape[3] != target_size:
raise ValueError(
f"Input image size ({pixel_values.shape[2]}*{pixel_values.shape[3]}) doesn't match model ({target_size}*{target_size})."
)

patch_embeds = self.patch_embedding(pixel_values) # shape = [batch_size, num_channels, height, width]
patch_embeds = patch_embeds.flatten(2).transpose(1, 2)

class_embeds = self.class_embedding.expand(batch_size, 1, -1)
embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
embeddings = embeddings + self.position_embedding(self.position_ids)

if interpolate_pos_encoding:
pos_embedding = self.position_embedding(self.position_ids)
pos_embedding = pos_embedding.unsqueeze(0).expand(batch_size, -1, -1)
h = w = int(patch_embeds.shape[1] ** 0.5)
pos_embedding = nn.functional.interpolate(
pos_embedding.reshape(batch_size, h, w, -1).permute(0, 3, 1, 2),
size=(h, w),
mode="bilinear",
align_corners=False,
)
pos_embedding = pos_embedding.permute(0, 2, 3, 1).reshape(batch_size, -1, pos_embedding.shape[1])
embeddings = embeddings + pos_embedding
else:
embeddings = embeddings + self.position_embedding(self.position_ids)

return embeddings

Expand Down
16 changes: 15 additions & 1 deletion src/transformers/models/owlvit/modeling_owlvit.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from typing import Any, Dict, Optional, Tuple, Union

import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from torch import Tensor, nn

Expand Down Expand Up @@ -285,8 +286,21 @@ def __init__(self, config: OwlViTVisionConfig):
self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False)

def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
def forward(self, pixel_values: torch.FloatTensor, interpolate_pos_encoding: bool = False) -> torch.Tensor:
batch_size = pixel_values.shape[0]
target_size = self.config.image_size

if interpolate_pos_encoding:
if pixel_values.shape[2] != target_size or pixel_values.shape[3] != target_size:
pixel_values = F.interpolate(
pixel_values, size=(target_size, target_size), mode="bilinear", align_corners=False
)
else:
if pixel_values.shape[2] != target_size or pixel_values.shape[3] != target_size:
raise ValueError(
f"Input image size ({pixel_values.shape[2]}*{pixel_values.shape[3]}) doesn't match model ({target_size}*{target_size})."
)

patch_embeds = self.patch_embedding(pixel_values) # shape = [batch_size, num_channels, height, width]
patch_embeds = patch_embeds.flatten(2).transpose(1, 2)

Expand Down

0 comments on commit 9f54fc9

Please sign in to comment.