diff --git a/maskrcnn_benchmark/config/defaults.py b/maskrcnn_benchmark/config/defaults.py index 90cf7fc88..17b107bec 100644 --- a/maskrcnn_benchmark/config/defaults.py +++ b/maskrcnn_benchmark/config/defaults.py @@ -254,9 +254,6 @@ # see 2 images per batch _C.SOLVER.IMS_PER_BATCH = 16 -# Whether or not to use mixed-precision (via apex.amp) -_C.SOLVER.MIXED_PRECISION = False - # ---------------------------------------------------------------------------- # # Specific test options # ---------------------------------------------------------------------------- # @@ -275,3 +272,13 @@ _C.OUTPUT_DIR = "." _C.PATHS_CATALOG = os.path.join(os.path.dirname(__file__), "paths_catalog.py") + +# ---------------------------------------------------------------------------- # +# Precision options +# ---------------------------------------------------------------------------- # + +# Precision of input, allowable: (float32, float16) +_C.DTYPE = "float32" + +# Enable verbosity in apex.amp +_C.AMP_VERBOSE = False diff --git a/maskrcnn_benchmark/layers/batch_norm.py b/maskrcnn_benchmark/layers/batch_norm.py index 712faedc1..3762e49e8 100644 --- a/maskrcnn_benchmark/layers/batch_norm.py +++ b/maskrcnn_benchmark/layers/batch_norm.py @@ -18,7 +18,7 @@ def __init__(self, n): def forward(self, x): # Cast all fixed parameters to half() if necessary - if x.type() == torch.half: + if x.dtype == torch.float16: self.weight = self.weight.half() self.bias = self.bias.half() self.running_mean = self.running_mean.half() diff --git a/tools/test_net.py b/tools/test_net.py index e1eb33e76..effa180ba 100644 --- a/tools/test_net.py +++ b/tools/test_net.py @@ -17,6 +17,12 @@ from maskrcnn_benchmark.utils.logger import setup_logger from maskrcnn_benchmark.utils.miscellaneous import mkdir +# Check if we can enable mixed-precision via apex.amp +try: + from apex import amp +except ImportError: + raise ImportError('Use APEX for mixed precision via apex.amp') + def main(): parser = argparse.ArgumentParser(description="PyTorch Object Detection Inference") @@ -60,6 +66,10 @@ def main(): model = build_detection_model(cfg) model.to(cfg.MODEL.DEVICE) + # Initialize mixed-precision if necessary + use_mixed_precision = cfg.DTYPE == 'float16' + amp_handle = amp.init(enabled=use_mixed_precision, verbose=cfg.AMP_VERBOSE) + output_dir = cfg.OUTPUT_DIR checkpointer = DetectronCheckpointer(cfg, model, save_dir=output_dir) _ = checkpointer.load(cfg.MODEL.WEIGHT) diff --git a/tools/train_net.py b/tools/train_net.py index 34c4ac4a8..b089a9cbb 100644 --- a/tools/train_net.py +++ b/tools/train_net.py @@ -31,7 +31,7 @@ from apex.parallel import DistributedDataParallel as DDP from apex import amp except ImportError: - print('Use APEX for better performance via apex.amp and apex.DistributedDataParallel') + raise ImportError('Use APEX for better performance via apex.amp and apex.DistributedDataParallel') def train(cfg, local_rank, distributed): @@ -42,9 +42,11 @@ def train(cfg, local_rank, distributed): optimizer = make_optimizer(cfg, model) scheduler = make_lr_scheduler(cfg, optimizer) - # Wrap the optimizer for fp16 training - use_mixed_precision = cfg.SOLVER.MIXED_PRECISION - amp_handle = amp.init(enabled=use_mixed_precision, verbose=False) + # Initialize mixed-precision training + use_mixed_precision = cfg.DTYPE == "float16" + amp_handle = amp.init(enabled=use_mixed_precision, verbose=cfg.AMP_VERBOSE) + + # wrap the optimizer for mixed precision optimizer = amp_handle.wrap_optimizer(optimizer) if distributed: