Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Mixed Precision Training #2149

Closed
matthew-z opened this issue Dec 6, 2018 · 22 comments
Closed

Mixed Precision Training #2149

matthew-z opened this issue Dec 6, 2018 · 22 comments

Comments

@matthew-z
Copy link

matthew-z commented Dec 6, 2018

It seems that currently there is no fp16 support in AllenNLP, but I think it is quite straightforward to do in a clean way with Apex.

Advantages of FP16:

  1. Larger batch size, especially useful for large embedding models like ELMO and BERT
  2. Faster training: the latest GPUs (P100, V100, RTX) can achieve 60-80% speed up
  3. Many experiments have showed that FP16 does not hurt accuracy for CV models.

Challenge:
Some functions in AllenNLP are not FP16 stable, such as masked_softmax

@matt-gardner
Copy link
Contributor

Thanks! For now, this is pretty low on our priority list, but contributions would be welcome here.

@matt-gardner matt-gardner added the P3 label Dec 6, 2018
@matt-peters
Copy link
Contributor

Would love to see fp16 support in allennlp!

@Hyperparticle
Copy link
Contributor

Hyperparticle commented Jan 6, 2019

I played around with apex/amp and found that integrating it into allennlp is straightforward, with a measurable performance boost on a large CRF tagger model (RTX GPU).

After cloning and installing apex, you only need to modify allennlp/training/trainer.py (or create your own trainer).

Near the top after the imports, import amp and initialize it.

from apex import amp
amp_handle = amp.init()

Replace the self.optimizer = optimizer assignment, wrapping amp in the __init__() function.

self.optimizer = amp_handle.wrap_optimizer(optimizer)

Finally, replace loss.backward() in _train_epoch() with scaled loss.

with self.optimizer.scale_loss(loss) as scaled_loss:
    scaled_loss.backward()

I tested this with a larger version of the crf_tagger.json tutorial config on conll 2003 using cuda 10.0 and pytorch 1.0.0. With such a small model, performance was actually slower with fp16 enabled. However, once I boosted the embedding/RNN size to something like 256/4096, I could see ~17% faster training times per epoch on a Titan RTX (over 10 epochs). It's not close to 2x like I hoped, but there definitely may be optimizations in order that can improve this. I also notice a small regularization effect on this model in that fp32 training overfits early on where fp16 does not.

During training, I see log messages like "Gradient overflow, skipping update" being emitted from apex. I'm not sure how to suppress them without interfering with allennlp's logging.

As @matthew-z mentioned, some allennlp functions are not stable, and I've not investigated these.

I'm not sure what additional work is necessary to submit a pull request. Thoughts?

@matt-gardner
Copy link
Contributor

Awesome, thanks for trying this out! @joelgrus is currently thinking about how to make the Trainer easier to modify, and this is another thing to keep in mind for that, I think. He can probably give better advice than me on what exactly to put into a PR.

For functions that aren't stable, this is because of hard-coded values that were made with fp32 in mind, right? Like 1e-13 in masked_softmax and 1e-45 in masked_log_softmax? Are there any other issues? I think before we can integrate anything into the library, we'd need to make sure that you won't silently get incorrect behavior when using these functions. Crashing is ok - you tried to use something we don't officially support, and it crashed, you can write your own functions to work around it. But doing the wrong thing without warning is no good.

@joelgrus
Copy link
Contributor

joelgrus commented Jan 6, 2019

