Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to set a proper batchsize when using CachedMultipleNegativesRankingLoss? #3134

Open
awmoe opened this issue Dec 13, 2024 · 3 comments
Open

Comments

@awmoe
Copy link

awmoe commented Dec 13, 2024

When using the MultipleNegativesRankingLoss, I tried different batchsize(per_device_train_batch_size) setting, and found that 512 was the maximum. When batchsize was greater than 512, GPU memory OOM was happened.

As stated in the document of CachedMultipleNegativesRankingLoss:

GradCache is a smart way to solve this problem. It achieves the goal by dividing the computation into two stages of embedding and loss calculation, which both can be scaled by mini-batches. As a result, memory of constant size (e.g. that works with batch size = 32) can now process much larger batches (e.g. 65536).

So, I tried CachedMultipleNegativesRankingLoss, and the mini_batch_size of CachedMultipleNegativesRankingLoss can go as high as 2048. mini_batch_size greather than 2048 will cause GPU memory OOM.

Nevertheless, When setting the mini_batch_size as 2048, I can still increase the global batchsize(per_device_train_batch_size). Generally speaking, larger batchsize will achieve better performance in the constrastive learning settings. So, I tried different batchsize(per_device_train_batch_size), and found it can be as large as 1048576 and it won't cause GPU memory OOM (but the GPU utilization is 100%). So, I am wondering how to set a proper batchsize(per_device_train_batch_size), can it be infinite big?

@tomaarsen
Copy link
Collaborator

Hello!

I should indeed clarify the documentation more, but the gist is that with CachedMultipleNegativesRankingLoss you can increase the per_device_train_batch_size to large numbers while keeping the mini_batch_size smaller. You should choose these like this:

  1. Set mini_batch_size in CMNRL as the highest that your system can handle. For example 512 in your case (or perhaps 2048?). This is the number of samples that are processed at the same time: with CMNRL this is the only parameter that increases or decreases your memory usage. Beyond that, this parameter does not affect the training performance. Using mini_batch_size=1, 32, 512 or 2048 doesn't change the training effects, just the memory usage and training speed.
  2. Set per_device_train_batch_size to any number. It can't be infinitely big, but you should be able to set it as large as you might want. You've already noticed that you can increase this very high, e.g. to 1048576. Higher tends to be better for contrastive learning, but I'm sure there's limits.

So you're very much on the right track already! One thing to consider is the relation between batch size (per_device_train_batch_size) and learning rate. Generally, if you have a higher batch size, you have to set a higher learning rate. Otherwise, each individual sample will be "learned from" a bit less.

An old trick/guideline is that if you know a good learning rate at a "normal" batch size of e.g. 64, and you want to increase to a (much) larger batch size, then you should increase your learning rate by sqrt(new_batch / old_batch). I normally start trying some values around there: the learning rate is the most important hyperparameter alongside the batch size.

So, if you might use a 2e-5 learning rate with a batch size of 32, and now you want to use a batch size (per_device_train_batch_size) of 65536, then I would try setting the learning rate to:

new_lr = old_lr * sqrt(new_batch / old_batch)
new_lr = 2e-5 * sqrt(65536 / 32)
new_lr = 2e-5 * sqrt(2048)
new_lr = 2e-5 * 45.25
new_lr = 0.9e-3

I usually start there, and then double or halve the value until I get the strongest results.

  • Tom Aarsen

@awmoe
Copy link
Author

awmoe commented Dec 18, 2024

Thank you @tomaarsen very much for your detailed explanation!

With CachedMultipleNegativesRankingLoss, I increased the batchsize (per_device_train_batch_size) from 512 to 8192, and forgot to change the learning rate, evaluation results showed that the performance with 8192 was worse than that with 512. Your explanation reminded me that I should have to tune my learning rate.

Just one more thing, when I used CachedMultipleNegativesRankingLoss and increased the batchsize (per_device_train_batch_size) from 512 to 8192, I found my GPU utilization fluctuating between low (25%) and high (100%). If the batchsize (per_device_train_batch_size) further increased to 65536, the GPU utilization fluctuating between 0% and 100%. Did the dataset become the bottleneck? How can I optimize the training?

@tomaarsen
Copy link
Collaborator

Just one more thing, when I used CachedMultipleNegativesRankingLoss and increased the batchsize (per_device_train_batch_size) from 512 to 8192, I found my GPU utilization fluctuating between low (25%) and high (100%). If the batchsize (per_device_train_batch_size) further increased to 65536, the GPU utilization fluctuating between 0% and 100%. Did the dataset become the bottleneck? How can I optimize the training?

Hmm, it's possible that the batch sampler, tokenizer, or dataloader indexing becomes a bottleneck indeed. I've not done too much testing here, but the TrainingArguments instance has some dataloader parameters. Do note that our tokenize function is inherently linked to the model itself, so e.g. dataloader_num_workers might not work out of the box because duplicating the data collator with the tokenize function and thus the full model isn't feasible. There's possible fixes, e.g. to override the tokenizer function in the collator with one that isn't tied to the model (but does the same), but it gets a bit hacky at some point.

  • Tom Aarsen

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants