-
Notifications
You must be signed in to change notification settings - Fork 2.5k
Conversation
could you please add it to the Dockerfile as well? |
@miguelvr Done - hadn't noticed that file was there :) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot for the PR!
I've left a few comments and questions to get started, as I'm not familiar with APEX.
tools/train_net.py
Outdated
model, device_ids=[local_rank], output_device=local_rank, | ||
# this should be removed if we update BatchNorm stats | ||
broadcast_buffers=False, | ||
model = DDP( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a difference now between DistributedDataParallel
from PyTorch and from apex
? What about the non-legacy DistributedParallel
from c10d, does it have similar performance?
Or does apex one handle fp16 differently?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
apex.DistributedDataParallel
has similar perf to the c10d implementation, and it's what we've been running for the last few months -- I'm not married to the change, but I might suggest using it until the c10d implementation works for Mask-RCNN (if it doesn't already in master)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we actually need it for mixed precision to work?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We shouldn't, but I haven't tested it.
@@ -17,6 +17,13 @@ def __init__(self, n): | |||
self.register_buffer("running_var", torch.ones(n)) | |||
|
|||
def forward(self, x): | |||
# Cast all fixed parameters to half() if necessary | |||
if x.type() == torch.half: | |||
self.weight = self.weight.half() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So it seems that we don't explicitly cast the model to fp16
during initialization, is that right?
This seems a bit counter-intuitive to me, what happens if we just cast everything in model
to .half()
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Because not everything can be half, as not all ops support half. One could write a function to cast all ops that do support half, cast_some_to_half(model)
maybe, but I decided to special case this one -- I'm open to other approaches :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What are the ops that do not support half (apart from the custom ones that are in this repo)? I thought that all ops in pytorch supported fp16 for cuda (with potentially bad accuracy)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In using "support" I chose my words badly. It's not necessarily support, more "can be used with a reasonable expectation of not losing accuracy". Apex.amp takes a conservative approach by not moving ops to fp16 when we're not sure of their accuracy (the lists of what is / isn't moved to fp16 are in the files here.
We could try casting the entire model to half and see what happens -- there's enough code in the RPN especially that I'm just not sure how it'll behave in lower precisions that I decided to be conservative beyond what apex.amp does. Unfortunately that means until PyT can grok y_16 = a_32 * x_16 + b_32
where the subscripts denote precision we have to do something manual here
@@ -116,6 +116,6 @@ def forward(self, x, boxes): | |||
for level, (per_level_feature, pooler) in enumerate(zip(x, self.poolers)): | |||
idx_in_level = torch.nonzero(levels == level).squeeze(1) | |||
rois_per_level = rois[idx_in_level] | |||
result[idx_in_level] = pooler(per_level_feature, rois_per_level) | |||
result[idx_in_level] = pooler(per_level_feature, rois_per_level).to(dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
doesn't the amp.float_function
wrap back the values to fp16 after they are computed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd need to run again to work out exactly what case was happening here, but the type change was not happening correctly and I had to manually cast here to prevent errors
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Quick question: is this casting still relevant?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes: result[idx_in_level]
is expected to be in fp16 (as it's the same precision as the input), but the pooler returns fp32 (explicitly, as the code hasn't had fp16 support added). To get around this the result from the pooler needs to be casted
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But I thought that amp.float_function
would:
1 - cast to float
2 - compute
3 - cast back to fp16
Or is my understanding of it wrong?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok.
I think the solution I'd potentially do myself (while support for fp16 is not present in the core pooler functions) is to just cast in the C++ side to float and cast back if the type is fp16.
But I suppose this is not really a hard-requirement here (but would make things cleaner)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would you like a comment added explaining the current need for the cast? (along with a TODO for the full (fp16 support in pooling) if you so desire)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I just find it very unintuitive why you had to add this casting only here, and not after NMS as well. :-/
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Because this is a strange case -- you're in fp16-land, allocate your output as fp16, then run something that has to cast up to fp32. There's no cast back automatically (you're calling from a module, so there's no module boundary to cause a cast) so it has to be done manually. If lines 111-115 didn't exist, and you inferred the type from the return type of the pooling call this explicit wouldn't have to be there.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I understand it's a little weird, but that's where the code is right now -- if you want an explicit (non-AMP) fp16 version that can also be done, but it'll be more invasive and can't be done before the new year after I get back
any updates on this? |
I'm rebasing right now (along with moving back to pytorch default |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks good to me, thanks a lot @slayton58 !
@wat3rBro will be running a few training jobs to double check that accuracy is the same, and then I'll get this PR merged.
@slayton58 Hi, I'm collecting stats of training, finding out it would terminate silently during the first iteration if fp16 is enabled for training ResNeXt model (both detection and mask) on V100. I checked the verbose output from APEX and couldn't find difference when comparing it with a successful run on P100 for the same model. The exit should happen when running:
Is this a known issue? |
@wat3rBro I've never run with ResNeXt (it didn't exist on the version of the code I originally developed this for). There should be no difference in behavior between resnet and resnext, does this happen on the first iteration? Also, do other tests pass with P100 (I develop on V100)? |
@slayton58 it happens during the first iteration. Could you verify if it works for e2e_mask_rcnn_X_101_32x8d_FPN_1x.yaml? On P100 I have R-50-C4/FPN running successfully without accuracy loss. |
Adds fp16 support via apex.amp Also switches communication to apex.DistributedDataParallel
Added support to tools/test_net.py SOLVER.MIXED_PRECISION -> DTYPE \in {float32, float16} apex.amp not installed now raises ImportError
@wat3rBro I was having an issue with my branch due to some of the pre-trained model URLs having changed (this is the relevant commit) -- I rebased (and pushed) against latest master and I can start training on a single V100 now. |
@slayton58 Does it run properly on your machine at least for the first few iterations? |
Yes. Are you still having issues?
|
One thing to keep in mind: @wat3rBro what version of CUDNN are you using that gives the problem? |
@fmassa I'm still on CuDNN 7.1.2 because of our infra. |
@wat3rBro Any progress on your side? If we want to try and track down more details on exactly where the failure's happening on your end you could try running with |
@slayton58 that ENV variable seems not to give any error report either. Anyway we'll have 7.4 available in several days, hope this issue just goes away. |
@slayton58 it does work with 7.4.2. |
@wat3rBro could you run with the following please and give me the resulting cudnn.log file please -- I can use it to file an internal bug (or email me at slayton (at) nvidia (dot) com if it's too big)
|
@wat3rBro Any progress? |
Hey @slayton58, I've sent you the log via email, sorry if I didn't notify you. |
@wat3rBro Hmm, I haven't seen anything come through my email - what address did you send it from? |
Maybe it's blocked, I just sent you again from my personal email. |
Curious, what's the current state of this PR? Some of the X-152 models I'm training with are huge and slow. I've seen notable improvements from training with NVIDIA's half-precision code and PyTorch in the past. |
hi @slayton58 |
I've made some tests on 4 x 2080Ti. I didn't experience any improvement in training time. In my tests, I set
@slayton58 what could be going wrong on these tests ? Could you provide some benchmarks on the improvements you've achieved ? |
@zimenglan-sysu-512 I'm still waiting on a merge. The PR should still be up-to-date and compatible with recent refactors of @LeviViana Speedup is dependent on network and batch size per GPU -- |
Thanks for your quick answer. Indeed, my precedent tests were based on a 1 img/gpu setup. I re-ran some other tests on the
|
thanks @slayton58 |
Sorry for the delay in merging, and thanks @slayton58 ! |
hi @slayton58 |
@zimenglan-sysu-512 No, labels will stay as |
I don't know if it was taken into account but what i installed:
error:
|
* Initial multi-precision training Adds fp16 support via apex.amp Also switches communication to apex.DistributedDataParallel * Add Apex install to dockerfile * Fixes from @fmassa review Added support to tools/test_net.py SOLVER.MIXED_PRECISION -> DTYPE \in {float32, float16} apex.amp not installed now raises ImportError * Remove extraneous apex DDP import * Move to new amp API
This PR adds initial mixed-precision training support via apex.amp.
Mixed-precision is controlled with the
SOLVER.MIXED_PRECISION
config argument.Along with
apex.amp
support, I've movedDistributedDataParallel
toapex.DistributedDataParallel
as this is what we've been using to good effect over the last few months.Please note that this does add the
apex
package as a requirement.