-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathGlobalAttention.py
executable file
·208 lines (171 loc) · 7.01 KB
/
GlobalAttention.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
import torch
import torch.nn as nn
from UtilClass import BottleLinear
def aeq(*args):
"""
Assert all arguments have the same value
"""
arguments = (arg for arg in args)
first = next(arguments)
assert all(arg == first for arg in arguments), \
"Not all arguments have the same value: " + str(args)
def sequence_mask(lengths, max_len=None):
"""
Creates a boolean mask from sequence lengths.
"""
batch_size = lengths.numel()
max_len = max_len or lengths.max()
return (torch.arange(0, max_len)
.type_as(lengths)
.repeat(batch_size, 1)
.lt(lengths.unsqueeze(1)))
class GlobalAttention(nn.Module):
"""
Luong Attention.
Global attention takes a matrix and a query vector. It
then computes a parameterized convex combination of the matrix
based on the input query.
H_1 H_2 H_3 ... H_n
q q q q
| | | |
\ | | /
.....
\ | /
a
Constructs a unit mapping.
$$(H_1 + H_n, q) => (a)$$
Where H is of `batch x n x dim` and q is of `batch x dim`.
Luong Attention (dot, general):
The full function is
$$\tanh(W_2 [(softmax((W_1 q + b_1) H) H), q] + b_2)$$.
* dot: $$score(h_t,{\overline{h}}_s) = h_t^T{\overline{h}}_s$$
* general: $$score(h_t,{\overline{h}}_s) = h_t^T W_a {\overline{h}}_s$$
Bahdanau Attention (mlp):
$$c = \sum_{j=1}^{SeqLength}\a_jh_j$$.
The Alignment-function $$a$$ computes an alignment as:
$$a_j = softmax(v_a^T \tanh(W_a q + U_a h_j) )$$.
"""
def __init__(self, dim, attn_type="dot", include_rnn=True, dropout=0.0, bias=False):
super(GlobalAttention, self).__init__()
self.dim = dim
self.attn_type = attn_type
self.include_rnn = include_rnn
self.drop = nn.Dropout(dropout)
assert (self.attn_type in ["dot", "general", "mlp"]), (
"Please select a valid attention type.")
if self.attn_type == "general":
self.linear_in = nn.Linear(dim, dim, bias=bias)
elif self.attn_type == "mlp":
self.linear_context = BottleLinear(dim, dim, bias=False)
self.linear_query = nn.Linear(dim, dim, bias=True)
self.v = BottleLinear(dim, 1, bias=False)
# mlp wants it with bias
out_bias = self.attn_type == "mlp" or bias
self.linear_out = nn.Linear(dim*2, dim, bias=out_bias)
self.sm = nn.Softmax(dim=1)
self.tanh = nn.Tanh()
self.mask = None
def applyMask(self, mask):
self.mask = mask
def score(self, h_t, h_s):
"""
h_t (FloatTensor): batch x tgt_len x dim
h_s (FloatTensor): batch x src_len x dim
returns scores (FloatTensor): batch x tgt_len x src_len:
raw attention scores for each src index
"""
# Check input sizes
src_batch, src_len, src_dim = h_s.size()
tgt_batch, tgt_len, tgt_dim = h_t.size()
aeq(src_batch, tgt_batch)
aeq(src_dim, tgt_dim)
aeq(self.dim, src_dim)
if self.attn_type in ["general", "dot"]:
if self.attn_type == "general":
h_t_ = h_t.contiguous().view(tgt_batch*tgt_len, tgt_dim)
h_t_ = self.linear_in(h_t_)
h_t = h_t_.view(tgt_batch, tgt_len, tgt_dim)
h_s_ = h_s.transpose(1, 2)
# (batch, t_len, d) x (batch, d, s_len) --> (batch, t_len, s_len)
return torch.bmm(h_t, h_s_)
else:
dim = self.dim
wq = self.linear_query(h_t.view(-1, dim))
wq = wq.view(tgt_batch, tgt_len, 1, dim)
wq = wq.expand(tgt_batch, tgt_len, src_len, dim)
uh = self.linear_context(h_s.contiguous().view(-1, dim))
uh = uh.view(src_batch, 1, src_len, dim)
uh = uh.expand(src_batch, tgt_len, src_len, dim)
# (batch, t_len, s_len, d)
wquh = self.tanh(wq + uh)
return self.v(wquh.view(-1, dim)).view(tgt_batch, tgt_len, src_len)
def forward(self, input, context, context_lengths_or_mask):
"""
input (FloatTensor): batch x tgt_len x dim: decoder's rnn's output.
context (FloatTensor): batch x src_len x dim: src hidden states
context_lengths_or_mask (LongTensor): the source context lengths or mask.
"""
# one step input
if input.dim() == 2:
one_step = True
input = input.unsqueeze(1)
else:
one_step = False
batch, sourceL, dim = context.size()
batch_, targetL, dim_ = input.size()
aeq(batch, batch_)
aeq(dim, dim_)
aeq(self.dim, dim)
if self.mask is not None:
beam_, batch_, sourceL_ = self.mask.size()
aeq(batch, batch_*beam_)
aeq(sourceL, sourceL_)
# compute attention scores, as in Luong et al.
align = self.score(input, context)
if context_lengths_or_mask is not None:
if context_lengths_or_mask.dim() == 1: # Its lengths and not mask
mask = sequence_mask(context_lengths_or_mask.data)
elif context_lengths_or_mask.dim() == 2: # Its lengths and not mask
mask = context_lengths_or_mask.data
mask = mask.unsqueeze(1) # Make it broadcastable.
align.data.masked_fill_(~mask, -float('inf'))
# Softmax to normalize attention weights
align_vectors = self.sm(align.view(batch*targetL, sourceL))
align_vectors = align_vectors.view(batch, targetL, sourceL)
# each context vector c_t is the weighted average
# over all the source hidden states
c = torch.bmm(align_vectors, context)
# concatenate
if self.include_rnn:
concat_c = torch.cat([c, input], 2).view(batch*targetL, dim*2)
attn_h = self.linear_out(concat_c).view(batch, targetL, dim)
else:
attn_h = c.view(batch, targetL, dim)
attn_h = self.drop(attn_h)
if self.attn_type in ["general", "dot"]:
attn_h = self.tanh(attn_h)
# TODO: One step This is broken
if one_step:
assert(False)
attn_h = attn_h.squeeze(1)
align_vectors = align_vectors.squeeze(1)
# Check output sizes
batch_, dim_ = attn_h.size()
aeq(batch, batch_)
aeq(dim, dim_)
batch_, sourceL_ = align_vectors.size()
aeq(batch, batch_)
aeq(sourceL, sourceL_)
else:
# attn_h = attn_h.transpose(0, 1).contiguous()
# align_vectors = align_vectors.transpose(0, 1).contiguous()
# Check output sizes
batch_, targetL_, dim_ = attn_h.size()
aeq(targetL, targetL_)
aeq(batch, batch_)
aeq(dim, dim_)
batch_, targetL_, sourceL_ = align_vectors.size()
aeq(targetL, targetL_)
aeq(batch, batch_)
aeq(sourceL, sourceL_)
return attn_h, align_vectors