-
Notifications
You must be signed in to change notification settings - Fork 7
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Rebasing tpu branch on a more recent fairseq upstream commit #19
Changes from 1 commit
b002d00
be5821b
8af5554
c811e0e
ffe53d6
7efde22
69d0f7f
f812e52
1f96d28
5f78106
62b5498
9c89e88
47fd985
bccfa7d
906411d
51ba352
654affc
e8d609a
a03fe6f
30123e2
af6b361
208295d
b49ea81
3d764a3
8835d93
17fcc72
40f1687
5218a7c
abc13e2
8207f26
1362b21
c446c44
76ff39f
a80cade
8d036c2
ce7f044
36df0da
2f6d8b3
2fe45f0
33597e5
138dc8e
c132b9b
e75cff5
3b2cecd
d82517e
b651b00
37eb9f2
c5650bf
fe8a163
94722a9
3e0e5be
5b2be87
4abadbd
430905d
45f23f6
ea6cc1d
ccb5dea
5f34252
abb7ed4
f02f70c
3903f46
9012e87
12258e5
c728b86
1684e16
5d543f9
e40e4b2
2b7843d
1e55bbd
a9eda73
9a1038f
72f9364
439ead5
6398aa9
3563e59
838e108
b6c55b6
a00ce13
8324919
c0a5d29
3bbdc55
969f447
2b68e91
d003664
0563d87
577e4fa
a171c2d
a33ac06
d015d23
baa8ce1
7c89e13
ffffe04
b870468
f840564
1d44cc8
ac66df4
49177c9
a8e3211
ed27ed8
a3cfd51
851c022
732d15a
0c75c76
02cb5a4
79460d3
2eb53b8
6ce55e4
c81fed4
4812f64
9e5edc1
7a31fe0
a2f5361
ba5f829
93057cc
3c2cf3b
8c509a9
d4c9136
6e2bd79
4fc3953
833f053
8a8c069
3ab8e0f
396ff7f
920b85d
d2410c4
108f94b
0a96d22
8777465
4a7cd58
c1951aa
746e59a
8d4588b
20dfba7
6c00b33
1f0f7cd
1566cfb
3e3fe72
1fd8943
e1ba32a
a3882ab
31dd13f
718677e
8dbee4a
f994c9b
0eaaf35
a8a85c2
3233540
e869c80
10f9349
3b09b98
3f4fc50
2ed65b6
fa7dea6
e073ddf
2314979
62e65c4
6c1da0f
86857a5
1cb267e
ea1a410
4ac2c5f
1351972
acb6fba
1c66792
58e43cb
de348d1
315c463
4cb895b
6f58e15
c216522
34e79c5
63b6b3f
b6e001f
33646ac
c4893ca
cce92bd
02b74c5
d80ad54
e3a40d9
b5f41f8
3dcb5c7
c8a7b62
b8d024e
a3c629b
66d24dc
34e6a5e
2d51e04
e49b302
8defa9d
5a2f76e
39faa0a
d0358bb
5b086a0
fdf4c3e
c07362c
eb68afc
dabbef4
50cf3bb
856d8b8
f30fc7d
99c524c
4c6b689
828c1ca
a0f7599
fd7dcac
7a23b93
f17ad03
734b14f
d370e6b
043b6a9
12aaf54
8de1826
5120a2b
bbfeec9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -26,9 +26,25 @@ def infer_language_pair(path): | |
return src, dst | ||
|
||
|
||
def collate_tokens(values, pad_idx, eos_idx=None, left_pad=False, move_eos_to_beginning=False): | ||
def get_pad_size(values, input_shapes): | ||
if input_shapes is None: | ||
return max(v.size(0) for v in values) | ||
for batch_size, padlen in input_shapes: | ||
if len(values) == batch_size: | ||
return padlen | ||
else: | ||
raise IndexError( | ||
'Encountered values with invalid length {}, input shapes were {}' | ||
.format(len(values), input_shapes) | ||
) | ||
|
||
|
||
def collate_tokens( | ||
values, pad_idx, eos_idx=None, left_pad=False, | ||
move_eos_to_beginning=False, input_shapes=None, | ||
): | ||
"""Convert a list of 1d tensors into a padded 2d tensor.""" | ||
size = max(v.size(0) for v in values) | ||
size = get_pad_size(values, input_shapes) | ||
res = values[0].new(len(values), size).fill_(pad_idx) | ||
|
||
def copy_tensor(src, dst): | ||
|
@@ -227,10 +243,25 @@ def batch_by_size( | |
|
||
if isinstance(indices, types.GeneratorType): | ||
indices = np.fromiter(indices, dtype=np.int64, count=-1) | ||
|
||
return batch_by_size_fast(indices, num_tokens_fn, max_tokens, max_sentences, bsz_mult) | ||
|
||
|
||
def batch_by_size_tpu( | ||
indices, num_tokens_fn, input_shapes | ||
): | ||
batches = [[] for _ in input_shapes] | ||
dlibenzi marked this conversation as resolved.
Show resolved
Hide resolved
|
||
for idx in indices: | ||
sample_len = num_tokens_fn(idx) | ||
for j, (batch_size, padlen) in enumerate(input_shapes): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this assuming that the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes |
||
if padlen < sample_len: | ||
continue | ||
batches[j].append(idx) | ||
if len(batches[j]) == batch_size: | ||
yield batches[j] | ||
batches[j] = [] | ||
break | ||
|
||
|
||
def process_bpe_symbol(sentence: str, bpe_symbol: str): | ||
if bpe_symbol == 'sentencepiece': | ||
sentence = sentence.replace(' ', '').replace('\u2581', ' ').strip() | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -11,15 +11,15 @@ | |
|
||
def collate( | ||
samples, pad_idx, eos_idx, left_pad_source=True, left_pad_target=False, | ||
input_feeding=True, | ||
input_feeding=True, input_shapes=None, | ||
): | ||
if len(samples) == 0: | ||
return {} | ||
|
||
def merge(key, left_pad, move_eos_to_beginning=False): | ||
return data_utils.collate_tokens( | ||
[s[key] for s in samples], | ||
pad_idx, eos_idx, left_pad, move_eos_to_beginning, | ||
[s[key] for s in samples], pad_idx, | ||
eos_idx,left_pad, move_eos_to_beginning, input_shapes, | ||
) | ||
|
||
def check_alignment(alignment, src_len, tgt_len): | ||
|
@@ -154,7 +154,8 @@ def __init__( | |
shuffle=True, input_feeding=True, | ||
remove_eos_from_source=False, append_eos_to_target=False, | ||
align_dataset=None, | ||
append_bos=False | ||
append_bos=False, | ||
input_shapes=None, | ||
): | ||
if tgt_dict is not None: | ||
assert src_dict.pad() == tgt_dict.pad() | ||
|
@@ -178,6 +179,7 @@ def __init__( | |
if self.align_dataset is not None: | ||
assert self.tgt_sizes is not None, "Both source and target needed when alignments are provided" | ||
self.append_bos = append_bos | ||
self.input_shapes = input_shapes | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Optional: maybe we could add a docstring for this guy and clarify how it should be sorted? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. see https://github.com/pytorch-tpu/fairseq/blob/tpu/train.py#L291-L298, we error while parsing the input args if the shapes passed in doesn't satisfy the assumption, and describe requirements. |
||
|
||
def __getitem__(self, index): | ||
tgt_item = self.tgt[index] if self.tgt is not None else None | ||
|
@@ -249,7 +251,7 @@ def collater(self, samples): | |
return collate( | ||
samples, pad_idx=self.src_dict.pad(), eos_idx=self.src_dict.eos(), | ||
left_pad_source=self.left_pad_source, left_pad_target=self.left_pad_target, | ||
input_feeding=self.input_feeding, | ||
input_feeding=self.input_feeding, input_shapes=self.input_shapes, | ||
) | ||
|
||
def num_tokens(self, index): | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -39,9 +39,17 @@ def __init__(self, embed_dim, num_heads, kdim=None, vdim=None, dropout=0., bias= | |
assert not self.self_attention or self.qkv_same_dim, 'Self-attention requires query, key and ' \ | ||
'value to be of the same size' | ||
|
||
self.k_proj = nn.Linear(self.kdim, embed_dim, bias=bias) | ||
self.v_proj = nn.Linear(self.vdim, embed_dim, bias=bias) | ||
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) | ||
if self.qkv_same_dim: | ||
self.in_proj_weight = Parameter(torch.Tensor(3 * embed_dim, embed_dim)) | ||
else: | ||
self.k_proj_weight = Parameter(torch.Tensor(embed_dim, self.kdim)) | ||
self.v_proj_weight = Parameter(torch.Tensor(embed_dim, self.vdim)) | ||
self.q_proj_weight = Parameter(torch.Tensor(embed_dim, embed_dim)) | ||
|
||
if bias: | ||
self.in_proj_bias = Parameter(torch.Tensor(3 * embed_dim)) | ||
else: | ||
self.register_parameter('in_proj_bias', None) | ||
|
||
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) | ||
|
||
|
@@ -57,11 +65,12 @@ def __init__(self, embed_dim, num_heads, kdim=None, vdim=None, dropout=0., bias= | |
|
||
self.onnx_trace = False | ||
|
||
# XXX: (taylanbil) try F.multi... | ||
self.enable_torch_version = False | ||
if hasattr(F, "multi_head_attention_forward"): | ||
self.enable_torch_version = True | ||
else: | ||
self.enable_torch_version = False | ||
# if hasattr(F, "multi_head_attention_forward"): | ||
dlibenzi marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# self.enable_torch_version = True | ||
# else: | ||
# self.enable_torch_version = False | ||
|
||
def prepare_for_onnx_export_(self): | ||
self.onnx_trace = True | ||
|
@@ -70,15 +79,15 @@ def reset_parameters(self): | |
if self.qkv_same_dim: | ||
# Empirically observed the convergence to be much better with | ||
# the scaled initialization | ||
nn.init.xavier_uniform_(self.k_proj.weight, gain=1/math.sqrt(2)) | ||
nn.init.xavier_uniform_(self.v_proj.weight, gain=1/math.sqrt(2)) | ||
nn.init.xavier_uniform_(self.q_proj.weight, gain=1/math.sqrt(2)) | ||
nn.init.xavier_uniform_(self.in_proj_weight, gain=1/math.sqrt(2)) | ||
else: | ||
nn.init.xavier_uniform_(self.k_proj.weight) | ||
nn.init.xavier_uniform_(self.v_proj.weight) | ||
nn.init.xavier_uniform_(self.q_proj.weight) | ||
nn.init.xavier_uniform_(self.k_proj_weight) | ||
nn.init.xavier_uniform_(self.v_proj_weight) | ||
nn.init.xavier_uniform_(self.q_proj_weight) | ||
|
||
nn.init.xavier_uniform_(self.out_proj.weight) | ||
if self.in_proj_bias is not None: | ||
nn.init.constant_(self.in_proj_bias, 0.) | ||
nn.init.constant_(self.out_proj.bias, 0.) | ||
if self.bias_k is not None: | ||
nn.init.xavier_normal_(self.bias_k) | ||
|
@@ -146,23 +155,19 @@ def forward( | |
saved_state = None | ||
|
||
if self.self_attention: | ||
q = self.q_proj(query) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For my understanding, was this for performance improvements? If so did it help? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes, this was causing a 10% regression. |
||
k = self.k_proj(query) | ||
v = self.v_proj(query) | ||
q, k, v = self.in_proj_qkv(query) | ||
elif self.encoder_decoder_attention: | ||
# encoder-decoder attention | ||
q = self.q_proj(query) | ||
q = self.in_proj_q(query) | ||
if key is None: | ||
assert value is None | ||
k = v = None | ||
else: | ||
k = self.k_proj(key) | ||
v = self.v_proj(key) | ||
k = self.in_proj_k(key) | ||
v = self.in_proj_v(key) | ||
|
||
else: | ||
q = self.q_proj(query) | ||
k = self.k_proj(key) | ||
v = self.v_proj(value) | ||
raise | ||
q *= self.scaling | ||
|
||
if self.bias_k is not None: | ||
|
@@ -242,10 +247,9 @@ def forward( | |
if key_padding_mask is not None: | ||
# don't attend to padding symbols | ||
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) | ||
attn_weights = attn_weights.masked_fill( | ||
key_padding_mask.unsqueeze(1).unsqueeze(2), | ||
float('-inf'), | ||
) | ||
attn_weights = attn_weights.transpose(0, 2) | ||
dlibenzi marked this conversation as resolved.
Show resolved
Hide resolved
|
||
attn_weights.masked_fill_(key_padding_mask, float('-inf')) | ||
attn_weights = attn_weights.transpose(0, 2) | ||
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) | ||
|
||
if before_softmax: | ||
|
@@ -330,3 +334,43 @@ def upgrade_state_dict_named(self, state_dict, name): | |
|
||
for key, value in items_to_add.items(): | ||
state_dict[key] = value | ||
|
||
def in_proj_qkv(self, query): | ||
return self._in_proj(query).chunk(3, dim=-1) | ||
|
||
def in_proj_q(self, query): | ||
if self.qkv_same_dim: | ||
return self._in_proj(query, end=self.embed_dim) | ||
else: | ||
bias = self.in_proj_bias | ||
if bias is not None: | ||
bias = bias[:self.embed_dim] | ||
return F.linear(query, self.q_proj_weight, bias) | ||
|
||
def in_proj_k(self, key): | ||
if self.qkv_same_dim: | ||
return self._in_proj(key, start=self.embed_dim, end=2 * self.embed_dim) | ||
else: | ||
weight = self.k_proj_weight | ||
bias = self.in_proj_bias | ||
if bias is not None: | ||
bias = bias[self.embed_dim:2 * self.embed_dim] | ||
return F.linear(key, weight, bias) | ||
|
||
def in_proj_v(self, value): | ||
if self.qkv_same_dim: | ||
return self._in_proj(value, start=2 * self.embed_dim) | ||
else: | ||
weight = self.v_proj_weight | ||
bias = self.in_proj_bias | ||
if bias is not None: | ||
bias = bias[2 * self.embed_dim:] | ||
return F.linear(value, weight, bias) | ||
|
||
def _in_proj(self, input, start=0, end=None): | ||
weight = self.in_proj_weight | ||
bias = self.in_proj_bias | ||
weight = weight[start:end, :] | ||
if bias is not None: | ||
bias = bias[start:end] | ||
return F.linear(input, weight, bias) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe move this out of the
try
into a local and reuse.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
will do