Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rebasing tpu branch on a more recent fairseq upstream commit #19

Merged
merged 214 commits into from
Nov 19, 2019
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
214 commits
Select commit Hold shift + click to select a range
b002d00
v0.7.1 -> v0.7.2 (#891)
Jul 19, 2019
be5821b
Switch to torch.nn.functional.gelu when available
Jul 19, 2019
8af5554
Improve interactive generation (support --tokenizer and --bpe)
Jul 19, 2019
c811e0e
Store task in the criterion base class
Jul 19, 2019
ffe53d6
Create standalone label_smoothed_nll_loss
Jul 19, 2019
7efde22
Allow not specifying --warmup-init-lr
Jul 19, 2019
69d0f7f
Rename _load_model_ensemble -> load_model_ensemble_and_task
Jul 19, 2019
f812e52
Rename data.transforms -> data.encoders
Jul 21, 2019
1f96d28
Fix topp sampling issues (#882)
Jul 21, 2019
5f78106
Default to mmap and infer dataset implementations automatically
Jul 21, 2019
62b5498
Update GPT-2 BPE
Jul 21, 2019
9c89e88
Misc improvements to torch hub interface
Jul 22, 2019
47fd985
Move Masked LM components to legacy/ -- new ones are coming
Jul 22, 2019
bccfa7d
Add fallback for SLURM config
Jul 22, 2019
906411d
Fix --reset-meters
Jul 22, 2019
51ba352
Simplify hubconf
Jul 22, 2019
654affc
Add new Datasets
Jul 22, 2019
e8d609a
Add new Masked LM task + criterion
Jul 22, 2019
a03fe6f
Implement sparse transformer fixed attention pattern (#804)
Jul 22, 2019
30123e2
Fix read_binarized.py script
Jul 23, 2019
af6b361
Initializing mask as a tensor of ints (not long) (#875)
taylanbil Jul 23, 2019
208295d
Update README.md
Jul 23, 2019
b49ea81
check save_dir before beginning training
Jul 24, 2019
3d764a3
Update torch.hub usage
Jul 25, 2019
8835d93
Standardize on 'teacher forcing' rather than 'input feeding' which is…
Jul 25, 2019
17fcc72
Add RoBERTa README
Jul 27, 2019
40f1687
Add return_all_hiddens flag to hub interface
Jul 27, 2019
5218a7c
Fix compatibility with PyTorch 1.0.x (Fixes #906)
Jul 28, 2019
abc13e2
Make hub_utils.generator inherit from nn.Module
Jul 28, 2019
8207f26
Misc dataset improvements
Jul 28, 2019
1362b21
Correctly zero padding index in TransformerSentenceEncoder
Jul 28, 2019
c446c44
Add Adamax optimizer
Jul 28, 2019
76ff39f
Change default --num-workers to 1
Jul 28, 2019
a80cade
Update BPE library code
Jul 29, 2019
8d036c2
Add RoBERTa
Jul 29, 2019
ce7f044
Add instructions to load RoBERTa models on PyTorch 1.0
Jul 29, 2019
36df0da
Fix RoBERTa model import (fixes #918)
Jul 29, 2019
2f6d8b3
Add missing files for RoBERTa hub interface
Jul 29, 2019
2fe45f0
Update README.md to add top-p sampling (#783)
xingz9 Jul 29, 2019
33597e5
Support different --max-positions and --tokens-per-sample
Jul 29, 2019
138dc8e
adding glue data preprocessing scripts (#771)
Jul 29, 2019
c132b9b
Fix tokenization (fixes #926) (#929)
Jul 30, 2019
e75cff5
Relicense fairseq under MIT license (#786)
Jul 30, 2019
3b2cecd
1) replaced fstring 2) fixed error from max-positions arg
Jul 30, 2019
d82517e
Add roberta.decode to hub interface to decode BPE (#931)
Jul 30, 2019
b651b00
Wmt19 models (#767)
nng555 Jul 31, 2019
37eb9f2
Use commandline interface in preprocess_GLUE_tasks.sh (#937)
villmow Jul 31, 2019
c5650bf
Update language_model README.md (#941)
nadongguri Jul 31, 2019
fe8a163
Roberta add classification finetuning example readme (#790)
ngoyal2707 Jul 31, 2019
94722a9
Fix citation errors (#791)
nng555 Jul 31, 2019
3e0e5be
Fix small syntax error in hub_utils.py (fixes #942)
Aug 1, 2019
5b2be87
Update PyTorch Hub interface
Aug 1, 2019
4abadbd
Fix sampling with beam>1
Aug 1, 2019
430905d
Changed tensor comparison return type from uint8 to bool (#21113)
izdeby Aug 1, 2019
45f23f6
Add more details for bulk BPE encoding
Aug 1, 2019
ea6cc1d
Use ==/!= to compare str, bytes, and int literals (#948)
cclauss Aug 1, 2019
ccb5dea
Fix wmt19 links (#796)
nng555 Aug 1, 2019
5f34252
Update beam search code to support torch.bool change
Aug 2, 2019
abb7ed4
Update READMEs for torch.hub
Aug 2, 2019
f02f70c
Add single-models for WMT'19 for hub tutorial
Aug 2, 2019
3903f46
Fewer torch.hub requirements (#959)
Aug 2, 2019
9012e87
Avoid cast in PositionalEmbeddings to fix BLEU drop in pytorch native…
cndn Aug 2, 2019
12258e5
Fix generating with a fixed prefix
Aug 3, 2019
c728b86
remove default params from args so architecture works properly
alexeib Aug 3, 2019
1684e16
Add doc string for Roberta.encode function
Aug 4, 2019
5d543f9
fixed roberta finetuning with --find-unused-parameters on multiGPU
Aug 5, 2019
e40e4b2
Add back set_epoch functionality lost in RoBERTa merge
Aug 6, 2019
2b7843d
Add code to realign RoBERTa features to word-level tokenizers
Aug 7, 2019
1e55bbd
Fix tests and GLUE finetuning (fixes #989)
Aug 7, 2019
a9eda73
Added mask_fill api and some examples in README (#807)
Aug 7, 2019
9a1038f
fixed reloading from checkpoint (#811)
Aug 7, 2019
72f9364
Asr initial push (#810)
Aug 8, 2019
439ead5
Integrate with Apache Arrow/Plasma in-memory store for large datasets…
Aug 8, 2019
6398aa9
replace 'mkdir' with 'mkdir -p' (#997)
gmhafiz Aug 8, 2019
3563e59
added superglue dev set results to readme
Aug 9, 2019
838e108
MacOS requires c++ flag (#1000)
vincentqb Aug 9, 2019
b6c55b6
added sentence ranking task and loss (#809)
jingfeidu Aug 9, 2019
a00ce13
Fix Python 3.5 compat
Aug 10, 2019
8324919
Add WSC task and criterion
Aug 10, 2019
c0a5d29
Fix torch.hub for MNLI
Aug 10, 2019
3bbdc55
Update --restore-file logic (partially fixes #999)
Aug 12, 2019
969f447
Remove LAMB optimizer (at least until we can test it more)
Aug 12, 2019
2b68e91
Lint
Aug 12, 2019
d003664
Minor fixes for RACE finetuning (#818)
Aug 12, 2019
0563d87
ignore files starting with . e.g. .ipynb_checkpoints (#819)
uralik Aug 12, 2019
577e4fa
fix cosine scheduler docstring
Aug 13, 2019
a171c2d
added readme code for inference with GLUE finetuned model
Aug 13, 2019
a33ac06
Add Commonsense QA task
Aug 13, 2019
d015d23
Add fairseq-validate
Aug 13, 2019
baa8ce1
Updates for PyTorch 1.2 masking/bool behavior
Aug 14, 2019
7c89e13
Fix tests
Aug 14, 2019
ffffe04
v0.7.2 -> v0.8.0 (#1017)
Aug 14, 2019
b870468
Update READMEs
Aug 14, 2019
f840564
initial light and dynamic convolution kernels (#547)
nng555 Aug 14, 2019
1d44cc8
added effcient wsc task/criterion for winogrande (#825)
ngoyal2707 Aug 15, 2019
ac66df4
Update README
Aug 15, 2019
49177c9
Backward reranking public (#667)
nng555 Aug 15, 2019
a8e3211
Update README
Aug 15, 2019
ed27ed8
BMUF Resetting local state param
Aug 15, 2019
a3cfd51
added hf bert bpe
Aug 16, 2019
851c022
added check in token block dataset for multiple consecutive blank lines
Aug 16, 2019
732d15a
implement tri-stage lr_scheduler (#1028)
Aug 17, 2019
0c75c76
Fix bug (the returned value has a dimension mismatch) in label-smooth…
violet-zct Aug 19, 2019
02cb5a4
remove shlex.quote in scripts/spm_train.py (#972)
freewym Aug 19, 2019
79460d3
add constrains when checking multiple consecutive blank lines (#1031)
Trinkle23897 Aug 19, 2019
2eb53b8
Add instructions to resume training from released RoBERTa models (fix…
Aug 19, 2019
6ce55e4
Small fixes
Aug 19, 2019
c81fed4
Back out "[fairseq][PR] Fix bug (the returned value has a dimension m…
Aug 19, 2019
4812f64
Fix method has same name as property
Aug 20, 2019
9e5edc1
Give path when checkpoint can't be found (#1040)
aryamccarthy Aug 20, 2019
7a31fe0
vggblock support without pooling and pooling_kernel_size missing self…
siddalmia Aug 21, 2019
a2f5361
Multiset (#838)
alexeib Aug 21, 2019
ba5f829
Parameterized criterions (#808)
0xjc Aug 21, 2019
93057cc
fix string format to work in python 3.5 (#1050)
Trinkle23897 Aug 21, 2019
3c2cf3b
Misc changes
Aug 22, 2019
8c509a9
Add links to cuda models (#828)
nng555 Aug 22, 2019
d4c9136
Fix year in noisy channel citation (#842)
nng555 Aug 22, 2019
6e2bd79
wav2vec everstore support
Aug 23, 2019
4fc3953
Cythonize token block dataset (#834)
Aug 23, 2019
833f053
Suppress leaked semaphore warnings
Aug 23, 2019
8a8c069
fix cython dependency in the setup (#847)
Aug 26, 2019
3ab8e0f
wav2vec everstore support fix
Aug 27, 2019
396ff7f
installing numpy headers for cython
Aug 27, 2019
920b85d
Minor update of README.md of language model example (#1063)
soskek Aug 27, 2019
d2410c4
Minor cleanup for setup.py
Aug 27, 2019
108f94b
use numpy function for filter by size when possible (#845)
Aug 28, 2019
0a96d22
Fix multi-gpu training (fixes #1088)
Aug 29, 2019
8777465
Adopt Contributor Covenant
zpao Aug 30, 2019
4a7cd58
set numpy seed explicitly + other minor fixes (#850)
alexeib Aug 30, 2019
c1951aa
add missing colorize dataset
alexeib Aug 31, 2019
746e59a
Improve support for `python setup.py build_ext --inplace`
Aug 31, 2019
8d4588b
Cleaner handling of numpy-based extensions in setup.py
Aug 31, 2019
20dfba7
fixed numpy based size filtering (#854)
Sep 1, 2019
6c00b33
Fix an error in the command about Hierarchical Neural Story Generatio…
altale Sep 3, 2019
1f0f7cd
added cython to install_requires
Sep 3, 2019
1566cfb
Fix multilingual translation bug for to-many case
pipibjc Sep 4, 2019
3e3fe72
Return predicted token for RoBERTa filling mask
raedle Sep 5, 2019
1fd8943
Average local optimizer param after warmup and during bmuf sync
Sep 12, 2019
e1ba32a
added fast stats sync option (#858)
Sep 16, 2019
a3882ab
Update README.md
Sep 17, 2019
31dd13f
Fix link to RACE fine-tuning instructions.
nelson-liu Sep 17, 2019
718677e
dont project maske tokens for mlm loss (#859)
Sep 18, 2019
8dbee4a
Minor fix to make adafactor work for >2d conv kernels (#1122)
akhileshgotmare Sep 18, 2019
f994c9b
Add autogenerated cython files to gitignore (#860)
jma127 Sep 18, 2019
0eaaf35
Add cython language_level hints
Sep 19, 2019
a8a85c2
Add dataset class for weighted sampling with replacement. (#861)
jma127 Sep 19, 2019
3233540
added multilingual masked LM training (#849)
Sep 20, 2019
e869c80
Update README.race.md
Sep 20, 2019
10f9349
Remove extraneous call to RNG in multi-GPU code path
Sep 20, 2019
3b09b98
fixed train valid epoch iter
Sep 23, 2019
3f4fc50
Miscellaneous documentation improvements: (#868)
jma127 Sep 23, 2019
2ed65b6
fixed corner case in mlm criterion when all tokens get masked
Sep 23, 2019
fa7dea6
Issue 1146: Minor fix to roberta pre-training readme (#1165)
mortonjt Sep 24, 2019
e073ddf
PR for Issue #1154: Two comments in lstm.py seem to be incorrect
vineetk1 Sep 26, 2019
2314979
Update getting_started.rst (#1188)
Michaelvll Sep 27, 2019
62e65c4
Explain the language modelling format in RoBERTa pretraining readme
louismartin Sep 27, 2019
6c1da0f
Fixing BMUF warmup and sync strategy
Sep 27, 2019
86857a5
Levenshtein Transformer paper code
kahne Sep 27, 2019
1cb267e
Fixing example of batched predictions for Roberta (#1195)
justachetan Sep 27, 2019
ea1a410
RoBERTa now supported on TPU and TensorFlow via transformers library
Sep 28, 2019
4ac2c5f
Implementation of the WeCNLP abstract "Cross+Self-Attention for Trans…
stephanpeitz Sep 29, 2019
1351972
fix typo in README of examples/translation
Sep 29, 2019
acb6fba
Fix torch.hub to not depend on libnat
Sep 30, 2019
1c66792
Implementation of the paper "Jointly Learning to Align and Translate …
sarthakgarg Sep 30, 2019
58e43cb
extract FP16OptimizerMixin for share the same logic in PyText (#1180)
chenyangyu1988 Oct 1, 2019
de348d1
Native Torchscript Wordpiece Tokenizer Op for BERTSquadQA, Torchscrip…
Oct 4, 2019
315c463
Add periodic CUDA cache cleanup (#882)
jma127 Oct 4, 2019
4cb895b
add pre-trained wav2vec model
alexeib Oct 5, 2019
6f58e15
Setting Global sync to 50 in BMUF
Oct 7, 2019
c216522
fix max lengths in Levenshtein Tramsformer
kahne Oct 8, 2019
34e79c5
ensemble levts
Oct 8, 2019
63b6b3f
Add printing of PyTorch memory summary on OOM (#885)
jma127 Oct 8, 2019
b6e001f
Fix data loading memory issue in pyspeech
Oct 9, 2019
33646ac
wav2letter integration
0xjc Oct 10, 2019
c4893ca
Add ctc loss to ASR task (#1233)
Oct 10, 2019
cce92bd
add new_arange function + FIX BUGS of returning attn values
MultiPath Oct 11, 2019
02b74c5
fix the random mask function for CMLM model
MultiPath Oct 11, 2019
d80ad54
Added option to save checkpoints using Path Manager.
sujitoc Oct 12, 2019
e3a40d9
fix libnat imports
kahne Oct 15, 2019
b5f41f8
Add Unit test cases for BMUF
Oct 15, 2019
3dcb5c7
fix levenshtein transfromer attn
kahne Oct 18, 2019
c8a7b62
fixed a bug in preprocess glue dataset dev filename (#1270)
DikshaMeghwal Oct 18, 2019
b8d024e
add missing function to FairseqLanguageModel
Oct 18, 2019
a3c629b
Fix typos on Examples for Nonautoregressive translation
MultiPath Oct 20, 2019
66d24dc
Enable separate models for insertion and deletion;
MultiPath Oct 20, 2019
34e6a5e
Fix load_dataset signature (#1281)
louismartin Oct 22, 2019
2d51e04
Rename "loaded {} batches" to "loaded {} blocks" (#1279)
louismartin Oct 22, 2019
e49b302
fix score
kahne Oct 22, 2019
8defa9d
Add warmup support in reduce_on_plateau lr schedule
Oct 23, 2019
5a2f76e
NAT productionization
cndn Oct 24, 2019
39faa0a
Reset both WPS and UPS on first minibatch (#891)
jma127 Oct 24, 2019
d0358bb
fix inconsistency w/ recent pytorch cuda device logic
jma127 Oct 24, 2019
5b086a0
OSS tracing compliant transformer to unbreak master (#1299)
cndn Oct 24, 2019
fdf4c3e
Simplify fairseq multihead attention (#888)
halilakin Oct 25, 2019
c07362c
Convert matmuls to quantizable nn.Linear modules (#1304)
halilakin Oct 25, 2019
eb68afc
fix a type mismatch in NAT quantization run
xianxl Oct 26, 2019
dabbef4
adding layerdrop code for training, pruning, and readme (#890)
huihuifan Oct 27, 2019
50cf3bb
Fix LevT generator interface
cndn Oct 28, 2019
856d8b8
layer drop
xianxl Oct 30, 2019
f30fc7d
Fix MultiheadAttention and torch hub
Oct 31, 2019
99c524c
Fix fairspeq unit test
Oct 31, 2019
4c6b689
Remove in_proj_weight/in_proj_bias in multihead attention and fix the…
halilakin Nov 1, 2019
828c1ca
Fix BPE for dual learning
chtran Nov 1, 2019
a0f7599
Fix building of docs
Nov 2, 2019
fd7dcac
option to suppress loss report
taylanbil Nov 8, 2019
7a23b93
Making tpu training work
taylanbil Jun 13, 2019
f17ad03
send meters to device
taylanbil Nov 14, 2019
734b14f
Revert inplace masked_fill_s so convergence occurs
taylanbil Nov 16, 2019
d370e6b
Merge branch 'tpu-rebase-master' of github.com:taylanbil/fairseq into…
taylanbil Nov 16, 2019
043b6a9
git wtf
taylanbil Nov 16, 2019
12aaf54
Clean up comments, unused imports, and reuse var in checkpoint saving
taylanbil Nov 18, 2019
8de1826
Added comments to various places of tpu related code change, and fixe…
taylanbil Nov 18, 2019
5120a2b
Added comments to various places of tpu related code change, and fixe…
taylanbil Nov 18, 2019
bbfeec9
More documentation for sequence padding
taylanbil Nov 18, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 16 additions & 8 deletions fairseq/checkpoint_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

from fairseq.models import FairseqEncoder, FairseqDecoder

import torch_xla.core.xla_model as xm


def save_checkpoint(args, trainer, epoch_itr, val_loss):
from fairseq import distributed_utils, meters
Expand Down Expand Up @@ -62,15 +64,17 @@ def is_better(a, b):
extra_state.update({'best': save_checkpoint.best})

checkpoints = [os.path.join(args.save_dir, fn) for fn, cond in checkpoint_conds.items() if cond]

if len(checkpoints) > 0:
trainer.save_checkpoint(checkpoints[0], extra_state)
for cp in checkpoints[1:]:
try:
from fairseq.fb_pathmgr import fb_pathmgr
fb_pathmgr.copy(checkpoints[0], cp, True)
if getattr(args, 'use_gpu', True) or xm.is_master_ordinal():

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe move this out of the try into a local and reuse.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will do

fb_pathmgr.copy(checkpoints[0], cp, True)
except (ModuleNotFoundError, ImportError):
shutil.copyfile(checkpoints[0], cp)

if getattr(args, 'use_gpu', True) or xm.is_master_ordinal():
shutil.copyfile(checkpoints[0], cp)
write_timer.stop()
print('| saved checkpoint {} (epoch {} @ {} updates) (writing took {} seconds)'.format(
checkpoints[0], epoch, updates, write_timer.sum))
Expand All @@ -97,7 +101,7 @@ def is_better(a, b):
def load_checkpoint(args, trainer, data_selector=None):
"""Load a checkpoint and restore the training iterator."""
# only one worker should attempt to create the required dir
if args.distributed_rank == 0:
if args.distributed_rank == 0 or xm.is_master_ordinal():
os.makedirs(args.save_dir, exist_ok=True)

if args.restore_file == 'checkpoint_last.pt':
Expand Down Expand Up @@ -210,7 +214,8 @@ def checkpoint_paths(path, pattern=r'checkpoint(\d+)\.pt'):
def torch_persistent_save(*args, **kwargs):
for i in range(3):
try:
return torch.save(*args, **kwargs)
save_func = xm.save if kwargs.pop('xla', False) else torch.save
return save_func(*args, **kwargs)
except Exception:
if i == 2:
logging.error(traceback.format_exc())
Expand Down Expand Up @@ -256,14 +261,17 @@ def save_state(
state_dict['criterion'] = criterion.state_dict()
if not args.no_save_optimizer_state:
state_dict['last_optimizer_state'] = convert_state_dict_type(optimizer.state_dict())

try:
from fairseq.fb_pathmgr import fb_pathmgr
with fb_pathmgr.open(filename, "wb") as f:
torch_persistent_save(state_dict, f)
torch_persistent_save(
state_dict, f, xla=not getattr(args, 'use_gpu', True)
)
except (ModuleNotFoundError, ImportError):
# if path manager not found, continue with local file.
torch_persistent_save(state_dict, filename)
torch_persistent_save(
state_dict, filename, xla=not getattr(args, 'use_gpu', True)
)


def _upgrade_state_dict(state):
Expand Down
8 changes: 4 additions & 4 deletions fairseq/criterions/label_smoothed_cross_entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ def label_smoothed_nll_loss(lprobs, target, epsilon, ignore_index=None, reduce=T
smooth_loss = -lprobs.sum(dim=-1, keepdim=True)
if ignore_index is not None:
non_pad_mask = target.ne(ignore_index)
nll_loss = nll_loss[non_pad_mask]
smooth_loss = smooth_loss[non_pad_mask]
nll_loss.masked_fill_(~non_pad_mask, 0.0)
smooth_loss.masked_fill_(~non_pad_mask, 0.0)
else:
nll_loss = nll_loss.squeeze(-1)
smooth_loss = smooth_loss.squeeze(-1)
Expand Down Expand Up @@ -57,8 +57,8 @@ def forward(self, model, sample, reduce=True):
loss, nll_loss = self.compute_loss(model, net_output, sample, reduce=reduce)
sample_size = sample['target'].size(0) if self.args.sentence_avg else sample['ntokens']
logging_output = {
'loss': utils.item(loss.data) if reduce else loss.data,
'nll_loss': utils.item(nll_loss.data) if reduce else nll_loss.data,
'loss': loss.data,
'nll_loss': nll_loss.data,
'ntokens': sample['ntokens'],
'nsentences': sample['target'].size(0),
'sample_size': sample_size,
Expand Down
37 changes: 34 additions & 3 deletions fairseq/data/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,25 @@ def infer_language_pair(path):
return src, dst


def collate_tokens(values, pad_idx, eos_idx=None, left_pad=False, move_eos_to_beginning=False):
def get_pad_size(values, input_shapes):
if input_shapes is None:
return max(v.size(0) for v in values)
for batch_size, padlen in input_shapes:
if len(values) == batch_size:
return padlen
else:
raise IndexError(
'Encountered values with invalid length {}, input shapes were {}'
.format(len(values), input_shapes)
)


def collate_tokens(
values, pad_idx, eos_idx=None, left_pad=False,
move_eos_to_beginning=False, input_shapes=None,
):
"""Convert a list of 1d tensors into a padded 2d tensor."""
size = max(v.size(0) for v in values)
size = get_pad_size(values, input_shapes)
res = values[0].new(len(values), size).fill_(pad_idx)

def copy_tensor(src, dst):
Expand Down Expand Up @@ -227,10 +243,25 @@ def batch_by_size(

if isinstance(indices, types.GeneratorType):
indices = np.fromiter(indices, dtype=np.int64, count=-1)

return batch_by_size_fast(indices, num_tokens_fn, max_tokens, max_sentences, bsz_mult)


def batch_by_size_tpu(
indices, num_tokens_fn, input_shapes
):
batches = [[] for _ in input_shapes]
for idx in indices:
sample_len = num_tokens_fn(idx)
for j, (batch_size, padlen) in enumerate(input_shapes):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this assuming that the input_shapes list will be sorted by shortest to longest padlen?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes

if padlen < sample_len:
continue
batches[j].append(idx)
if len(batches[j]) == batch_size:
yield batches[j]
batches[j] = []
break


def process_bpe_symbol(sentence: str, bpe_symbol: str):
if bpe_symbol == 'sentencepiece':
sentence = sentence.replace(' ', '').replace('\u2581', ' ').strip()
Expand Down
12 changes: 7 additions & 5 deletions fairseq/data/language_pair_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,15 @@

def collate(
samples, pad_idx, eos_idx, left_pad_source=True, left_pad_target=False,
input_feeding=True,
input_feeding=True, input_shapes=None,
):
if len(samples) == 0:
return {}

def merge(key, left_pad, move_eos_to_beginning=False):
return data_utils.collate_tokens(
[s[key] for s in samples],
pad_idx, eos_idx, left_pad, move_eos_to_beginning,
[s[key] for s in samples], pad_idx,
eos_idx,left_pad, move_eos_to_beginning, input_shapes,
)

def check_alignment(alignment, src_len, tgt_len):
Expand Down Expand Up @@ -154,7 +154,8 @@ def __init__(
shuffle=True, input_feeding=True,
remove_eos_from_source=False, append_eos_to_target=False,
align_dataset=None,
append_bos=False
append_bos=False,
input_shapes=None,
):
if tgt_dict is not None:
assert src_dict.pad() == tgt_dict.pad()
Expand All @@ -178,6 +179,7 @@ def __init__(
if self.align_dataset is not None:
assert self.tgt_sizes is not None, "Both source and target needed when alignments are provided"
self.append_bos = append_bos
self.input_shapes = input_shapes

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Optional: maybe we could add a docstring for this guy and clarify how it should be sorted?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

see https://github.com/pytorch-tpu/fairseq/blob/tpu/train.py#L291-L298, we error while parsing the input args if the shapes passed in doesn't satisfy the assumption, and describe requirements.


def __getitem__(self, index):
tgt_item = self.tgt[index] if self.tgt is not None else None
Expand Down Expand Up @@ -249,7 +251,7 @@ def collater(self, samples):
return collate(
samples, pad_idx=self.src_dict.pad(), eos_idx=self.src_dict.eos(),
left_pad_source=self.left_pad_source, left_pad_target=self.left_pad_target,
input_feeding=self.input_feeding,
input_feeding=self.input_feeding, input_shapes=self.input_shapes,
)

def num_tokens(self, index):
Expand Down
8 changes: 4 additions & 4 deletions fairseq/models/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,8 +355,8 @@ def forward(self, src_tokens, src_lengths, cls_input=None, return_all_hiddens=Fa

# compute padding mask
encoder_padding_mask = src_tokens.eq(self.padding_idx)
if not encoder_padding_mask.any():
encoder_padding_mask = None
#if not encoder_padding_mask.any():
# encoder_padding_mask = None

encoder_states = [] if return_all_hiddens else None

Expand Down Expand Up @@ -596,8 +596,8 @@ def extract_features(
x = x.transpose(0, 1)

self_attn_padding_mask = prev_output_tokens.eq(self.padding_idx)
if not self_attn_padding_mask.any() and not self.cross_self_attention:
self_attn_padding_mask = None
# if not self_attn_padding_mask.any() and not self.cross_self_attention:
# self_attn_padding_mask = None

# decoder layers
attn = None
Expand Down
96 changes: 70 additions & 26 deletions fairseq/modules/multihead_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,17 @@ def __init__(self, embed_dim, num_heads, kdim=None, vdim=None, dropout=0., bias=
assert not self.self_attention or self.qkv_same_dim, 'Self-attention requires query, key and ' \
'value to be of the same size'

self.k_proj = nn.Linear(self.kdim, embed_dim, bias=bias)
self.v_proj = nn.Linear(self.vdim, embed_dim, bias=bias)
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
if self.qkv_same_dim:
self.in_proj_weight = Parameter(torch.Tensor(3 * embed_dim, embed_dim))
else:
self.k_proj_weight = Parameter(torch.Tensor(embed_dim, self.kdim))
self.v_proj_weight = Parameter(torch.Tensor(embed_dim, self.vdim))
self.q_proj_weight = Parameter(torch.Tensor(embed_dim, embed_dim))

if bias:
self.in_proj_bias = Parameter(torch.Tensor(3 * embed_dim))
else:
self.register_parameter('in_proj_bias', None)

self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)

Expand All @@ -57,11 +65,12 @@ def __init__(self, embed_dim, num_heads, kdim=None, vdim=None, dropout=0., bias=

self.onnx_trace = False

# XXX: (taylanbil) try F.multi...
self.enable_torch_version = False
if hasattr(F, "multi_head_attention_forward"):
self.enable_torch_version = True
else:
self.enable_torch_version = False
# if hasattr(F, "multi_head_attention_forward"):
# self.enable_torch_version = True
# else:
# self.enable_torch_version = False

def prepare_for_onnx_export_(self):
self.onnx_trace = True
Expand All @@ -70,15 +79,15 @@ def reset_parameters(self):
if self.qkv_same_dim:
# Empirically observed the convergence to be much better with
# the scaled initialization
nn.init.xavier_uniform_(self.k_proj.weight, gain=1/math.sqrt(2))
nn.init.xavier_uniform_(self.v_proj.weight, gain=1/math.sqrt(2))
nn.init.xavier_uniform_(self.q_proj.weight, gain=1/math.sqrt(2))
nn.init.xavier_uniform_(self.in_proj_weight, gain=1/math.sqrt(2))
else:
nn.init.xavier_uniform_(self.k_proj.weight)
nn.init.xavier_uniform_(self.v_proj.weight)
nn.init.xavier_uniform_(self.q_proj.weight)
nn.init.xavier_uniform_(self.k_proj_weight)
nn.init.xavier_uniform_(self.v_proj_weight)
nn.init.xavier_uniform_(self.q_proj_weight)

nn.init.xavier_uniform_(self.out_proj.weight)
if self.in_proj_bias is not None:
nn.init.constant_(self.in_proj_bias, 0.)
nn.init.constant_(self.out_proj.bias, 0.)
if self.bias_k is not None:
nn.init.xavier_normal_(self.bias_k)
Expand Down Expand Up @@ -146,23 +155,19 @@ def forward(
saved_state = None

if self.self_attention:
q = self.q_proj(query)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For my understanding, was this for performance improvements? If so did it help?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, this was causing a 10% regression.

k = self.k_proj(query)
v = self.v_proj(query)
q, k, v = self.in_proj_qkv(query)
elif self.encoder_decoder_attention:
# encoder-decoder attention
q = self.q_proj(query)
q = self.in_proj_q(query)
if key is None:
assert value is None
k = v = None
else:
k = self.k_proj(key)
v = self.v_proj(key)
k = self.in_proj_k(key)
v = self.in_proj_v(key)

else:
q = self.q_proj(query)
k = self.k_proj(key)
v = self.v_proj(value)
raise
q *= self.scaling

if self.bias_k is not None:
Expand Down Expand Up @@ -242,10 +247,9 @@ def forward(
if key_padding_mask is not None:
# don't attend to padding symbols
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
attn_weights = attn_weights.masked_fill(
key_padding_mask.unsqueeze(1).unsqueeze(2),
float('-inf'),
)
attn_weights = attn_weights.transpose(0, 2)
attn_weights.masked_fill_(key_padding_mask, float('-inf'))
attn_weights = attn_weights.transpose(0, 2)
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)

if before_softmax:
Expand Down Expand Up @@ -330,3 +334,43 @@ def upgrade_state_dict_named(self, state_dict, name):

for key, value in items_to_add.items():
state_dict[key] = value

def in_proj_qkv(self, query):
return self._in_proj(query).chunk(3, dim=-1)

def in_proj_q(self, query):
if self.qkv_same_dim:
return self._in_proj(query, end=self.embed_dim)
else:
bias = self.in_proj_bias
if bias is not None:
bias = bias[:self.embed_dim]
return F.linear(query, self.q_proj_weight, bias)

def in_proj_k(self, key):
if self.qkv_same_dim:
return self._in_proj(key, start=self.embed_dim, end=2 * self.embed_dim)
else:
weight = self.k_proj_weight
bias = self.in_proj_bias
if bias is not None:
bias = bias[self.embed_dim:2 * self.embed_dim]
return F.linear(key, weight, bias)

def in_proj_v(self, value):
if self.qkv_same_dim:
return self._in_proj(value, start=2 * self.embed_dim)
else:
weight = self.v_proj_weight
bias = self.in_proj_bias
if bias is not None:
bias = bias[2 * self.embed_dim:]
return F.linear(value, weight, bias)

def _in_proj(self, input, start=0, end=None):
weight = self.in_proj_weight
bias = self.in_proj_bias
weight = weight[start:end, :]
if bias is not None:
bias = bias[start:end]
return F.linear(input, weight, bias)
15 changes: 11 additions & 4 deletions fairseq/tasks/fairseq_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,10 +146,17 @@ def get_batch_iterator(
)

# create mini-batches with given size constraints
batch_sampler = data_utils.batch_by_size(
indices, dataset.num_tokens, max_tokens=max_tokens, max_sentences=max_sentences,
required_batch_size_multiple=required_batch_size_multiple,
)
if getattr(self.args, 'use_gpu', True):
batch_sampler = data_utils.batch_by_size(
indices, dataset.num_tokens, max_tokens=max_tokens,
max_sentences=max_sentences,
required_batch_size_multiple=required_batch_size_multiple,
)
else:
batch_sampler = data_utils.batch_by_size_tpu(
indices, dataset.num_tokens,
getattr(self.args, 'input_shapes', None)
)

# return a reusable, sharded iterator
epoch_iter = iterators.EpochBatchIterator(
Expand Down
Loading