Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Flax BART pretraining script #18297

Merged
merged 20 commits into from
Aug 1, 2022

Conversation

duongna21
Copy link
Contributor

@duongna21 duongna21 commented Jul 26, 2022

What does this PR do?

Fixes #6743 #18030 #4151 #5096. Adds Flax script for BART pretraining.

Inspired by @patil-suraj's suggestion, I modified the @morganmcg1's DataCollatorForDenoisingTasks to create a BART denoising pretraining script in Flax.

Implementation details from the paper:

  • Text infilling
  • Sentence permutation
  • Large training batch sizes (will add gradient accumulation for this and other Flax language modeling scripts in the next PR)

Training statistics when pre-train bart-base in Norwegian on a single TPUv3-8 pod:

Who can review?

cc potential reviewers: @patrickvonplaten, @patil-suraj, @sgugger, @LysandreJik

@duongna21 duongna21 changed the title add bart pretraining flax script Add BART pretraining Flax script Jul 26, 2022
@duongna21 duongna21 changed the title Add BART pretraining Flax script Add Flax BART pretraining script Jul 26, 2022
@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Jul 26, 2022

The documentation is not available anymore as the PR was closed or merged.

@patrickvonplaten
Copy link
Contributor

This looks very nice to me! Just a small question, why is the file called run_bart_dlm_flax.py, i.e. why the dlm and not just lm ?

@patrickvonplaten
Copy link
Contributor

@patil-suraj could you take a look here as well ?

Copy link
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks good to me! Maybe @sanchit-gandhi can also take a quick look here

@duongna21
Copy link
Contributor Author

duongna21 commented Jul 27, 2022

This looks very nice to me! Just a small question, why is the file called run_bart_dlm_flax.py, i.e. why the dlm and not just lm ?

@patrickvonplaten It stands for denoising language modeling, which is consistent with mlm and clm. What do you think?

Copy link
Contributor

@sanchit-gandhi sanchit-gandhi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Thank you @duongna21 for adding this example. Really interesting to see the training logs too!

examples/flax/language-modeling/run_bart_dlm_flax.py Outdated Show resolved Hide resolved
examples/flax/language-modeling/run_bart_dlm_flax.py Outdated Show resolved Hide resolved
examples/flax/language-modeling/run_bart_dlm_flax.py Outdated Show resolved Hide resolved
examples/flax/language-modeling/run_bart_dlm_flax.py Outdated Show resolved Hide resolved
Comment on lines +899 to +900
eval_normalizer = eval_metrics.pop("normalizer")
eval_metrics = jax.tree_map(lambda x: x / eval_normalizer, eval_metrics)
Copy link
Contributor

@sanchit-gandhi sanchit-gandhi Jul 27, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice! This is a clean way of normalising the eval_metrics!

Technically speaking, in the train_step, the pmap won't compute a 'true' mean over devices. Here, what you're doing is computing a normalised loss on each device, and then averaging these losses over devices. This isn't strictly equal to summing the losses over all devices, and then dividing by the number of samples.

Let $K$ denote the number of devices. Denote the loss on the $i$-th device as $L_i$ (loss.sum()) and the number of samples $N_i$ (label_mask.sum()). In the loss_fn, we compute the normalised loss on each device (loss.sum() / label_mask.sum()):

$$\bar{L}_i = \frac{L_i}{N_i}$$

and then average over devices with the pmap:

$$\mathcal{L} = \frac{1}{K} \sum_{i=1}^{K} \frac{L_i}{N_i}$$

Whereas, for a 'true' loss, we should first add up all the losses over devices:

$$L_{tot} = \sum_{i=1}^{K} L_i $$

and then divide by the total number of labels:

$$\mathcal{L}' = \frac{L_{tot}}{N} = \frac{1}{N}\sum_{i=1}^{K} L_i $$

where $N$ is the total number of labels:

$$ N = \sum_{i=1}^{K} N_i $$

If we compare the two and ignore the constant $K$ in the pmap average:

$$\mathcal{L} = \sum_{i=1}^{K} \frac{L_i}{N_i}$$

$$ \mathcal{L}' = \frac{1}{N}\sum_{i=1}^{K} L_i $$

we see that the losses are in-fact different. The first expression is what you get if you average the losses on each device, then average these terms over devices with a pmap. The second expression is a 'true' loss, what you get by summing the losses on each device, summing these losses over devices, and then dividing by the total number of terms in your batch (= sum of the label_mask per device, summing these terms over devices).

Copy link
Contributor Author

@duongna21 duongna21 Jul 27, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're right, we should not simply use pmean in case of masked labels. I actually borrowed this implementation from run_mlm_flax.py.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed, we should update the other Flax training examples to reflect this!

Copy link
Contributor Author

