-
Notifications
You must be signed in to change notification settings - Fork 2.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
How to set a proper batchsize when using CachedMultipleNegativesRankingLoss? #3134
Comments
Hello! I should indeed clarify the documentation more, but the gist is that with
So you're very much on the right track already! One thing to consider is the relation between batch size ( An old trick/guideline is that if you know a good learning rate at a "normal" batch size of e.g. 64, and you want to increase to a (much) larger batch size, then you should increase your learning rate by So, if you might use a 2e-5 learning rate with a batch size of 32, and now you want to use a batch size (
I usually start there, and then double or halve the value until I get the strongest results.
|
Thank you @tomaarsen very much for your detailed explanation! With CachedMultipleNegativesRankingLoss, I increased the batchsize (per_device_train_batch_size) from 512 to 8192, and forgot to change the learning rate, evaluation results showed that the performance with 8192 was worse than that with 512. Your explanation reminded me that I should have to tune my learning rate. Just one more thing, when I used CachedMultipleNegativesRankingLoss and increased the batchsize (per_device_train_batch_size) from 512 to 8192, I found my GPU utilization fluctuating between low (25%) and high (100%). If the batchsize (per_device_train_batch_size) further increased to 65536, the GPU utilization fluctuating between 0% and 100%. Did the dataset become the bottleneck? How can I optimize the training? |
Hmm, it's possible that the batch sampler, tokenizer, or dataloader indexing becomes a bottleneck indeed. I've not done too much testing here, but the TrainingArguments instance has some
|
When using the MultipleNegativesRankingLoss, I tried different batchsize(per_device_train_batch_size) setting, and found that 512 was the maximum. When batchsize was greater than 512, GPU memory OOM was happened.
As stated in the document of CachedMultipleNegativesRankingLoss:
So, I tried CachedMultipleNegativesRankingLoss, and the mini_batch_size of CachedMultipleNegativesRankingLoss can go as high as 2048. mini_batch_size greather than 2048 will cause GPU memory OOM.
Nevertheless, When setting the mini_batch_size as 2048, I can still increase the global batchsize(per_device_train_batch_size). Generally speaking, larger batchsize will achieve better performance in the constrastive learning settings. So, I tried different batchsize(per_device_train_batch_size), and found it can be as large as 1048576 and it won't cause GPU memory OOM (but the GPU utilization is 100%). So, I am wondering how to set a proper batchsize(per_device_train_batch_size), can it be infinite big?
The text was updated successfully, but these errors were encountered: