Skip to content

Commit 7a23b93

Browse files
committed
Making tpu training work
optimizer fix progress bar comment out temporarily some changes to train_tpu int mask instead of float pfpfpfpf fix printing device index per loop bkpt to investigate resize_ call attempting to init buffer size to 2*dim bkpt better print do not drop records when computing loss Changes that reduce graph compiles. * Loss function replaced with an equivalent logic that doesn't resize tensors. * cli args changed to guarantee consistency * collate_tokens function in fairseq/data/data_utils.py overwritten to guarantee consistency undoing some changes made while debugging progress_bar implements len some irrelevant changes to train_tpu.py new xla changes bug fix in enable_torch_version removing the last batch that is of diferent size from the iterator delete optimizer step in fairseq s trainer Added `self.xla` flag that controls if Trainer includes optimizer step + Tried to include more explanation why skip optimizer step this time deleted obsolete file add norm clipping count back in (#4) remove grad norm clip count (#5) Change masked_fill_ input in loss in order to accomodate necessary pytorch changes (#6) Adding tpu capabilities to train.py (#8) * Adding tpu capabilities to train.py * flush when printing for better user experience * separated cli_main into parse_args, maingpu and maintpu deleted unused line in datautils.py Enumerate the loader in training and validation (#9) * Adding tpu capabilities to train.py * flush when printing for better user experience * separated cli_main into parse_args, maingpu and maintpu deleted unused line in datautils.py * Enumerate the loader * enumerate the loader Add option to assert on training and/or validation loss (#10) * Add option to assert on training and/or validation loss * applied suggestion None loss should be filled to inf (#11) Enabling multiprocessing for fairseq training. (#12) * initial commit for multiprocess api * indentation fixes and import fix * no need to softlink, fix save/load * Remove the hacks to only save from master ordinal as xm.save takes care of that * fix indentation; 3 -> 4 spaces * Moved xu.eprints after spawn and dropping last batches better trainers->trainer (#13) fix bug in assert_on_losses Replace usage of unsqueeze with transpose + broadcasting (#15) remove attn mask + loss rewrite + save per host + format suppress loss report allow usage of batch_by_size in translation. attn_weights masked fill in place Clean up the log output suppressing a bit Revert multihead attn's in_proj code changes non-rebased tpu branch is about 10% faster on TPUs compared to the rebased branch. The regression is inside multihead attn's in_proj mechanism. Reverting the relevant changes to preserve performance. Pass correct args to the new get_valid_stats function Send meters to device in order not to fail training when resuming dfrom chkpt
1 parent a0f7599 commit 7a23b93

11 files changed

+530
-85
lines changed

fairseq/checkpoint_utils.py

+16-8
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717

1818
from fairseq.models import FairseqEncoder, FairseqDecoder
1919

20+
import torch_xla.core.xla_model as xm
21+
2022

2123
def save_checkpoint(args, trainer, epoch_itr, val_loss):
2224
from fairseq import distributed_utils, meters
@@ -62,15 +64,17 @@ def is_better(a, b):
6264
extra_state.update({'best': save_checkpoint.best})
6365

6466
checkpoints = [os.path.join(args.save_dir, fn) for fn, cond in checkpoint_conds.items() if cond]
67+
6568
if len(checkpoints) > 0:
6669
trainer.save_checkpoint(checkpoints[0], extra_state)
6770
for cp in checkpoints[1:]:
6871
try:
6972
from fairseq.fb_pathmgr import fb_pathmgr
70-
fb_pathmgr.copy(checkpoints[0], cp, True)
73+
if getattr(args, 'use_gpu', True) or xm.is_master_ordinal():
74+
fb_pathmgr.copy(checkpoints[0], cp, True)
7175
except (ModuleNotFoundError, ImportError):
72-
shutil.copyfile(checkpoints[0], cp)
73-
76+
if getattr(args, 'use_gpu', True) or xm.is_master_ordinal():
77+
shutil.copyfile(checkpoints[0], cp)
7478
write_timer.stop()
7579
print('| saved checkpoint {} (epoch {} @ {} updates) (writing took {} seconds)'.format(
7680
checkpoints[0], epoch, updates, write_timer.sum))
@@ -97,7 +101,7 @@ def is_better(a, b):
97101
def load_checkpoint(args, trainer, data_selector=None):
98102
"""Load a checkpoint and restore the training iterator."""
99103
# only one worker should attempt to create the required dir
100-
if args.distributed_rank == 0:
104+
if args.distributed_rank == 0 or xm.is_master_ordinal():
101105
os.makedirs(args.save_dir, exist_ok=True)
102106