@duongna21 duongna21 Jul 28, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sanchit-gandhi Sure, this trick should be applied to tasks that do not have the same $N_i$ for every batch, including run_mlm_flax, run_bart_dlm_flax, run_summarization_flax and run_image_captioning_flax.
However, the changes should be made in run_summarization_flax and run_image_captioning_flax are nontrivial because of the existance of compute_loss function (which should return only 1 scalar - the loss.sum()). Also, I think we should apply this trick to train_loss to make it comparable to eval_loss.
I think it's better to do these in another PR! (coming soon)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great point! When the batch size is fixed (such as in run_summarization_flax and run_image_captioning_flax) there is no padding, so we are not obliged to normalise the loss by the number of tokens (as this stays fixed for every batch).

We can either use the trick you used in eval_loss, or compute a psum over devices for the losses and number of tokens and then normalise accordingly: https://github.com/sanchit-gandhi/seq2seq-speech/blob/cfc6d73959486f5bd71c623ddd95843d62f5a614/run_flax_speech_recognition_seq2seq.py#L1252-L1255

Feel free to leave it as is for this PR - the numerical differences end up being small when averaged over many batches and large amounts of data. Would make for a good follow-up PR!

Copy link
Contributor Author

@duongna21 duongna21 Jul 28, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the nice trick!

In fact, I have yet to grasp the idea of batch size is fixed = num of tokens stays fixed for every batch in run_summarization_flax and run_image_captioning_flax. Could you elaborate on this a little bit?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, ignore that remark! I checked and it applies equally well to both of those! I remember discussing this point with @patil-suraj offline a couple of months ago. IIRC, in the example given for run_summarization_flax, the targets are always truncated to max_target_length, and so effectively there is no padding. But in general, the pmean won't return a 'true' loss.

Note that we can actually bypass the constraint of the compute_loss function only returning one scalar (the total loss). If we pass the flag has_aux=True to jax.value_and_grad we can return a second, auxiliary output (https://jax.readthedocs.io/en/latest/_autosummary/jax.value_and_grad.html?highlight=value_and_grad). If we return two outputs, loss and num_labels from our compute_loss function, we can take gradients wrt the loss and also return the num_labels:

def compute_loss(params):
    ...
    return loss, num_labels

# take grads wrt to the first output (loss), also return a second output (num_labels)
grad_fn = jax.value_and_grad(compute_loss, has_aux=True)
loss, num_labels, grad = grad_fn(state.params)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Aha, I unfortunately forgot the has_aux=True flag. Thanks for reminding!

Regarding the num of target tokens, what you said makes sense if the target length exceeds max_target_length, but what if it is shorter?

Copy link
Contributor Author

@duongna21 duongna21 Jul 28, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And do you think this PR is ready to be merged? @sanchit-gandhi @patrickvonplaten

Copy link
Contributor

@sanchit-gandhi sanchit-gandhi Jul 28, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Having provided the pointers related to the data collator, @patil-suraj might wish to take a look before merging!

@duongna21
Copy link
Contributor Author

duongna21 commented Jul 28, 2022

Your reviews are extremely helpful @sanchit-gandhi! All of them have been resolved. Thank you!

@duongna21
Copy link
Contributor Author

@patil-suraj Could you have a look at this PR? I'd love to hear your feedback.

@sanchit-gandhi
Copy link
Contributor

sanchit-gandhi commented Aug 1, 2022

Hey @duongna21! Great job on this PR, and thank you for addressing the comments!

For the time being, let's hold off on gradient accumulation. If you require it for your personal experiments, it can be achieved quite easily using the Optax wrapper MultiSteps. However, in my experience, this wrapper is pretty memory inefficient and does not yield particularly good performance; it applies a dummy update of zeros for $K-1$ train steps (redundant), and then the accumulated gradient update step on the $K$-th train step. Instead, writing a custom gradient accumulation training loop is more efficient (c.f. seq2seq-speech), but this involves quite a lot of additional code and is significantly more involved, so I'm not particularly in favour of using it for these streamlined examples scripts!

Otherwise, all the implementation TODOs are complete, the code review approved, and the training results on track, so happy to go ahead and merge!

@duongna21
Copy link
Contributor Author

duongna21 commented Aug 1, 2022

@sanchit-gandhi Yeah, I agree that gradient accumulation shouldn't be added until there is a more elegant way to implement it, so feel free to merge this PR. Thank you for the helpful advice!

@sgugger sgugger merged commit 3909d7f into huggingface:main Aug 1, 2022
oneraghavan pushed a commit to oneraghavan/transformers that referenced this pull request Sep 26, 2022
* add bart pretraining flax script

* fixup

* add bart pretraining flax script

* add BART to README

* add BART to README

* add BART to README

* add BART to README

* add BART to README

* add bos eos document

* Update README.md

* Update README.md

* Update examples/flax/language-modeling/run_bart_dlm_flax.py

Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>

* final

* final

* final

* remove use_auth_token ing from_config

Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>
@StevenTang1998
Copy link
Contributor

@duongna21 I wonder why the permute_sentences only contains the pad token rather the full stop token?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

BART for Pre-Training
6 participants