-
Notifications
You must be signed in to change notification settings - Fork 2.5k
Initial mixed-precision training #196
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -116,6 +116,6 @@ def forward(self, x, boxes): | |
for level, (per_level_feature, pooler) in enumerate(zip(x, self.poolers)): | ||
idx_in_level = torch.nonzero(levels == level).squeeze(1) | ||
rois_per_level = rois[idx_in_level] | ||
result[idx_in_level] = pooler(per_level_feature, rois_per_level) | ||
result[idx_in_level] = pooler(per_level_feature, rois_per_level).to(dtype) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. doesn't the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'd need to run again to work out exactly what case was happening here, but the type change was not happening correctly and I had to manually cast here to prevent errors There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Quick question: is this casting still relevant? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes: There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. But I thought that Or is my understanding of it wrong? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ok. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Would you like a comment added explaining the current need for the cast? (along with a TODO for the full (fp16 support in pooling) if you so desire) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I just find it very unintuitive why you had to add this casting only here, and not after NMS as well. :-/ There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Because this is a strange case -- you're in fp16-land, allocate your output as fp16, then run something that has to cast up to fp32. There's no cast back automatically (you're calling from a module, so there's no module boundary to cause a cast) so it has to be done manually. If lines 111-115 didn't exist, and you inferred the type from the return type of the pooling call this explicit wouldn't have to be there. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I understand it's a little weird, but that's where the code is right now -- if you want an explicit (non-AMP) fp16 version that can also be done, but it'll be more invasive and can't be done before the new year after I get back |
||
|
||
return result |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So it seems that we don't explicitly cast the model to
fp16
during initialization, is that right?This seems a bit counter-intuitive to me, what happens if we just cast everything in
model
to.half()
?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Because not everything can be half, as not all ops support half. One could write a function to cast all ops that do support half,
cast_some_to_half(model)
maybe, but I decided to special case this one -- I'm open to other approaches :)There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What are the ops that do not support half (apart from the custom ones that are in this repo)? I thought that all ops in pytorch supported fp16 for cuda (with potentially bad accuracy)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In using "support" I chose my words badly. It's not necessarily support, more "can be used with a reasonable expectation of not losing accuracy". Apex.amp takes a conservative approach by not moving ops to fp16 when we're not sure of their accuracy (the lists of what is / isn't moved to fp16 are in the files here.
We could try casting the entire model to half and see what happens -- there's enough code in the RPN especially that I'm just not sure how it'll behave in lower precisions that I decided to be conservative beyond what apex.amp does. Unfortunately that means until PyT can grok
y_16 = a_32 * x_16 + b_32
where the subscripts denote precision we have to do something manual here