Skip to content
This repository has been archived by the owner on Oct 31, 2023. It is now read-only.

Commit

Permalink
Fixes from @fmassa review
Browse files Browse the repository at this point in the history
Added support to tools/test_net.py
SOLVER.MIXED_PRECISION -> DTYPE \in {float32, float16}
apex.amp not installed now raises ImportError
  • Loading branch information
slayton58 committed Nov 26, 2018
1 parent 275b290 commit a2ecbe7
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 8 deletions.
13 changes: 10 additions & 3 deletions maskrcnn_benchmark/config/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,9 +254,6 @@
# see 2 images per batch
_C.SOLVER.IMS_PER_BATCH = 16

# Whether or not to use mixed-precision (via apex.amp)
_C.SOLVER.MIXED_PRECISION = False

# ---------------------------------------------------------------------------- #
# Specific test options
# ---------------------------------------------------------------------------- #
Expand All @@ -275,3 +272,13 @@
_C.OUTPUT_DIR = "."

_C.PATHS_CATALOG = os.path.join(os.path.dirname(__file__), "paths_catalog.py")

# ---------------------------------------------------------------------------- #
# Precision options
# ---------------------------------------------------------------------------- #

# Precision of input, allowable: (float32, float16)
_C.DTYPE = "float32"

# Enable verbosity in apex.amp
_C.AMP_VERBOSE = False
2 changes: 1 addition & 1 deletion maskrcnn_benchmark/layers/batch_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def __init__(self, n):

def forward(self, x):
# Cast all fixed parameters to half() if necessary
if x.type() == torch.half:
if x.dtype == torch.float16:
self.weight = self.weight.half()
self.bias = self.bias.half()
self.running_mean = self.running_mean.half()
Expand Down
10 changes: 10 additions & 0 deletions tools/test_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,12 @@
from maskrcnn_benchmark.utils.logger import setup_logger
from maskrcnn_benchmark.utils.miscellaneous import mkdir

# Check if we can enable mixed-precision via apex.amp
try:
from apex import amp
except ImportError:
raise ImportError('Use APEX for mixed precision via apex.amp')


def main():
parser = argparse.ArgumentParser(description="PyTorch Object Detection Inference")
Expand Down Expand Up @@ -60,6 +66,10 @@ def main():
model = build_detection_model(cfg)
model.to(cfg.MODEL.DEVICE)

# Initialize mixed-precision if necessary
use_mixed_precision = cfg.DTYPE == 'float16'
amp_handle = amp.init(enabled=use_mixed_precision, verbose=cfg.AMP_VERBOSE)

output_dir = cfg.OUTPUT_DIR
checkpointer = DetectronCheckpointer(cfg, model, save_dir=output_dir)
_ = checkpointer.load(cfg.MODEL.WEIGHT)
Expand Down
10 changes: 6 additions & 4 deletions tools/train_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from apex.parallel import DistributedDataParallel as DDP
from apex import amp
except ImportError:
print('Use APEX for better performance via apex.amp and apex.DistributedDataParallel')
raise ImportError('Use APEX for better performance via apex.amp and apex.DistributedDataParallel')


def train(cfg, local_rank, distributed):
Expand All @@ -42,9 +42,11 @@ def train(cfg, local_rank, distributed):
optimizer = make_optimizer(cfg, model)
scheduler = make_lr_scheduler(cfg, optimizer)

# Wrap the optimizer for fp16 training
use_mixed_precision = cfg.SOLVER.MIXED_PRECISION
amp_handle = amp.init(enabled=use_mixed_precision, verbose=False)
# Initialize mixed-precision training
use_mixed_precision = cfg.DTYPE == "float16"
amp_handle = amp.init(enabled=use_mixed_precision, verbose=cfg.AMP_VERBOSE)

# wrap the optimizer for mixed precision
optimizer = amp_handle.wrap_optimizer(optimizer)

if distributed:
Expand Down

0 comments on commit a2ecbe7

Please sign in to comment.