103107
if args.restore_file == 'checkpoint_last.pt':
@@ -210,7 +214,8 @@ def checkpoint_paths(path, pattern=r'checkpoint(\d+)\.pt'):
210214
def torch_persistent_save(*args, **kwargs):
211215
for i in range(3):
212216
try:
213-
return torch.save(*args, **kwargs)
217+
save_func = xm.save if kwargs.pop('xla', False) else torch.save
218+
return save_func(*args, **kwargs)
214219
except Exception:
215220
if i == 2:
216221
logging.error(traceback.format_exc())
@@ -256,14 +261,17 @@ def save_state(
256261
state_dict['criterion'] = criterion.state_dict()
257262
if not args.no_save_optimizer_state:
258263
state_dict['last_optimizer_state'] = convert_state_dict_type(optimizer.state_dict())
259-
260264
try:
261265
from fairseq.fb_pathmgr import fb_pathmgr
262266
with fb_pathmgr.open(filename, "wb") as f:
263-
torch_persistent_save(state_dict, f)
267+
torch_persistent_save(
268+
state_dict, f, xla=not getattr(args, 'use_gpu', True)
269+
)
264270
except (ModuleNotFoundError, ImportError):
265271
# if path manager not found, continue with local file.
266-
torch_persistent_save(state_dict, filename)
272+
torch_persistent_save(
273+
state_dict, filename, xla=not getattr(args, 'use_gpu', True)
274+
)
267275

268276

269277
def _upgrade_state_dict(state):

fairseq/criterions/label_smoothed_cross_entropy.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@ def label_smoothed_nll_loss(lprobs, target, epsilon, ignore_index=None, reduce=T
1717
smooth_loss = -lprobs.sum(dim=-1, keepdim=True)
1818
if ignore_index is not None:
1919
non_pad_mask = target.ne(ignore_index)
20-
nll_loss = nll_loss[non_pad_mask]
21-
smooth_loss = smooth_loss[non_pad_mask]
20+
nll_loss.masked_fill_(~non_pad_mask, 0.0)
21+
smooth_loss.masked_fill_(~non_pad_mask, 0.0)
2222
else:
2323
nll_loss = nll_loss.squeeze(-1)
2424
smooth_loss = smooth_loss.squeeze(-1)
@@ -57,8 +57,8 @@ def forward(self, model, sample, reduce=True):
5757
loss, nll_loss = self.compute_loss(model, net_output, sample, reduce=reduce)
5858
sample_size = sample['target'].size(0) if self.args.sentence_avg else sample['ntokens']
5959
logging_output = {
60-
'loss': utils.item(loss.data) if reduce else loss.data,
61-
'nll_loss': utils.item(nll_loss.data) if reduce else nll_loss.data,
60+
'loss': loss.data,
61+
'nll_loss': nll_loss.data,
6262
'ntokens': sample['ntokens'],
6363
'nsentences': sample['target'].size(0),
6464
'sample_size': sample_size,

fairseq/data/data_utils.py

+34-3
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,25 @@ def infer_language_pair(path):
2626
return src, dst
2727

2828

29-
def collate_tokens(values, pad_idx, eos_idx=None, left_pad=False, move_eos_to_beginning=False):
29+
def get_pad_size(values, input_shapes):
30+
if input_shapes is None:
31+
return max(v.size(0) for v in values)
32+
for batch_size, padlen in input_shapes:
33+
if len(values) == batch_size:
34+
return padlen
35+
else:
36+
raise IndexError(
37+
'Encountered values with invalid length {}, input shapes were {}'
38+
.format(len(values), input_shapes)
39+
)
40+
41+
42+
def collate_tokens(
43+
values, pad_idx, eos_idx=None, left_pad=False,
44+
move_eos_to_beginning=False, input_shapes=None,
45+
):
3046
"""Convert a list of 1d tensors into a padded 2d tensor."""
31-
size = max(v.size(0) for v in values)
47+
size = get_pad_size(values, input_shapes)
3248
res = values[0].new(len(values), size).fill_(pad_idx)
3349

3450
def copy_tensor(src, dst):
@@ -227,10 +243,25 @@ def batch_by_size(
227243

228244
if isinstance(indices, types.GeneratorType):
229245
indices = np.fromiter(indices, dtype=np.int64, count=-1)
230-
231246
return batch_by_size_fast(indices, num_tokens_fn, max_tokens, max_sentences, bsz_mult)
232247

233248

249+
def batch_by_size_tpu(
250+
indices, num_tokens_fn, input_shapes
251+
):
252+
batches = [[] for _ in input_shapes]
253+
for idx in indices:
254+
sample_len = num_tokens_fn(idx)
255+
for j, (batch_size, padlen) in enumerate(input_shapes):
256+
if padlen < sample_len:
257+
continue
258+
batches[j].append(idx)
259+
if len(batches[j]) == batch_size:
260+
yield batches[j]
261+
batches[j] = []
262+
break
263+
264+
234265
def process_bpe_symbol(sentence: str, bpe_symbol: str):
235266
if bpe_symbol == 'sentencepiece':
236267
sentence = sentence.replace(' ', '').replace('\u2581', ' ').strip()

fairseq/data/language_pair_dataset.py

+7-5
Original file line numberDiff line numberDiff line change
@@ -11,15 +11,15 @@
1111

1212
def collate(
1313
samples, pad_idx, eos_idx, left_pad_source=True, left_pad_target=False,
14-
input_feeding=True,
14+
input_feeding=True, input_shapes=None,
1515
):
1616
if len(samples) == 0:
1717
return {}
1818

1919
def merge(key, left_pad, move_eos_to_beginning=False):
2020
return data_utils.collate_tokens(
21-
[s[key] for s in samples],
22-
pad_idx, eos_idx, left_pad, move_eos_to_beginning,
21+
[s[key] for s in samples], pad_idx,
22+
eos_idx,left_pad, move_eos_to_beginning, input_shapes,
2323
)
2424

2525
def check_alignment(alignment, src_len, tgt_len):
@@ -154,7 +154,8 @@ def __init__(
154154
shuffle=True, input_feeding=True,
155155
remove_eos_from_source=False, append_eos_to_target=False,
156156
align_dataset=None,
157-
append_bos=False
157+
append_bos=False,
158+
input_shapes=None,
158159
):
159160
if tgt_dict is not None:
160161
assert src_dict.pad() == tgt_dict.pad()
@@ -178,6 +179,7 @@ def __init__(
178179
if self.align_dataset is not None:
179180
assert self.tgt_sizes is not None, "Both source and target needed when alignments are provided"
180181
self.append_bos = append_bos
182+
self.input_shapes = input_shapes
181183

182184
def __getitem__(self, index):
183185
tgt_item = self.tgt[index] if self.tgt is not None else None
@@ -249,7 +251,7 @@ def collater(self, samples):
249251
return collate(
250252
samples, pad_idx=self.src_dict.pad(), eos_idx=self.src_dict.eos(),
251253
left_pad_source=self.left_pad_source, left_pad_target=self.left_pad_target,
252-
input_feeding=self.input_feeding,
254+
input_feeding=self.input_feeding, input_shapes=self.input_shapes,
253255
)
254256

255257
def num_tokens(self, index):

fairseq/models/transformer.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -355,8 +355,8 @@ def forward(self, src_tokens, src_lengths, cls_input=None, return_all_hiddens=Fa
355355

356356
# compute padding mask
357357
encoder_padding_mask = src_tokens.eq(self.padding_idx)
358-
if not encoder_padding_mask.any():
359-
encoder_padding_mask = None
358+
#if not encoder_padding_mask.any():
359+
# encoder_padding_mask = None
360360

361361
encoder_states = [] if return_all_hiddens else None
362362

@@ -596,8 +596,8 @@ def extract_features(
596596
x = x.transpose(0, 1)
597597

598598
self_attn_padding_mask = prev_output_tokens.eq(self.padding_idx)
599-
if not self_attn_padding_mask.any() and not self.cross_self_attention:
600-
self_attn_padding_mask = None
599+
# if not self_attn_padding_mask.any() and not self.cross_self_attention:
600+
# self_attn_padding_mask = None
601601

602602
# decoder layers
603603
attn = None

fairseq/modules/multihead_attention.py

+70-26
Original file line numberDiff line numberDiff line change
@@ -39,9 +39,17 @@ def __init__(self, embed_dim, num_heads, kdim=None, vdim=None, dropout=0., bias=
3939
assert not self.self_attention or self.qkv_same_dim, 'Self-attention requires query, key and ' \
4040
'value to be of the same size'
4141

42-
self.k_proj = nn.Linear(self.kdim, embed_dim, bias=bias)
43-
self.v_proj = nn.Linear(self.vdim, embed_dim, bias=bias)
44-
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
42+
if self.qkv_same_dim:
43+
self.in_proj_weight = Parameter(torch.Tensor(3 * embed_dim, embed_dim))
44+
else:
45+
self.k_proj_weight = Parameter(torch.Tensor(embed_dim, self.kdim))
46+
self.v_proj_weight = Parameter(torch.Tensor(embed_dim, self.vdim))
47+
self.q_proj_weight = Parameter(torch.Tensor(embed_dim, embed_dim))
48+
49+
if bias:
50+
self.in_proj_bias = Parameter(torch.Tensor(3 * embed_dim))
51+
else:
52+
self.register_parameter('in_proj_bias', None)
4553

4654
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
4755

@@ -57,11 +65,12 @@ def __init__(self, embed_dim, num_heads, kdim=None, vdim=None, dropout=0., bias=
5765

5866
self.onnx_trace = False
5967

68+
# XXX: (taylanbil) try F.multi...
6069
self.enable_torch_version = False
61-
if hasattr(F, "multi_head_attention_forward"):
62-
self.enable_torch_version = True
63-
else:
64-
self.enable_torch_version = False
70+
# if hasattr(F, "multi_head_attention_forward"):
71+
# self.enable_torch_version = True
72+
# else:
73+
# self.enable_torch_version = False
6574

6675
def prepare_for_onnx_export_(self):
6776
self.onnx_trace = True
@@ -70,15 +79,15 @@ def reset_parameters(self):
7079
if self.qkv_same_dim:
7180
# Empirically observed the convergence to be much better with
7281
# the scaled initialization
73-
nn.init.xavier_uniform_(self.k_proj.weight, gain=1/math.sqrt(2))
74-
nn.init.xavier_uniform_(self.v_proj.weight, gain=1/math.sqrt(2))
75-
nn.init.xavier_uniform_(self.q_proj.weight, gain=1/math.sqrt(2))
82+
nn.init.xavier_uniform_(self.in_proj_weight, gain=1/math.sqrt(2))
7683
else:
77-
nn.init.xavier_uniform_(self.k_proj.weight)
78-
nn.init.xavier_uniform_(self.v_proj.weight)
79-
nn.init.xavier_uniform_(self.q_proj.weight)
84+
nn.init.xavier_uniform_(self.k_proj_weight)
85+
nn.init.xavier_uniform_(self.v_proj_weight)
86+
nn.init.xavier_uniform_(self.q_proj_weight)
8087

8188
nn.init.xavier_uniform_(self.out_proj.weight)
89+
if self.in_proj_bias is not None:
90+
nn.init.constant_(self.in_proj_bias, 0.)
8291
nn.init.constant_(self.out_proj.bias, 0.)
8392
if self.bias_k is not None:
8493
nn.init.xavier_normal_(self.bias_k)
@@ -146,23 +155,19 @@ def forward(
146155
saved_state = None
147156

148157
if self.self_attention:
149-
q = self.q_proj(query)
150-
k = self.k_proj(query)
151-
v = self.v_proj(query)
158+
q, k, v = self.in_proj_qkv(query)
152159
elif self.encoder_decoder_attention:
153160
# encoder-decoder attention
154-
q = self.q_proj(query)
161+
q = self.in_proj_q(query)
155162
if key is None:
156163
assert value is None
157164
k = v = None
158165
else:
159-
k = self.k_proj(key)
160-
v = self.v_proj(key)
166+
k = self.in_proj_k(key)
167+
v = self.in_proj_v(key)
161168

162169
else:
163-
q = self.q_proj(query)
164-
k = self.k_proj(key)
165-
v = self.v_proj(value)
170+
raise
166171
q *= self.scaling
167172

168173
if self.bias_k is not None:
@@ -242,10 +247,9 @@ def forward(
242247
if key_padding_mask is not None:
243248
# don't attend to padding symbols
244249
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
245-
attn_weights = attn_weights.masked_fill(
246-
key_padding_mask.unsqueeze(1).unsqueeze(2),
247-
float('-inf'),
248-
)
250+
attn_weights = attn_weights.transpose(0, 2)
251+
attn_weights.masked_fill_(key_padding_mask, float('-inf'))
252+
attn_weights = attn_weights.transpose(0, 2)
249253
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
250254

251255
if before_softmax:
@@ -330,3 +334,43 @@ def upgrade_state_dict_named(self, state_dict, name):
330334

331335
for key, value in items_to_add.items():
332336
state_dict[key] = value
337+
338+
def in_proj_qkv(self, query):
339+
return self._in_proj(query).chunk(3, dim=-1)
340+
341+
def in_proj_q(self, query):
342+
if self.qkv_same_dim:
343+
return self._in_proj(query, end=self.embed_dim)
344+
else:
345+
bias = self.in_proj_bias
346+
if bias is not None:
347+
bias = bias[:self.embed_dim]
348+
return F.linear(query, self.q_proj_weight, bias)
349+
350+
def in_proj_k(self, key):
351+
if self.qkv_same_dim:
352+
return self._in_proj(key, start=self.embed_dim, end=2 * self.embed_dim)
353+
else:
354+
weight = self.k_proj_weight
355+
bias = self.in_proj_bias
356+
if bias is not None:
357+
bias = bias[self.embed_dim:2 * self.embed_dim]
358+
return F.linear(key, weight, bias)
359+
360+
def in_proj_v(self, value):
361+
if self.qkv_same_dim:
362+
return self._in_proj(value, start=2 * self.embed_dim)
363+
else:
364+
weight = self.v_proj_weight
365+
bias = self.in_proj_bias
366+
if bias is not None:
367+
bias = bias[2 * self.embed_dim:]
368+
return F.linear(value, weight, bias)
369+
370+
def _in_proj(self, input, start=0, end=None):
371+
weight = self.in_proj_weight
372+
bias = self.in_proj_bias
373+
weight = weight[start:end, :]
374+
if bias is not None:
375+
bias = bias[start:end]
376+
return F.linear(input, weight, bias)

fairseq/tasks/fairseq_task.py

+11-4
Original file line numberDiff line numberDiff line change
@@ -146,10 +146,17 @@ def get_batch_iterator(
146146
)
147147

148148
# create mini-batches with given size constraints
149-
batch_sampler = data_utils.batch_by_size(
150-
indices, dataset.num_tokens, max_tokens=max_tokens, max_sentences=max_sentences,
151-
required_batch_size_multiple=required_batch_size_multiple,
152-
)
149+
if getattr(self.args, 'use_gpu', True):
150+
batch_sampler = data_utils.batch_by_size(
151+
indices, dataset.num_tokens, max_tokens=max_tokens,
152+
max_sentences=max_sentences,
153+
required_batch_size_multiple=required_batch_size_multiple,
154+
)
155+
else:
156+
batch_sampler = data_utils.batch_by_size_tpu(
157+
indices, dataset.num_tokens,
158+
getattr(self.args, 'input_shapes', None)
159+
)
153160

154161
# return a reusable, sharded iterator
155162
epoch_iter = iterators.EpochBatchIterator(

0 commit comments

Comments
 (0)