-
Notifications
You must be signed in to change notification settings - Fork 7
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Rebasing tpu branch on a more recent fairseq upstream commit #19
Rebasing tpu branch on a more recent fairseq upstream commit #19
Conversation
Summary: No major API changes since the last release. Cutting a new release since we'll be merging significant (possibly breaking) changes to logging, data loading and the masked LM implementation soon. Pull Request resolved: facebookresearch#891 Differential Revision: D16377132 Pulled By: myleott fbshipit-source-id: f1cb88e671ccd510e53334d0f449fe18585268c7
Summary: Pull Request resolved: fairinternal/fairseq-py#735 Differential Revision: D16377046 Pulled By: myleott fbshipit-source-id: 9725d4a3ce6b2fc8cee0b1d1cb8921f9d59c551a
Summary: Pull Request resolved: fairinternal/fairseq-py#734 Differential Revision: D16377044 Pulled By: myleott fbshipit-source-id: 37d5553d76aa7c653113fec089f59710281c31d7
Summary: Pull Request resolved: fairinternal/fairseq-py#737 Differential Revision: D16377805 Pulled By: myleott fbshipit-source-id: 1e090a02ff4fbba8695173f57d3cc5b88ae98bbf
Summary: Pull Request resolved: fairinternal/fairseq-py#739 Differential Revision: D16377798 Pulled By: myleott fbshipit-source-id: 20047c80de2e6f108269ace4ae3eec906a5920dd
Summary: Pull Request resolved: fairinternal/fairseq-py#736 Differential Revision: D16378001 Pulled By: myleott fbshipit-source-id: 2907f63bcbf7068ceaa48b00096040fa2639e569
Summary: Pull Request resolved: fairinternal/fairseq-py#738 Differential Revision: D16377803 Pulled By: myleott fbshipit-source-id: 6beb2f78e7464b70ff65a965d2b747cdca0ca951
Summary: Pull Request resolved: fairinternal/fairseq-py#747 Differential Revision: D16403464 Pulled By: myleott fbshipit-source-id: ee3b4184f129a02be833c7bdc00685978b4de883
Summary: Two issues here: 1. `last_included` should be the last included index `cumsum_mask[:, :, -1:]` instead of `cumsum_mask[:, :, :1]` (which is either 0 or 1); 2. If `--no-repeat-ngram-size` is set, the sum of `probs` may less than 1, we need to re-normalize to make it a valid probability distribution The following code can reproduce this issues: ``` import torch import numpy as np def _sample_topp(probs): # ===== Code from fairseq/search.py _sample_topp ====== # sort the last dimension (vocab dimension) in descending order sorted_probs, sorted_indices = probs.sort(descending=True) # compute a mask to indicate the words to be included in the top-P set. cumsum_probs = sorted_probs.cumsum(dim=2) mask = cumsum_probs.lt(sampling_topp) # note that mask was computed by 'lt'. One more word needs to be included # so that the cumulative probability mass can exceed p. cumsum_mask = mask.cumsum(dim=2) last_included = cumsum_mask[:, :, :1] mask = mask.scatter_(2, last_included, 1) # truncate unnecessary dims. max_dim = last_included.max() truncated_mask = mask[:, :, :max_dim + 1] truncated_probs = sorted_probs[:, :, :max_dim + 1] truncated_indices = sorted_indices[:, :, :max_dim + 1] # trim the words that are not in top-P by setting their probabilities # to 0, so that they would not be sampled later. trim_mask = 1 - truncated_mask trimed_probs = truncated_probs.masked_fill_(trim_mask, 0) return trimed_probs, truncated_indices # ======================================================== if __name__ == '__main__': np.random.seed(1234) torch.manual_seed(1234) sampling_topp = 0.9 probs = torch.softmax(torch.randn(1, 1, 10), dim=-1) # probs = tensor([0.0545, 0.0779, 0.0189, 0.0647, 0.0282, 0.0862, 0.0656, 0.1041, 0.0399, 0.4600]) print('probs =', probs[0][0]) trimed_probs, truncated_indices = _sample_topp(probs) cum_probs = trimed_probs.cumsum(dim=-1)[0][0] # cumsum = tensor([0.4600, 0.5641]) print('cumsum =', cum_probs) # Will throw AssertionError assert float(cum_probs[-1]) >= sampling_topp ``` Pull Request resolved: facebookresearch#882 Differential Revision: D16409269 Pulled By: xingz9 fbshipit-source-id: 94b1122eed50c656057b64e22af6f4a6ea7a68af
Summary: Pull Request resolved: fairinternal/fairseq-py#751 Differential Revision: D16410989 Pulled By: myleott fbshipit-source-id: ddbbee49756f9ff6c4487977a3f5d2259b7abafe
Summary: Pull Request resolved: fairinternal/fairseq-py#749 Differential Revision: D16410984 Pulled By: myleott fbshipit-source-id: 7698df46b8a179afccb287990f9705358690454a
Summary: Pull Request resolved: fairinternal/fairseq-py#750 Differential Revision: D16410986 Pulled By: myleott fbshipit-source-id: 8ee6b4371d6ae5b041b00a54a6039a422345795e
Summary: Pull Request resolved: fairinternal/fairseq-py#740 Differential Revision: D16377797 Pulled By: myleott fbshipit-source-id: f7d6c8b00a77e279ea94376b1f0fcd15087eaf5f
Summary: Pull Request resolved: fairinternal/fairseq-py#752 Differential Revision: D16417582 Pulled By: myleott fbshipit-source-id: 6b4289febcf9290452bb91f1f2181a02c09c82a7
Summary: Pull Request resolved: fairinternal/fairseq-py#756 Differential Revision: D16418302 Pulled By: myleott fbshipit-source-id: 62495a0bff41d1741e2b09807a3b43ff2c66c8fb
Summary: Pull Request resolved: fairinternal/fairseq-py#758 Differential Revision: D16418932 Pulled By: myleott fbshipit-source-id: 59f005164b61b9fa712922eeb23525f7eec38f38
Summary: Pull Request resolved: fairinternal/fairseq-py#757 Differential Revision: D16418305 Pulled By: myleott fbshipit-source-id: 25f293a2792509f7a75c688e4bf8cff02e6bba2e
Summary: Pull Request resolved: fairinternal/fairseq-py#761 Differential Revision: D16421335 Pulled By: myleott fbshipit-source-id: 257d92c2b90361147642e2baa38486b4d18f6297
…h#804) Summary: Pull Request resolved: facebookresearch/pytext#804 Pull Request resolved: fairinternal/fairseq-py#746 Pull Request resolved: facebookresearch#894 Adding an implementation of the sparse transformer to multi-head attention using the fixed attention pattern specified https://arxiv.org/pdf/1904.10509.pdf. The sparse_mask masks out words using -inf; after softmax, -inf becomes 0. Thus, a mask does not need to be re-calculated and re-applied when multiplying attn_weights and values. Four inputs are added to the config: sparse, is_bidirectional, stride, expressivity. If we are using the sparse transformer, is_bidirectional, stride, and expressivity must be specified (there are defaults). If is_bidirectional is False, the mask values using the fixed attention pattern described in the paper. If is_bidirectional is True, subset one includes all values in the current stride window and a summary from every stride window--all other values are masked. Stride (L in the paper) controls the window size and expressivity (c in the paper) controls the size of the summary. Reviewed By: borguz Differential Revision: D16042988 fbshipit-source-id: c59166dc7cfe89187a256e4076000c2458842fd5
Summary: Pull Request resolved: fairinternal/fairseq-py#762 Differential Revision: D16427266 Pulled By: myleott fbshipit-source-id: 9bd9b8c6b4994ae98a62a37b34d03265bd365453
Summary: Since mask really is a tensor of ints, this change should be mathematically equivalent to the base. On the other hand, this has performance implications for xla, hence the pull request. Pull Request resolved: facebookresearch#875 Differential Revision: D16232877 Pulled By: myleott fbshipit-source-id: e63175ee0016dcf0dfe10e2fd22570b8bbfbde84
Summary: Pull Request resolved: facebookresearch#899 Differential Revision: D16448602 Pulled By: myleott fbshipit-source-id: afd1a1b713274b6328150cd85d7f8a81833597aa
Summary: I sadly discovery that my checkpoint directory wasn't globally readable after 8 hours of training. Adding this check at the beginning of train loop to keep that from happening again! Reviewed By: myleott Differential Revision: D16455394 fbshipit-source-id: 35959aa058150b2afb63710c468d01ebc8a12b0c
Summary: Pull Request resolved: fairinternal/fairseq-py#770 Differential Revision: D16491911 Pulled By: myleott fbshipit-source-id: 8dd2b76f8fa24183640ae9d1129ea47ded77d43d
facebookresearch#769) Summary: Input feeding generally refers to a slightly different concept Pull Request resolved: fairinternal/fairseq-py#769 Differential Revision: D16491898 Pulled By: myleott fbshipit-source-id: 68573584e820f11f199db4e7e37e9ee7a69a3287
Summary: Pull Request resolved: fairinternal/fairseq-py#778 Differential Revision: D16525447 Pulled By: myleott fbshipit-source-id: e721e3a10e243a2408a04f89f06b5adbbe2fdff2
Summary: Pull Request resolved: facebookresearch#909 Differential Revision: D16532919 Pulled By: myleott fbshipit-source-id: 16ce884cf3d84579026e4406a75ba3c01a128dbd
Summary: Pull Request resolved: facebookresearch#910 Differential Revision: D16536532 Pulled By: myleott fbshipit-source-id: 56bb5570e70b5670ad87c64d9dd20c64c1fa9f5c
Summary: Pull Request resolved: facebookresearch#913 Differential Revision: D16536562 Pulled By: myleott fbshipit-source-id: ce28642da6868ec884e3e416388a652977a062df
Summary: Pull Request resolved: facebookresearch#911 Differential Revision: D16536559 Pulled By: myleott fbshipit-source-id: 7fe495054ce5b7658b1d3a43eca38c5858360236
optimizer fix progress bar comment out temporarily some changes to train_tpu int mask instead of float pfpfpfpf fix printing device index per loop bkpt to investigate resize_ call attempting to init buffer size to 2*dim bkpt better print do not drop records when computing loss Changes that reduce graph compiles. * Loss function replaced with an equivalent logic that doesn't resize tensors. * cli args changed to guarantee consistency * collate_tokens function in fairseq/data/data_utils.py overwritten to guarantee consistency undoing some changes made while debugging progress_bar implements len some irrelevant changes to train_tpu.py new xla changes bug fix in enable_torch_version removing the last batch that is of diferent size from the iterator delete optimizer step in fairseq s trainer Added `self.xla` flag that controls if Trainer includes optimizer step + Tried to include more explanation why skip optimizer step this time deleted obsolete file add norm clipping count back in (#4) remove grad norm clip count (#5) Change masked_fill_ input in loss in order to accomodate necessary pytorch changes (#6) Adding tpu capabilities to train.py (facebookresearch#8) * Adding tpu capabilities to train.py * flush when printing for better user experience * separated cli_main into parse_args, maingpu and maintpu deleted unused line in datautils.py Enumerate the loader in training and validation (facebookresearch#9) * Adding tpu capabilities to train.py * flush when printing for better user experience * separated cli_main into parse_args, maingpu and maintpu deleted unused line in datautils.py * Enumerate the loader * enumerate the loader Add option to assert on training and/or validation loss (facebookresearch#10) * Add option to assert on training and/or validation loss * applied suggestion None loss should be filled to inf (facebookresearch#11) Enabling multiprocessing for fairseq training. (facebookresearch#12) * initial commit for multiprocess api * indentation fixes and import fix * no need to softlink, fix save/load * Remove the hacks to only save from master ordinal as xm.save takes care of that * fix indentation; 3 -> 4 spaces * Moved xu.eprints after spawn and dropping last batches better trainers->trainer (facebookresearch#13) fix bug in assert_on_losses Replace usage of unsqueeze with transpose + broadcasting (facebookresearch#15) remove attn mask + loss rewrite + save per host + format suppress loss report allow usage of batch_by_size in translation. attn_weights masked fill in place Clean up the log output suppressing a bit Revert multihead attn's in_proj code changes non-rebased tpu branch is about 10% faster on TPUs compared to the rebased branch. The regression is inside multihead attn's in_proj mechanism. Reverting the relevant changes to preserve performance. Pass correct args to the new get_valid_stats function Send meters to device in order not to fail training when resuming dfrom chkpt
… taylanbil-tpu-rebase-master
For completeness, I'm doing an e2e run atm. Will check back on Monday. |
One flyby comment before review 😉 |
7a23b93 is the squashed version of all my commits over the last months :'( I didn't really intend to do any changes other than resolving conflicts, but it all got messed up in the end, sorry. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is pretty hard to review things like this.
fairseq/checkpoint_utils.py
Outdated
if len(checkpoints) > 0: | ||
trainer.save_checkpoint(checkpoints[0], extra_state) | ||
for cp in checkpoints[1:]: | ||
try: | ||
from fairseq.fb_pathmgr import fb_pathmgr | ||
fb_pathmgr.copy(checkpoints[0], cp, True) | ||
if getattr(args, 'use_gpu', True) or xm.is_master_ordinal(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe move this out of the try
into a local and reuse.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
will do
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are these your changes? Did they come with a rebase? Hard to tell 😉
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You may want to squash these commits and fix the comments (not just list the full original comments).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Apologies, I honestly don't know what the proper github etiquette is for prs like this. This is basically months of work upstream, merged with months of work in our fork.
I don't really have any new functionalities, changes etc in this pr. Everything I have here is already reviewed stuff, + merge conflict resolutions. Those resolutions are really the only interesting bits of the pr, but there's no way to isolate them afaik.
I don't think I can squash the upstream commits, can I? I can squash the last 6 commits, but I tried that and failed miserably :'( (#18)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not the PITA guy asking to split PRs for no reason.
But in this case, the vast majority of this did not need a review, as it was stuff coming off a rebase.
I think you created this PR from the wrong tip of the repo, which ended up including all the commit IDs from the rebase.
In general, about the code, we need to keep in mind that this will be the code folks will look to understand what we are doing to optimize pytorch code for TPU, so clean code and comments would really help.
You can squash your commits:
git rebase -i HEAD~8
Read the documentation about it.
Then force-push.
fairseq/data/data_utils.py
Outdated
batches = [[] for _ in input_shapes] | ||
for idx in indices: | ||
sample_len = num_tokens_fn(idx) | ||
for j, (batch_size, padlen) in enumerate(input_shapes): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this assuming that the input_shapes
list will be sorted by shortest to longest padlen
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes
@@ -178,6 +179,7 @@ def __init__( | |||
if self.align_dataset is not None: | |||
assert self.tgt_sizes is not None, "Both source and target needed when alignments are provided" | |||
self.append_bos = append_bos | |||
self.input_shapes = input_shapes |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Optional: maybe we could add a docstring for this guy and clarify how it should be sorted?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
see https://github.com/pytorch-tpu/fairseq/blob/tpu/train.py#L291-L298, we error while parsing the input args if the shapes passed in doesn't satisfy the assumption, and describe requirements.
@@ -146,23 +155,19 @@ def forward( | |||
saved_state = None | |||
|
|||
if self.self_attention: | |||
q = self.q_proj(query) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For my understanding, was this for performance improvements? If so did it help?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, this was causing a 10% regression.
self.meters['clip'].update( | ||
1. if grad_norm > self.args.clip_norm and self.args.clip_norm > 0 else 0. | ||
0. | ||
#1. if grad_norm > self.args.clip_norm and self.args.clip_norm > 0 else 0. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We don't need this comment right and all other comments in the PR?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah, this line causes a _local_scalar_dense, so I replaced it with 0. The meter is essentially invalid. I'll cancel this meter if xla.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
actually, may I handle this in a separate pr? I have one in the works that will change the logic around meters a bit.
train.py
Outdated
from fairseq import checkpoint_utils, distributed_utils, options, progress_bar, tasks, utils | ||
import torch_xla | ||
import torch_xla.debug.metrics as met | ||
import torch_xla.distributed.data_parallel as dp |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we remove unused imports like data_parallel
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yep, will do, thanks.
Synced offline with @dlibenzi, I am now adding comments in code in the places I edited vs upstream/master. |
Here is all my changes. I'm in the process of adding comments for them. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Approving, but as I mentioned in another comment, the code here will be used by folks wanting to embark similar models, as reference.
So I wish we could get this a bit more documented.
…d the multihead attention switch case
…d the multihead attention switch case
Re-do of #18
taylanbil#1 is showing equivalency between the 2 branches
tpu-rebase-master
andtaylanbil-tpu-rebase-master
. For some reson, this pr can be merged and the other can't. I cannot git, sorry.This rebase will enable working on RoBERTa.
Transformer on WMT18 performance is about the same. Convergence is verified via bleu score being 25.83 in epoch 5 (will not merge before verifying e2e convergence) and validation loss going down similarly to current tpu branch
Getting this out there for review as it's a big pr. My commit is the last 2, and second to last one is a giant commit. it's many commits squashed to 1.