diff --git a/models/gc_vit.py b/models/gc_vit.py index 376ea95..f66dd29 100644 --- a/models/gc_vit.py +++ b/models/gc_vit.py @@ -445,16 +445,16 @@ def __init__(self, self.num_windows = int((input_resolution // window_size) * (input_resolution // window_size)) def forward(self, x, q_global): - B, H, W, C = x.shape - shortcut = x - x = self.norm1(x) - x_windows = window_partition(x, self.window_size) - x_windows = x_windows.view(-1, self.window_size * self.window_size, C) - attn_windows = self.attn(x_windows, q_global) - x = window_reverse(attn_windows, self.window_size, H, W) - x = shortcut + self.drop_path(self.gamma1 * x) - x = x + self.drop_path(self.gamma2 * self.mlp(self.norm2(x))) - return x + B, H, W, C = x.shape + shortcut = x + x = self.norm1(x) + x_windows = window_partition(x, self.window_size) + x_windows = x_windows.view(-1, self.window_size * self.window_size, C) + attn_windows = self.attn(x_windows, q_global) + x = window_reverse(attn_windows, self.window_size, H, W) + x = shortcut + self.drop_path(self.gamma1 * x) + x = x + self.drop_path(self.gamma2 * self.mlp(self.norm2(x))) + return x class GlobalQueryGen(nn.Module): @@ -474,6 +474,9 @@ def __init__(self, input_resolution: input image resolution. window_size: window size. num_heads: number of heads. + + For instance, repeating log(56/7) = 3 blocks, with input window dimension 56 and output window dimension 7 at + down-sampling ratio 2. Please check Fig.5 of GC ViT paper for details. """ super().__init__()