Skip to content
This repository has been archived by the owner on Oct 31, 2023. It is now read-only.

Initial mixed-precision training #196

Merged
merged 5 commits into from
Apr 19, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions INSTALL.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,12 @@ git clone https://github.com/cocodataset/cocoapi.git
cd cocoapi/PythonAPI
python setup.py build_ext install

# install apex
cd ~github
git clone https://github.com/NVIDIA/apex.git
cd apex
python setup.py install --cuda_ext --cpp_ext

# install PyTorch Detection
cd ~/github
git clone https://github.com/facebookresearch/maskrcnn-benchmark.git
Expand Down
5 changes: 5 additions & 0 deletions docker/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,11 @@ RUN git clone https://github.com/cocodataset/cocoapi.git \
&& cd cocoapi/PythonAPI \
&& python setup.py build_ext install

# install apex
RUN git clone https://github.com/NVIDIA/apex.git \
&& cd apex \
&& python setup.py install --cuda_ext --cpp_ext

# install PyTorch Detection
RUN git clone https://github.com/facebookresearch/maskrcnn-benchmark.git \
&& cd maskrcnn-benchmark \
Expand Down
10 changes: 10 additions & 0 deletions maskrcnn_benchmark/config/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,3 +272,13 @@
_C.OUTPUT_DIR = "."

_C.PATHS_CATALOG = os.path.join(os.path.dirname(__file__), "paths_catalog.py")

# ---------------------------------------------------------------------------- #
# Precision options
# ---------------------------------------------------------------------------- #

# Precision of input, allowable: (float32, float16)
_C.DTYPE = "float32"

# Enable verbosity in apex.amp
_C.AMP_VERBOSE = False
6 changes: 5 additions & 1 deletion maskrcnn_benchmark/engine/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from maskrcnn_benchmark.utils.comm import get_world_size
from maskrcnn_benchmark.utils.metric_logger import MetricLogger

from apex import amp

def reduce_loss_dict(loss_dict):
"""
Expand Down Expand Up @@ -73,7 +74,10 @@ def do_train(
meters.update(loss=losses_reduced, **loss_dict_reduced)

optimizer.zero_grad()
losses.backward()
# Note: If mixed precision is not used, this ends up doing nothing
# Otherwise apply loss scaling for mixed-precision recipe
with amp.scale_loss(losses, optimizer) as scaled_losses:
scaled_losses.backward()
optimizer.step()

batch_time = time.time() - end
Expand Down
7 changes: 7 additions & 0 deletions maskrcnn_benchmark/layers/batch_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,13 @@ def __init__(self, n):
self.register_buffer("running_var", torch.ones(n))

def forward(self, x):
# Cast all fixed parameters to half() if necessary
if x.dtype == torch.float16:
self.weight = self.weight.half()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So it seems that we don't explicitly cast the model to fp16 during initialization, is that right?
This seems a bit counter-intuitive to me, what happens if we just cast everything in model to .half()?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because not everything can be half, as not all ops support half. One could write a function to cast all ops that do support half, cast_some_to_half(model) maybe, but I decided to special case this one -- I'm open to other approaches :)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What are the ops that do not support half (apart from the custom ones that are in this repo)? I thought that all ops in pytorch supported fp16 for cuda (with potentially bad accuracy)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In using "support" I chose my words badly. It's not necessarily support, more "can be used with a reasonable expectation of not losing accuracy". Apex.amp takes a conservative approach by not moving ops to fp16 when we're not sure of their accuracy (the lists of what is / isn't moved to fp16 are in the files here.

We could try casting the entire model to half and see what happens -- there's enough code in the RPN especially that I'm just not sure how it'll behave in lower precisions that I decided to be conservative beyond what apex.amp does. Unfortunately that means until PyT can grok y_16 = a_32 * x_16 + b_32 where the subscripts denote precision we have to do something manual here

self.bias = self.bias.half()
self.running_mean = self.running_mean.half()
self.running_var = self.running_var.half()

scale = self.weight * self.running_var.rsqrt()
bias = self.bias - self.running_mean * scale
scale = scale.reshape(1, -1, 1, 1)
Expand Down
6 changes: 5 additions & 1 deletion maskrcnn_benchmark/layers/nms.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@
# from ._utils import _C
from maskrcnn_benchmark import _C

nms = _C.nms
from apex import amp

# Only valid with fp32 inputs - give AMP the hint
nms = amp.float_function(_C.nms)

# nms.__doc__ = """
# This function performs Non-maximum suppresion"""
3 changes: 2 additions & 1 deletion maskrcnn_benchmark/layers/roi_align.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from maskrcnn_benchmark import _C

from apex import amp

class _ROIAlign(Function):
@staticmethod
Expand Down Expand Up @@ -46,14 +47,14 @@ def backward(ctx, grad_output):

roi_align = _ROIAlign.apply


class ROIAlign(nn.Module):
def __init__(self, output_size, spatial_scale, sampling_ratio):
super(ROIAlign, self).__init__()
self.output_size = output_size
self.spatial_scale = spatial_scale
self.sampling_ratio = sampling_ratio

