-
Notifications
You must be signed in to change notification settings - Fork 2.3k
Mixed Precision Training #2149
Comments
Thanks! For now, this is pretty low on our priority list, but contributions would be welcome here. |
Would love to see fp16 support in allennlp! |
I played around with apex/amp and found that integrating it into allennlp is straightforward, with a measurable performance boost on a large CRF tagger model (RTX GPU). After cloning and installing apex, you only need to modify allennlp/training/trainer.py (or create your own trainer). Near the top after the imports, import amp and initialize it. from apex import amp
amp_handle = amp.init() Replace the self.optimizer = amp_handle.wrap_optimizer(optimizer) Finally, replace with self.optimizer.scale_loss(loss) as scaled_loss:
scaled_loss.backward() I tested this with a larger version of the crf_tagger.json tutorial config on conll 2003 using During training, I see log messages like As @matthew-z mentioned, some allennlp functions are not stable, and I've not investigated these. I'm not sure what additional work is necessary to submit a pull request. Thoughts? |
Awesome, thanks for trying this out! @joelgrus is currently thinking about how to make the For functions that aren't stable, this is because of hard-coded values that were made with fp32 in mind, right? Like |
I'm still not 100% sure what I'm going to do about the trainer. In the meantime, if you wanted to do a PR, I would probably do the following (once you've done the correctness checking that mattg pointed out): (1) give the current trainer a private method
(2) replace the call to (3) create a subclass Like I said, I'm not quite sure yet how I'm going to clean up the Trainer code, but I'm pretty sure I could work with this. |
I have tested |
I guess there is no need to wrap the optimizer, just from apex import amp
amp_handler = amp.init()
with amp_handler.scale_loss(loss, self.optimizer) as scaled_loss:
scaled_loss.backward() |
Calling |
Also, I ran a similar experiment with amp on/off with a Transformer-based tagger. For some reason, I'm getting ~5% slower training times on RTX with mixed precision enabled (even with larger models) despite Transformer being most suitable for tensor cores. Tested this on a vanilla PyTorch Transformer tagger as well, and I'm seeing something similar. There's got to be some bottleneck I'm missing, as fairseq claims a nearly 5x speedup over plain |
@Hyperparticle fp16 provides about 200% ~ 250% speed over fp32 in very simple matrix multiplication and CNN with V100 in pytorch. I guess some special CUDA code required to achieve more speed up (5x sounds too surprising to me) BTW, RTX game card's tensor cores are much lower than titan or V100 intentionally (otherwise people will not buy titan). |
Wondering if there's a PR on this. The solution given by @matthew-z using the new |
There was an unfinished PR #2467... |
any update? |
Well, you can use this trainer to make it work for now. |
Is there PR on this? This would be a good addition along with the recent changes wrt distributed training, gradient accumulation and @DeNeutoy 's data loader changes. |
I was able to get Apex working on the commit tagged as v1.0 prerelease. Starting from
This appears to work. My model trains ~20% faster but performance is unchanged. It is also using ~17% less memory. The majority of my model's parameters come from a pre-trained transformer (RoBERTa) and these are approximately the same improvements I got with Apex in a different project using a similar model. I will try to post a more rigorous benchmark here in the next couple of days. For anyone who wants to try the trainer out, I have a gist here. FYI, I tried subclassing |
Thanks @JohnGiorgi! Are you willing to make a PR for this? It seems pretty straightforward. I would just have the default value for |
@matt-gardner Yes I can do that! Two questions:
|
@JohnGiorgi, I was thinking just a small PR introducing the changes you made to our existing trainer. It looks like it's something like 10 lines of code difference (though I may have missed some changes). If we can have a default that doesn't use half precision, and it doesn't add much complexity, it seems like a no-brainer to put in those 10 lines of code. |
If I'm missing some reason that would make it so that we really have to have a separate trainer entirely for this, please let me know what I'm missing. |
@matt-gardner No you are correct, I think it makes sense to just add it to the existing Okay I will take a crack at it and open a PR! |
Fixed by #3866 |
It seems that currently there is no fp16 support in AllenNLP, but I think it is quite straightforward to do in a clean way with Apex.
Advantages of FP16:
Challenge:
Some functions in AllenNLP are not FP16 stable, such as
masked_softmax
The text was updated successfully, but these errors were encountered: