Skip to content

Commit

Permalink
Merge pull request #19 from NVlabs/dev
Browse files Browse the repository at this point in the history
update model & checkpoints
  • Loading branch information
ahatamiz authored Aug 11, 2022
2 parents f679914 + c77aae0 commit 3900ff4
Showing 1 changed file with 13 additions and 10 deletions.
23 changes: 13 additions & 10 deletions models/gc_vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,16 +445,16 @@ def __init__(self,
self.num_windows = int((input_resolution // window_size) * (input_resolution // window_size))

def forward(self, x, q_global):
B, H, W, C = x.shape
shortcut = x
x = self.norm1(x)
x_windows = window_partition(x, self.window_size)
x_windows = x_windows.view(-1, self.window_size * self.window_size, C)
attn_windows = self.attn(x_windows, q_global)
x = window_reverse(attn_windows, self.window_size, H, W)
x = shortcut + self.drop_path(self.gamma1 * x)
x = x + self.drop_path(self.gamma2 * self.mlp(self.norm2(x)))
return x
B, H, W, C = x.shape
shortcut = x
x = self.norm1(x)
x_windows = window_partition(x, self.window_size)
x_windows = x_windows.view(-1, self.window_size * self.window_size, C)
attn_windows = self.attn(x_windows, q_global)
x = window_reverse(attn_windows, self.window_size, H, W)
x = shortcut + self.drop_path(self.gamma1 * x)
x = x + self.drop_path(self.gamma2 * self.mlp(self.norm2(x)))
return x


class GlobalQueryGen(nn.Module):
Expand All @@ -474,6 +474,9 @@ def __init__(self,
input_resolution: input image resolution.
window_size: window size.
num_heads: number of heads.
For instance, repeating log(56/7) = 3 blocks, with input window dimension 56 and output window dimension 7 at
down-sampling ratio 2. Please check Fig.5 of GC ViT paper for details.
"""

super().__init__()
Expand Down

0 comments on commit 3900ff4

Please sign in to comment.