I'm still not 100% sure what I'm going to do about the trainer. In the meantime, if you wanted to do a PR, I would probably do the following (once you've done the correctness checking that mattg pointed out):

(1) give the current trainer a private method

def _backpropagate(self, loss: torch.Tensor) -> None:
    """
    This is pulled out as a method because subclasses might want to do something
    special (e.g. mixed-precision training).
    """
    loss.backward()

(2) replace the call to loss.backward() in _train_epoch with a call to that method

(3) create a subclass ApexTrainer(Trainer) that just overrides the constructor (so it can wrap the optimizer) and also overrides the new _backpropagate function.

Like I said, I'm not quite sure yet how I'm going to clean up the Trainer code, but I'm pretty sure I could work with this.

@matthew-z
Copy link
Author

matthew-z commented Jan 8, 2019

I have tested masked_softmax and masked_log_softmax, and their current implementations will not work in pure fp16 mode, but in apex.amp mode they will be converted into fp32 automatically, so we don't need to worry them.

@matthew-z
Copy link
Author

@Hyperparticle

I guess there is no need to wrap the optimizer, just

from apex import amp
amp_handler = amp.init()

with amp_handler.scale_loss(loss, self.optimizer) as scaled_loss:
    scaled_loss.backward()

@Hyperparticle
Copy link
Contributor

Hyperparticle commented Jan 8, 2019

@matthew-z

Calling amp_handler.scale_loss() is an alternative to wrapping the optimizer, but this assumes there is only one optimizer in the code. This assumption holds currently, but in the future, you may want multiple optimizers or backward passes (in which case, you must wrap the optimizer). Just something to consider.

@Hyperparticle
Copy link
Contributor

Hyperparticle commented Jan 8, 2019

Also, I ran a similar experiment with amp on/off with a Transformer-based tagger. For some reason, I'm getting ~5% slower training times on RTX with mixed precision enabled (even with larger models) despite Transformer being most suitable for tensor cores. Tested this on a vanilla PyTorch Transformer tagger as well, and I'm seeing something similar. There's got to be some bottleneck I'm missing, as fairseq claims a nearly 5x speedup over plain fp32.

@matthew-z
Copy link
Author

@Hyperparticle
In my tests, apex can achieve 1.8x speed for BERT-large model, but nearly no speed up for RNN models.

fp16 provides about 200% ~ 250% speed over fp32 in very simple matrix multiplication and CNN with V100 in pytorch. I guess some special CUDA code required to achieve more speed up (5x sounds too surprising to me)

BTW, RTX game card's tensor cores are much lower than titan or V100 intentionally (otherwise people will not buy titan).

@scarecrow1123
Copy link
Contributor

Wondering if there's a PR on this. The solution given by @matthew-z using the new amp API should work seamlessly with the current setup I reckon. Regarding multiple optimizer and the use of scale_loss, the current Trainer class anyways support only a single optimizer and hence the suggested solution should work very well. It does work for me in my local setup and would like to know if there are any other PRs for this. Thanks!

@matthew-z
Copy link
Author

There was an unfinished PR #2467...

@jiacheng-xu
Copy link

any update?

@scarecrow1123
Copy link
Contributor

any update?

Well, you can use this trainer to make it work for now.

@scarecrow1123
Copy link
Contributor

Is there PR on this? This would be a good addition along with the recent changes wrt distributed training, gradient accumulation and @DeNeutoy 's data loader changes.

@JohnGiorgi
Copy link
Contributor

JohnGiorgi commented Feb 27, 2020

I was able to get Apex working on the commit tagged as v1.0 prerelease.

Starting from Trainer, I

  1. Added an argument called opt_level.
  2. Wrapped self.model and self.optimizer in the constructor with amp.initialize.
  3. Made a subtle modification to rescale_gradients to accomodate gradient clipping/normalization with amp (see here).
  4. Replaced loss.backwards() with the amp.scale_loss context manager.

This appears to work. My model trains ~20% faster but performance is unchanged. It is also using ~17% less memory. The majority of my model's parameters come from a pre-trained transformer (RoBERTa) and these are approximately the same improvements I got with Apex in a different project using a similar model. I will try to post a more rigorous benchmark here in the next couple of days.

For anyone who wants to try the trainer out, I have a gist here.

FYI, I tried subclassing Trainer rather than copying it outright, but kept getting an error that train_data wasn't supplied. Couldn't figure this out.

@matt-gardner
Copy link
Contributor

Thanks @JohnGiorgi! Are you willing to make a PR for this? It seems pretty straightforward. I would just have the default value for opt_level be None, which means don't use amp (unless the default is already to use full precision).

@JohnGiorgi
Copy link
Contributor

@matt-gardner Yes I can do that!

Two questions:

  1. Does it make sense to have an entire new trainer as I have done? Or should I subclass Trainer (and try to figure out why this throws errors).
  2. What about a default value of "O0" for opt-level? This is a no-op for FP32 models (see here).

@matt-gardner
Copy link
Contributor

@JohnGiorgi, I was thinking just a small PR introducing the changes you made to our existing trainer. It looks like it's something like 10 lines of code difference (though I may have missed some changes). If we can have a default that doesn't use half precision, and it doesn't add much complexity, it seems like a no-brainer to put in those 10 lines of code.

@matt-gardner
Copy link
Contributor

If I'm missing some reason that would make it so that we really have to have a separate trainer entirely for this, please let me know what I'm missing.

@JohnGiorgi
Copy link
Contributor

@matt-gardner No you are correct, I think it makes sense to just add it to the existing Trainer.

Okay I will take a crack at it and open a PR!

@DeNeutoy
Copy link
Contributor

DeNeutoy commented Mar 3, 2020

Fixed by #3866

@DeNeutoy DeNeutoy closed this as completed Mar 3, 2020
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

No branches or pull requests

9 participants