-
Notifications
You must be signed in to change notification settings - Fork 19
/
attention.py
137 lines (109 loc) · 6.82 KB
/
attention.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
import torch
import torch.nn as nn
import torch.nn.functional as F
from model import common
class NonLocalSparseAttention(nn.Module):
def __init__( self, n_hashes=4, channels=64, k_size=3, reduction=4, chunk_size=144, conv=common.default_conv, res_scale=1):
super(NonLocalSparseAttention,self).__init__()
self.chunk_size = chunk_size
self.n_hashes = n_hashes
self.reduction = reduction
self.res_scale = res_scale
self.conv_match = common.BasicBlock(conv, channels, channels//reduction, k_size, bn=False, act=None)
self.conv_assembly = common.BasicBlock(conv, channels, channels, 1, bn=False, act=None)
def LSH(self, hash_buckets, x):
#x: [N,H*W,C]
N = x.shape[0]
device = x.device
#generate random rotation matrix
rotations_shape = (1, x.shape[-1], self.n_hashes, hash_buckets//2) #[1,C,n_hashes,hash_buckets//2]
random_rotations = torch.randn(rotations_shape, dtype=x.dtype, device=device).expand(N, -1, -1, -1) #[N, C, n_hashes, hash_buckets//2]
#locality sensitive hashing
rotated_vecs = torch.einsum('btf,bfhi->bhti', x, random_rotations) #[N, n_hashes, H*W, hash_buckets//2]
rotated_vecs = torch.cat([rotated_vecs, -rotated_vecs], dim=-1) #[N, n_hashes, H*W, hash_buckets]
#get hash codes
hash_codes = torch.argmax(rotated_vecs, dim=-1) #[N,n_hashes,H*W]
#add offsets to avoid hash codes overlapping between hash rounds
offsets = torch.arange(self.n_hashes, device=device)
offsets = torch.reshape(offsets * hash_buckets, (1, -1, 1))
hash_codes = torch.reshape(hash_codes + offsets, (N, -1,)) #[N,n_hashes*H*W]
return hash_codes
def add_adjacent_buckets(self, x):
x_extra_back = torch.cat([x[:,:,-1:, ...], x[:,:,:-1, ...]], dim=2)
x_extra_forward = torch.cat([x[:,:,1:, ...], x[:,:,:1,...]], dim=2)
return torch.cat([x, x_extra_back,x_extra_forward], dim=3)
def forward(self, input):
N,_,H,W = input.shape
x_embed = self.conv_match(input).view(N,-1,H*W).contiguous().permute(0,2,1)
y_embed = self.conv_assembly(input).view(N,-1,H*W).contiguous().permute(0,2,1)
L,C = x_embed.shape[-2:]
#number of hash buckets/hash bits
hash_buckets = min(L//self.chunk_size + (L//self.chunk_size)%2, 128)
#get assigned hash codes/bucket number
hash_codes = self.LSH(hash_buckets, x_embed) #[N,n_hashes*H*W]
hash_codes = hash_codes.detach()
#group elements with same hash code by sorting
_, indices = hash_codes.sort(dim=-1) #[N,n_hashes*H*W]
_, undo_sort = indices.sort(dim=-1) #undo_sort to recover original order
mod_indices = (indices % L) #now range from (0->H*W)
x_embed_sorted = common.batched_index_select(x_embed, mod_indices) #[N,n_hashes*H*W,C]
y_embed_sorted = common.batched_index_select(y_embed, mod_indices) #[N,n_hashes*H*W,C]
#pad the embedding if it cannot be divided by chunk_size
padding = self.chunk_size - L%self.chunk_size if L%self.chunk_size!=0 else 0
x_att_buckets = torch.reshape(x_embed_sorted, (N, self.n_hashes,-1, C)) #[N, n_hashes, H*W,C]
y_att_buckets = torch.reshape(y_embed_sorted, (N, self.n_hashes,-1, C*self.reduction))
if padding:
pad_x = x_att_buckets[:,:,-padding:,:].clone()
pad_y = y_att_buckets[:,:,-padding:,:].clone()
x_att_buckets = torch.cat([x_att_buckets,pad_x],dim=2)
y_att_buckets = torch.cat([y_att_buckets,pad_y],dim=2)
x_att_buckets = torch.reshape(x_att_buckets,(N,self.n_hashes,-1,self.chunk_size,C)) #[N, n_hashes, num_chunks, chunk_size, C]
y_att_buckets = torch.reshape(y_att_buckets,(N,self.n_hashes,-1,self.chunk_size, C*self.reduction))
x_match = F.normalize(x_att_buckets, p=2, dim=-1,eps=5e-5)
#allow attend to adjacent buckets
x_match = self.add_adjacent_buckets(x_match)
y_att_buckets = self.add_adjacent_buckets(y_att_buckets)
#unormalized attention score
raw_score = torch.einsum('bhkie,bhkje->bhkij', x_att_buckets, x_match) #[N, n_hashes, num_chunks, chunk_size, chunk_size*3]
#softmax
bucket_score = torch.logsumexp(raw_score, dim=-1, keepdim=True)
score = torch.exp(raw_score - bucket_score) #(after softmax)
bucket_score = torch.reshape(bucket_score,[N,self.n_hashes,-1])
#attention
ret = torch.einsum('bukij,bukje->bukie', score, y_att_buckets) #[N, n_hashes, num_chunks, chunk_size, C]
ret = torch.reshape(ret,(N,self.n_hashes,-1,C*self.reduction))
#if padded, then remove extra elements
if padding:
ret = ret[:,:,:-padding,:].clone()
bucket_score = bucket_score[:,:,:-padding].clone()
#recover the original order
ret = torch.reshape(ret, (N, -1, C*self.reduction)) #[N, n_hashes*H*W,C]
bucket_score = torch.reshape(bucket_score, (N, -1,)) #[N,n_hashes*H*W]
ret = common.batched_index_select(ret, undo_sort)#[N, n_hashes*H*W,C]
bucket_score = bucket_score.gather(1, undo_sort)#[N,n_hashes*H*W]
#weighted sum multi-round attention
ret = torch.reshape(ret, (N, self.n_hashes, L, C*self.reduction)) #[N, n_hashes*H*W,C]
bucket_score = torch.reshape(bucket_score, (N, self.n_hashes, L, 1))
probs = nn.functional.softmax(bucket_score,dim=1)
ret = torch.sum(ret * probs, dim=1)
ret = ret.permute(0,2,1).view(N,-1,H,W).contiguous()*self.res_scale+input
return ret
class NonLocalAttention(nn.Module):
def __init__(self, channel=128, reduction=2, ksize=1, scale=3, stride=1, softmax_scale=10, average=True, res_scale=1,conv=common.default_conv):
super(NonLocalAttention, self).__init__()
self.res_scale = res_scale
self.conv_match1 = common.BasicBlock(conv, channel, channel//reduction, 1, bn=False, act=nn.PReLU())
self.conv_match2 = common.BasicBlock(conv, channel, channel//reduction, 1, bn=False, act = nn.PReLU())
self.conv_assembly = common.BasicBlock(conv, channel, channel, 1,bn=False, act=nn.PReLU())
def forward(self, input):
x_embed_1 = self.conv_match1(input)
x_embed_2 = self.conv_match2(input)
x_assembly = self.conv_assembly(input)
N,C,H,W = x_embed_1.shape
x_embed_1 = x_embed_1.permute(0,2,3,1).view((N,H*W,C))
x_embed_2 = x_embed_2.view(N,C,H*W)
score = torch.matmul(x_embed_1, x_embed_2)
score = F.softmax(score, dim=2)
x_assembly = x_assembly.view(N,-1,H*W).permute(0,2,1)
x_final = torch.matmul(score, x_assembly)
return x_final.permute(0,2,1).view(N,-1,H,W)+self.res_scale*input