-
Notifications
You must be signed in to change notification settings - Fork 27.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add Flax BART pretraining script #18297
Add Flax BART pretraining script #18297
Conversation
…a21/transformers into add-flax-pretraining-bart
The documentation is not available anymore as the PR was closed or merged. |
This looks very nice to me! Just a small question, why is the file called |
@patil-suraj could you take a look here as well ? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks good to me! Maybe @sanchit-gandhi can also take a quick look here
@patrickvonplaten It stands for |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! Thank you @duongna21 for adding this example. Really interesting to see the training logs too!
eval_normalizer = eval_metrics.pop("normalizer") | ||
eval_metrics = jax.tree_map(lambda x: x / eval_normalizer, eval_metrics) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice! This is a clean way of normalising the eval_metrics
!
Technically speaking, in the train_step
, the pmap
won't compute a 'true' mean over devices. Here, what you're doing is computing a normalised loss on each device, and then averaging these losses over devices. This isn't strictly equal to summing the losses over all devices, and then dividing by the number of samples.
Let loss.sum()
) and the number of samples label_mask.sum()
). In the loss_fn
, we compute the normalised loss on each device (loss.sum() / label_mask.sum()
):
and then average over devices with the pmap
:
Whereas, for a 'true' loss, we should first add up all the losses over devices:
and then divide by the total number of labels:
where
If we compare the two and ignore the constant pmap
average:
we see that the losses are in-fact different. The first expression is what you get if you average the losses on each device, then average these terms over devices with a pmap
. The second expression is a 'true' loss, what you get by summing the losses on each device, summing these losses over devices, and then dividing by the total number of terms in your batch (= sum of the label_mask
per device, summing these terms over devices).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You're right, we should not simply use pmean
in case of masked labels. I actually borrowed this implementation from run_mlm_flax.py
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Indeed, we should update the other Flax training examples to reflect this!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@sanchit-gandhi Sure, this trick should be applied to tasks that do not have the same run_mlm_flax
, run_bart_dlm_flax
, run_summarization_flax
and run_image_captioning_flax
.
However, the changes should be made in run_summarization_flax
and run_image_captioning_flax
are nontrivial because of the existance of compute_loss
function (which should return only 1 scalar - the loss.sum()
). Also, I think we should apply this trick to train_loss to make it comparable to eval_loss.
I think it's better to do these in another PR! (coming soon)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great point! When the batch size is fixed (such as in run_summarization_flax
and run_image_captioning_flax
) there is no padding, so we are not obliged to normalise the loss by the number of tokens (as this stays fixed for every batch).
We can either use the trick you used in eval_loss
, or compute a psum
over devices for the losses and number of tokens and then normalise accordingly: https://github.com/sanchit-gandhi/seq2seq-speech/blob/cfc6d73959486f5bd71c623ddd95843d62f5a614/run_flax_speech_recognition_seq2seq.py#L1252-L1255
Feel free to leave it as is for this PR - the numerical differences end up being small when averaged over many batches and large amounts of data. Would make for a good follow-up PR!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the nice trick!
In fact, I have yet to grasp the idea of batch size is fixed
= num of tokens stays fixed for every batch
in run_summarization_flax
and run_image_captioning_flax
. Could you elaborate on this a little bit?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry, ignore that remark! I checked and it applies equally well to both of those! I remember discussing this point with @patil-suraj offline a couple of months ago. IIRC, in the example given for run_summarization_flax
, the targets are always truncated to max_target_length
, and so effectively there is no padding. But in general, the pmean
won't return a 'true' loss.
Note that we can actually bypass the constraint of the compute_loss
function only returning one scalar (the total loss). If we pass the flag has_aux=True
to jax.value_and_grad
we can return a second, auxiliary output (https://jax.readthedocs.io/en/latest/_autosummary/jax.value_and_grad.html?highlight=value_and_grad). If we return two outputs, loss
and num_labels
from our compute_loss
function, we can take gradients wrt the loss
and also return the num_labels
:
def compute_loss(params):
...
return loss, num_labels
# take grads wrt to the first output (loss), also return a second output (num_labels)
grad_fn = jax.value_and_grad(compute_loss, has_aux=True)
loss, num_labels, grad = grad_fn(state.params)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Aha, I unfortunately forgot the has_aux=True
flag. Thanks for reminding!
Regarding the num of target tokens, what you said makes sense if the target length exceeds max_target_length
, but what if it is shorter?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
And do you think this PR is ready to be merged? @sanchit-gandhi @patrickvonplaten
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Having provided the pointers related to the data collator, @patil-suraj might wish to take a look before merging!
Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>
…a21/transformers into add-flax-pretraining-bart
Your reviews are extremely helpful @sanchit-gandhi! All of them have been resolved. Thank you! |
@patil-suraj Could you have a look at this PR? I'd love to hear your feedback. |
Hey @duongna21! Great job on this PR, and thank you for addressing the comments! For the time being, let's hold off on gradient accumulation. If you require it for your personal experiments, it can be achieved quite easily using the Optax wrapper MultiSteps. However, in my experience, this wrapper is pretty memory inefficient and does not yield particularly good performance; it applies a dummy update of zeros for Otherwise, all the implementation TODOs are complete, the code review approved, and the training results on track, so happy to go ahead and merge! |
@sanchit-gandhi Yeah, I agree that gradient accumulation shouldn't be added until there is a more elegant way to implement it, so feel free to merge this PR. Thank you for the helpful advice! |
* add bart pretraining flax script * fixup * add bart pretraining flax script * add BART to README * add BART to README * add BART to README * add BART to README * add BART to README * add bos eos document * Update README.md * Update README.md * Update examples/flax/language-modeling/run_bart_dlm_flax.py Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> * final * final * final * remove use_auth_token ing from_config Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>
@duongna21 I wonder why the |
What does this PR do?
Fixes #6743 #18030 #4151 #5096. Adds Flax script for BART pretraining.
Inspired by @patil-suraj's suggestion, I modified the @morganmcg1's DataCollatorForDenoisingTasks to create a BART denoising pretraining script in Flax.
Implementation details from the paper:
Training statistics when pre-train bart-base in Norwegian on a single TPUv3-8 pod:
Who can review?
cc potential reviewers: @patrickvonplaten, @patil-suraj, @sgugger, @LysandreJik