-
Notifications
You must be signed in to change notification settings - Fork 636
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[T104292598] Refactor the "LRA" training code -> Pytorch Lightning #343
Conversation
* minor cleanup; updated changelog * fixed mypy error * added checking for blocksparse availability Co-authored-by: Chris Yuan <christopheryuan@learnfair1490.h2.fair> Co-authored-by: Chris Yuan <christopheryuan@devfair0278.h2.fair>
@dianaml0 @blefaudeux Hope it's okay to tag you both as reviewers based on the original task (T104292598). |
sounds great, thank you for working on this ! |
|
||
return model | ||
|
||
|
||
def build_training_setup( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh wow, this is some sizeable cleanup.. thank you !
xformers/benchmarks/LRA/run_tasks.py
Outdated
|
||
# Training epochs | ||
if accumu_steps > 1: | ||
config_training["num_train_steps"] *= accumu_steps |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure that this is still required with lightning, it handles grad accumulation out of the box, right ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I thought so too but based off this warning, it sounds like gradient accumulation coupled with DDP behaves differently. I'll look into it in more detail.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is fine, it just means that prior to the optimizer step the gradients will not be in sync across the fleet, while it's the case without the accumulation. It's a useful warning if you were to peek into the gradients on a per-rank basis, and decide on something from that, in that case triggering the grad acc could mess with your logic. We're not doing that here, simply training over all the ranks, so the default lightning behaviour should be fine
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah makes sense—thanks! Will get rid of this block then.
Looks good to me @lisjin, thank you for all this work ! Quick follow up, sorry for the delay
|
Codecov Report
@@ Coverage Diff @@
## main #343 +/- ##
==========================================
- Coverage 93.91% 93.89% -0.03%
==========================================
Files 70 70
Lines 3960 3962 +2
==========================================
+ Hits 3719 3720 +1
- Misses 241 242 +1
Flags with carried forward coverage won't be shown. Click here to find out more.
Continue to review full report at Codecov.
|
@blefaudeux I ran experiments for
|
@@ -277,7 +277,9 @@ def forward( | |||
|
|||
# Apply the optional input masking | |||
if encoder_input_mask is not None: | |||
x += encoder_input_mask.unsqueeze(0).unsqueeze(-1) | |||
if x.dim() - encoder_input_mask.dim() > 1: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I had to add this check to avoid a tensor shape mismatch error.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh, well done, thanks for fixing that
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks a million @lisjin ! I'll defer to Diana and Francisco for another validation/landing, but I think that it's a lot cleaner indeed.. thanks for the validation runs also, great PR which was not trivial !
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks so much for your contribution! Great to have this improvement and that the LRA results have been validated!
def __init__(self, config, model_name): | ||
super().__init__() | ||
|
||
config_model = config["model"] | ||
self.config_training = config["training"] | ||
|
||
self.enable_amp = config["training"]["mixed_precision"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems like this is no longer being used?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A little buried, but it's being used in configure_optimizers
.
…atch-1 Fix duplicate calculations on baseline for mem efficient transformers
What does this PR do?
Refactor LRA run_tasks.py so that it uses Pytorch Lightning as a trainer.
Before submitting
PR review
Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.