Skip to content

Commit a03fe6f

Browse files
Sara Hansonfacebook-github-bot
Sara Hanson
authored andcommitted
Implement sparse transformer fixed attention pattern (#804)
Summary: Pull Request resolved: facebookresearch/pytext#804 Pull Request resolved: fairinternal/fairseq-py#746 Pull Request resolved: #894 Adding an implementation of the sparse transformer to multi-head attention using the fixed attention pattern specified https://arxiv.org/pdf/1904.10509.pdf. The sparse_mask masks out words using -inf; after softmax, -inf becomes 0. Thus, a mask does not need to be re-calculated and re-applied when multiplying attn_weights and values. Four inputs are added to the config: sparse, is_bidirectional, stride, expressivity. If we are using the sparse transformer, is_bidirectional, stride, and expressivity must be specified (there are defaults). If is_bidirectional is False, the mask values using the fixed attention pattern described in the paper. If is_bidirectional is True, subset one includes all values in the current stride window and a summary from every stride window--all other values are masked. Stride (L in the paper) controls the window size and expressivity (c in the paper) controls the size of the summary. Reviewed By: borguz Differential Revision: D16042988 fbshipit-source-id: c59166dc7cfe89187a256e4076000c2458842fd5
1 parent e8d609a commit a03fe6f

5 files changed

+296
-2
lines changed

fairseq/modules/multihead_attention.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,6 @@ def __init__(self, embed_dim, num_heads, kdim=None, vdim=None, dropout=0., bias=
4040
assert not self.self_attention or self.qkv_same_dim, 'Self-attention requires query, key and ' \
4141
'value to be of the same size'
4242

43-
4443
if self.qkv_same_dim:
4544
self.in_proj_weight = Parameter(torch.Tensor(3 * embed_dim, embed_dim))
4645
else:
@@ -102,7 +101,6 @@ def forward(self, query, key, value, key_padding_mask=None, incremental_state=No
102101
the key by passing a binary ByteTensor (`key_padding_mask`) with shape:
103102
batch x src_len, where padding elements are indicated by 1s.
104103
"""
105-
106104
tgt_len, bsz, embed_dim = query.size()
107105
assert embed_dim == self.embed_dim
108106
assert list(query.size()) == [tgt_len, bsz, embed_dim]
@@ -217,6 +215,8 @@ def forward(self, query, key, value, key_padding_mask=None, incremental_state=No
217215
[key_padding_mask, torch.zeros(key_padding_mask.size(0), 1).type_as(key_padding_mask)], dim=1)
218216

219217
attn_weights = torch.bmm(q, k.transpose(1, 2))
218+
attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz)
219+
220220
assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]
221221

222222
if attn_mask is not None:
@@ -327,3 +327,6 @@ def _set_input_buffer(self, incremental_state, buffer):
327327
'attn_state',
328328
buffer,
329329
)
330+
331+
def apply_sparse_mask(self, attn_weights, tgt_len, src_len, bsz):
332+
return attn_weights
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
# Copyright (c) 2017-present, Facebook, Inc.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the license found in the LICENSE file in
5+
# the root directory of this source tree. An additional grant of patent rights
6+
# can be found in the PATENTS file in the same directory.
7+
8+
import math
9+
import torch
10+
from .multihead_attention import MultiheadAttention
11+
12+
13+
class SparseMultiheadAttention(MultiheadAttention):
14+
""" Sparse Multi-Headed Attention.
15+
16+
"Generating Long Sequences with Sparse Transformers". Implements
17+
fixed factorized self attention, where l=stride and c=expressivity.
18+
A(1) includes all words in the stride window and A(2) takes a summary of c
19+
words from the end of each stride window.
20+
If is_bidirectional=False, we do not include any words past the current word,
21+
as in the paper.
22+
"""
23+
24+
def __init__(self, embed_dim, num_heads, kdim=None, vdim=None, dropout=0., bias=True,
25+
add_bias_kv=False, add_zero_attn=False, self_attention=False,
26+
encoder_decoder_attention=False, stride=32, expressivity=8, is_bidirectional=True):
27+
28+
super().__init__(
29+
embed_dim, num_heads, kdim, vdim, dropout, bias, add_bias_kv,
30+
add_zero_attn, self_attention, encoder_decoder_attention
31+
)
32+
33+
self.is_bidirectional = is_bidirectional
34+
self.stride = stride
35+
self.expressivity = expressivity
36+
assert(self.stride > 0 and self.stride >= self.expressivity)
37+
38+
# Used for Ai(2) calculations - beginning of [l-c, l] range
39+
def compute_checkpoint(self, word_index):
40+
if word_index % self.stride == 0 and word_index is not 0:
41+
checkpoint_index = word_index - self.expressivity
42+
else:
43+
checkpoint_index = (
44+
math.floor(word_index / self.stride) * self.stride
45+
+ self.stride - self.expressivity
46+
)
47+
return checkpoint_index
48+
49+
# Computes Ai(2)
50+
def compute_subset_summaries(self, absolute_max):
51+
checkpoint_index = self.compute_checkpoint(0)
52+
subset_two = set()
53+
while checkpoint_index <= absolute_max-1:
54+
summary = set(range(checkpoint_index, min(
55+
checkpoint_index+self.expressivity+1, absolute_max)
56+
))
57+
subset_two = subset_two.union(summary)
58+
checkpoint_index = self.compute_checkpoint(checkpoint_index+self.stride)
59+
return subset_two
60+
61+
# Sparse Transformer Fixed Attention Pattern: https://arxiv.org/pdf/1904.10509.pdf
62+
def compute_fixed_attention_subset(self, word_index, tgt_len):
63+
# +1s account for range function; [min, max) -> [min, max]
64+
if not self.is_bidirectional:
65+
absolute_max = word_index + 1
66+
else:
67+
absolute_max = tgt_len
68+
69+
# Subset 1 - whole window
70+
rounded_index = math.floor((word_index + self.stride) / self.stride) * self.stride
71+
if word_index % self.stride == 0 and word_index is not 0:
72+
subset_one = set(range(word_index-self.stride, min(absolute_max, word_index+1)))
73+
else:
74+
subset_one = set(range(max(0, rounded_index - self.stride), min(
75+
absolute_max, rounded_index+1))
76+
)
77+
78+
# Subset 2 - summary per window
79+
# If bidirectional, subset 2 is the same for every index
80+
subset_two = set()
81+
if not self.is_bidirectional:
82+
subset_two = self.compute_subset_summaries(absolute_max)
83+
84+
return subset_one.union(subset_two)
85+
86+
# Compute sparse mask - if bidirectional, can pre-compute and store
87+
def buffered_sparse_mask(self, tensor, tgt_len, src_len):
88+
assert(tgt_len > self.stride)
89+
sparse_mask = torch.empty((tgt_len, src_len)).float().fill_(float('-inf'))
90+
91+
# If bidirectional, subset 2 is the same for every index
92+
subset_summaries = set()
93+
if self.is_bidirectional:
94+
subset_summaries = self.compute_subset_summaries(tgt_len)
95+
96+
for i in range(tgt_len):
97+
fixed_attention_subset = self.compute_fixed_attention_subset(i, tgt_len)
98+
fixed_attention_subset = fixed_attention_subset.union(subset_summaries)
99+
included_word_indices = torch.LongTensor(list(fixed_attention_subset))
100+
sparse_mask[i].index_fill_(0, included_word_indices, 0)
101+
return sparse_mask.type_as(tensor)
102+
103+
def apply_sparse_mask(self, attn_weights, tgt_len, src_len, bsz):
104+
sparse_mask = self.buffered_sparse_mask(attn_weights, tgt_len, src_len)
105+
sparse_mask = sparse_mask.unsqueeze(0).expand(bsz * self.num_heads, tgt_len, src_len)
106+
attn_weights += sparse_mask
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
# Copyright (c) 2017-present, Facebook, Inc.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the license found in the LICENSE file in
5+
# the root directory of this source tree. An additional grant of patent rights
6+
# can be found in the PATENTS file in the same directory.
7+
8+
import torch.nn as nn
9+
from fairseq.modules import TransformerSentenceEncoder
10+
from fairseq.modules.sparse_transformer_sentence_encoder_layer import SparseTransformerSentenceEncoderLayer
11+
12+
13+
class SparseTransformerSentenceEncoder(TransformerSentenceEncoder):
14+
"""
15+
Sparse implementation of the TransformerSentenceEncoder
16+
- see SparseMultiheadAttention
17+
"""
18+
19+
def __init__(
20+
self,
21+
padding_idx: int,
22+
vocab_size: int,
23+
num_encoder_layers: int = 6,
24+
embedding_dim: int = 768,
25+
ffn_embedding_dim: int = 3072,
26+
num_attention_heads: int = 8,
27+
dropout: float = 0.1,
28+
attention_dropout: float = 0.1,
29+
activation_dropout: float = 0.1,
30+
max_seq_len: int = 256,
31+
num_segments: int = 2,
32+
use_position_embeddings: bool = True,
33+
offset_positions_by_padding: bool = True,
34+
encoder_normalize_before: bool = False,
35+
apply_bert_init: bool = False,
36+
activation_fn: str = "relu",
37+
learned_pos_embedding: bool = True,
38+
add_bias_kv: bool = False,
39+
add_zero_attn: bool = False,
40+
embed_scale: float = None,
41+
freeze_embeddings: bool = False,
42+
n_trans_layers_to_freeze: int = 0,
43+
export: bool = False,
44+
is_bidirectional: bool = True,
45+
stride: int = 32,
46+
expressivity: int = 8,
47+
) -> None:
48+
49+
super().__init__(
50+
padding_idx, vocab_size, num_encoder_layers, embedding_dim,
51+
ffn_embedding_dim, num_attention_heads, dropout, attention_dropout,
52+
activation_dropout, max_seq_len, num_segments, use_position_embeddings,
53+
offset_positions_by_padding, encoder_normalize_before, apply_bert_init,
54+
activation_fn, learned_pos_embedding, add_bias_kv, add_zero_attn,
55+
embed_scale, freeze_embeddings, n_trans_layers_to_freeze, export
56+
)
57+
58+
self.layers = nn.ModuleList(
59+
[
60+
SparseTransformerSentenceEncoderLayer(
61+
embedding_dim=self.embedding_dim,
62+
ffn_embedding_dim=ffn_embedding_dim,
63+
num_attention_heads=num_attention_heads,
64+
dropout=self.dropout,
65+
attention_dropout=attention_dropout,
66+
activation_dropout=activation_dropout,
67+
activation_fn=activation_fn,
68+
add_bias_kv=add_bias_kv,
69+
add_zero_attn=add_zero_attn,
70+
export=export,
71+
is_bidirectional=is_bidirectional,
72+
stride=stride,
73+
expressivity=expressivity,
74+
)
75+
for _ in range(num_encoder_layers)
76+
]
77+
)
78+
79+
def freeze_module_params(m):
80+
if m is not None:
81+
for p in m.parameters():
82+
p.requires_grad = False
83+
84+
for layer in range(n_trans_layers_to_freeze):
85+
freeze_module_params(self.layers[layer])
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
# Copyright (c) 2017-present, Facebook, Inc.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the license found in the LICENSE file in
5+
# the root directory of this source tree. An additional grant of patent rights
6+
# can be found in the PATENTS file in the same directory.
7+
8+
from fairseq.modules import TransformerSentenceEncoderLayer
9+
from fairseq.modules.sparse_multihead_attention import SparseMultiheadAttention
10+
11+
12+
class SparseTransformerSentenceEncoderLayer(TransformerSentenceEncoderLayer):
13+
"""
14+
Implements a Sprase Transformer Encoder Layer (see SparseMultiheadAttention)
15+
"""
16+
17+
def __init__(
18+
self,
19+
embedding_dim: float = 768,
20+
ffn_embedding_dim: float = 3072,
21+
num_attention_heads: float = 8,
22+
dropout: float = 0.1,
23+
attention_dropout: float = 0.1,
24+
activation_dropout: float = 0.1,
25+
activation_fn: str = 'relu',
26+
add_bias_kv: bool = False,
27+
add_zero_attn: bool = False,
28+
export: bool = False,
29+
is_bidirectional: bool = True,
30+
stride: int = 32,
31+
expressivity: int = 8,
32+
) -> None:
33+
34+
super().__init__(
35+
embedding_dim, ffn_embedding_dim, num_attention_heads, dropout,
36+
attention_dropout, activation_dropout, activation_fn, add_bias_kv,
37+
add_zero_attn, export
38+
)
39+
40+
self.self_attn = SparseMultiheadAttention(
41+
self.embedding_dim,
42+
num_attention_heads,
43+
dropout=attention_dropout,
44+
add_bias_kv=add_bias_kv,
45+
add_zero_attn=add_zero_attn,
46+
self_attention=True,
47+
is_bidirectional=is_bidirectional,
48+
stride=stride,
49+
expressivity=expressivity,
50+
)
+50
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
# Copyright (c) 2017-present, Facebook, Inc.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the license found in the LICENSE file in
5+
# the root directory of this source tree. An additional grant of patent rights
6+
# can be found in the PATENTS file in the same directory.
7+
8+
import torch
9+
import unittest
10+
from fairseq.modules.sparse_multihead_attention import SparseMultiheadAttention
11+
12+
13+
class TestSparseMultiheadAttention(unittest.TestCase):
14+
def test_sparse_multihead_attention(self):
15+
attn_weights = torch.randn(1, 8, 8)
16+
bidirectional_sparse_mask = torch.tensor([
17+
[0, 0, 0, 0, 0, float('-inf'), float('-inf'), 0],
18+
[0, 0, 0, 0, 0, float('-inf'), float('-inf'), 0],
19+
[0, 0, 0, 0, 0, float('-inf'), float('-inf'), 0],
20+
[0, 0, 0, 0, 0, float('-inf'), float('-inf'), 0],
21+
[float('-inf'), float('-inf'), float('-inf'), 0, 0, 0, 0, 0],
22+
[float('-inf'), float('-inf'), float('-inf'), 0, 0, 0, 0, 0],
23+
[float('-inf'), float('-inf'), float('-inf'), 0, 0, 0, 0, 0],
24+
[float('-inf'), float('-inf'), float('-inf'), 0, 0, 0, 0, 0]
25+
])
26+
27+
bidirectional_attention = SparseMultiheadAttention(16, 1, stride=4, expressivity=1, is_bidirectional=True)
28+
bidirectional_attention_sparse_mask = bidirectional_attention.buffered_sparse_mask(attn_weights, 8, 8)
29+
torch.all(torch.eq(bidirectional_attention_sparse_mask, bidirectional_sparse_mask))
30+
31+
sparse_mask = torch.tensor([
32+
[0, float('-inf'), float('-inf'), float('-inf'), float('-inf'), float('-inf'),
33+
float('-inf'), float('-inf')],
34+
[0, 0, float('-inf'), float('-inf'), float('-inf'), float('-inf'), float('-inf'), float('-inf')],
35+
[0, 0, 0, float('-inf'), float('-inf'), float('-inf'), float('-inf'), float('-inf')],
36+
[0, 0, 0, 0, float('-inf'), float('-inf'), float('-inf'), float('-inf')],
37+
[0, 0, 0, 0, 0, float('-inf'), float('-inf'), float('-inf')],
38+
[float('-inf'), float('-inf'), float('-inf'), 0, 0, 0, float('-inf'), float('-inf')],
39+
[float('-inf'), float('-inf'), float('-inf'), 0, 0, 0, 0, float('-inf')],
40+
[float('-inf'), float('-inf'), float('-inf'), 0, 0, 0, 0, 0],
41+
])
42+
43+
attention = SparseMultiheadAttention(16, 1, stride=4, expressivity=1, is_bidirectional=False)
44+
attention_sparse_mask = attention.buffered_sparse_mask(attn_weights, 8, 8)
45+
46+
torch.all(torch.eq(attention_sparse_mask, sparse_mask))
47+
48+
49+
if __name__ == '__main__':
50+
unittest.main()

0 commit comments

Comments
 (0)