@amp.float_function
def forward(self, input, rois):
return roi_align(
input, rois, self.output_size, self.spatial_scale, self.sampling_ratio
Expand Down
2 changes: 2 additions & 0 deletions maskrcnn_benchmark/layers/roi_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from maskrcnn_benchmark import _C

from apex import amp

class _ROIPool(Function):
@staticmethod
Expand Down Expand Up @@ -52,6 +53,7 @@ def __init__(self, output_size, spatial_scale):
self.output_size = output_size
self.spatial_scale = spatial_scale

@amp.float_function
def forward(self, input, rois):
return roi_pool(input, rois, self.output_size, self.spatial_scale)

Expand Down
2 changes: 1 addition & 1 deletion maskrcnn_benchmark/modeling/poolers.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,6 @@ def forward(self, x, boxes):
for level, (per_level_feature, pooler) in enumerate(zip(x, self.poolers)):
idx_in_level = torch.nonzero(levels == level).squeeze(1)
rois_per_level = rois[idx_in_level]
result[idx_in_level] = pooler(per_level_feature, rois_per_level)
result[idx_in_level] = pooler(per_level_feature, rois_per_level).to(dtype)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

doesn't the amp.float_function wrap back the values to fp16 after they are computed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd need to run again to work out exactly what case was happening here, but the type change was not happening correctly and I had to manually cast here to prevent errors

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Quick question: is this casting still relevant?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes: result[idx_in_level] is expected to be in fp16 (as it's the same precision as the input), but the pooler returns fp32 (explicitly, as the code hasn't had fp16 support added). To get around this the result from the pooler needs to be casted

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But I thought that amp.float_function would:
1 - cast to float
2 - compute
3 - cast back to fp16

Or is my understanding of it wrong?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok.
I think the solution I'd potentially do myself (while support for fp16 is not present in the core pooler functions) is to just cast in the C++ side to float and cast back if the type is fp16.
But I suppose this is not really a hard-requirement here (but would make things cleaner)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would you like a comment added explaining the current need for the cast? (along with a TODO for the full (fp16 support in pooling) if you so desire)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just find it very unintuitive why you had to add this casting only here, and not after NMS as well. :-/

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because this is a strange case -- you're in fp16-land, allocate your output as fp16, then run something that has to cast up to fp32. There's no cast back automatically (you're calling from a module, so there's no module boundary to cause a cast) so it has to be done manually. If lines 111-115 didn't exist, and you inferred the type from the return type of the pooling call this explicit wouldn't have to be there.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I understand it's a little weird, but that's where the code is right now -- if you want an explicit (non-AMP) fp16 version that can also be done, but it'll be more invasive and can't be done before the new year after I get back


return result
5 changes: 5 additions & 0 deletions maskrcnn_benchmark/modeling/roi_heads/mask_head/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,11 +111,16 @@ def expand_masks(mask, padding):
pad2 = 2 * padding
scale = float(M + pad2) / M
padded_mask = mask.new_zeros((N, 1, M + pad2, M + pad2))

padded_mask[:, :, padding:-padding, padding:-padding] = mask
return padded_mask, scale


def paste_mask_in_image(mask, box, im_h, im_w, thresh=0.5, padding=1):
# Need to work on the CPU, where fp16 isn't supported - cast to float to avoid this
mask = mask.float()
box = box.float()

padded_mask, scale = expand_masks(mask[None], padding=padding)
mask = padded_mask[0, 0]
box = expand_boxes(box[None], scale)[0]
Expand Down
10 changes: 10 additions & 0 deletions tools/test_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,12 @@
from maskrcnn_benchmark.utils.logger import setup_logger
from maskrcnn_benchmark.utils.miscellaneous import mkdir

# Check if we can enable mixed-precision via apex.amp
try:
from apex import amp
except ImportError:
raise ImportError('Use APEX for mixed precision via apex.amp')


def main():
parser = argparse.ArgumentParser(description="PyTorch Object Detection Inference")
Expand Down Expand Up @@ -60,6 +66,10 @@ def main():
model = build_detection_model(cfg)
model.to(cfg.MODEL.DEVICE)

# Initialize mixed-precision if necessary
use_mixed_precision = cfg.DTYPE == 'float16'
amp_handle = amp.init(enabled=use_mixed_precision, verbose=cfg.AMP_VERBOSE)

output_dir = cfg.OUTPUT_DIR
checkpointer = DetectronCheckpointer(cfg, model, save_dir=output_dir)
_ = checkpointer.load(cfg.MODEL.WEIGHT)
Expand Down
12 changes: 12 additions & 0 deletions tools/train_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,13 @@
from maskrcnn_benchmark.utils.logger import setup_logger
from maskrcnn_benchmark.utils.miscellaneous import mkdir

# See if we can use apex.DistributedDataParallel instead of the torch default,
# and enable mixed-precision via apex.amp
try:
from apex import amp
except ImportError:
raise ImportError('Use APEX for multi-precision via apex.amp')


def train(cfg, local_rank, distributed):
model = build_detection_model(cfg)
Expand All @@ -34,6 +41,11 @@ def train(cfg, local_rank, distributed):
optimizer = make_optimizer(cfg, model)
scheduler = make_lr_scheduler(cfg, optimizer)

# Initialize mixed-precision training
use_mixed_precision = cfg.DTYPE == "float16"
amp_opt_level = 'O1' if use_mixed_precision else 'O0'
model, optimizer = amp.initialize(model, optimizer, opt_level=amp_opt_level)

if distributed:
model = torch.nn.parallel.DistributedDataParallel(
model, device_ids=[local_rank], output_device=local_rank,
Expand Down