Skip to content

Commit

Permalink
[Fix] Convert SyncBN to BN when training on DP (#772)
Browse files Browse the repository at this point in the history
* [Fix] Convert SyncBN to BN when training on DP.

* Modify SyncBN2BN.

* Add SyncBN2BN unit test.

* Resolve some comments.

* use mmcv official revert_sync_batchnorm

* Remove local syncbn2bn unit tests.

* Update mmcv version.

* Fix bugs of gather model tools.

* Modify warnings.

* Modify docker mmcv version.

* Update mmcv version table.
  • Loading branch information
clownrat6 authored Sep 15, 2021
1 parent 5a7996d commit cae715a
Show file tree
Hide file tree
Showing 9 changed files with 26 additions and 30 deletions.
2 changes: 1 addition & 1 deletion .dev/gather_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def get_final_results(log_json_path, iter_num):
def parse_args():
parser = argparse.ArgumentParser(description='Gather benchmarked models')
parser.add_argument(
'-c', '--config-name', type=str, help='Process the selected config.')
'-f', '--config-name', type=str, help='Process the selected config.')
parser.add_argument(
'-w',
'--work-dir',
Expand Down
2 changes: 1 addition & 1 deletion docker/Dockerfile
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
ARG PYTORCH="1.6.0"
ARG CUDA="10.1"
ARG CUDNN="7"
ARG MMCV="1.3.12"
ARG MMCV="1.3.13"

FROM pytorch/pytorch:${PYTORCH}-cuda${CUDA}-cudnn${CUDNN}-devel

Expand Down
2 changes: 1 addition & 1 deletion docker/serve/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ ARG CUDA="10.1"
ARG CUDNN="7"
FROM pytorch/pytorch:${PYTORCH}-cuda${CUDA}-cudnn${CUDNN}-devel

ARG MMCV="1.3.12"
ARG MMCV="1.3.13"
ARG MMSEG="0.17.0"

ENV PYTHONUNBUFFERED TRUE
Expand Down
2 changes: 1 addition & 1 deletion docs/get_started.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ The compatible MMSegmentation and MMCV versions are as below. Please install the

| MMSegmentation version | MMCV version |
|:-------------------:|:-------------------:|
| master | mmcv-full>=1.3.7, <1.4.0 |
| master | mmcv-full>=1.3.13, <1.4.0 |
| 0.17.0 | mmcv-full>=1.3.7, <1.4.0 |
| 0.16.0 | mmcv-full>=1.3.7, <1.4.0 |
| 0.15.0 | mmcv-full>=1.3.7, <1.4.0 |
Expand Down
8 changes: 8 additions & 0 deletions docs/train.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,14 @@ To trade speed with GPU memory, you may pass in `--options model.backbone.with_c

### Train with a single GPU

official support:

```shell
./tools/dist_train.sh ${CONFIG_FILE} 1 [optional arguments]
```

experimental support (Convert SyncBN to BN):

```shell
python tools/train.py ${CONFIG_FILE} [optional arguments]
```
Expand Down
2 changes: 1 addition & 1 deletion docs_zh-CN/get_started.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

| MMSegmentation 版本 | MMCV 版本 |
|:-------------------:|:-------------------:|
| master | mmcv-full>=1.3.7, <1.4.0 |
| master | mmcv-full>=1.3.13, <1.4.0 |
| 0.17.0 | mmcv-full>=1.3.7, <1.4.0 |
| 0.16.0 | mmcv-full>=1.3.7, <1.4.0 |
| 0.15.0 | mmcv-full>=1.3.7, <1.4.0 |
Expand Down
2 changes: 1 addition & 1 deletion mmseg/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from .version import __version__, version_info

MMCV_MIN = '1.3.7'
MMCV_MIN = '1.3.13'
MMCV_MAX = '1.4.0'


Expand Down
26 changes: 2 additions & 24 deletions tests/test_models/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import pytest
import torch
import torch.nn as nn
from mmcv.utils.parrots_wrapper import SyncBatchNorm, _BatchNorm
from mmcv.cnn.utils import revert_sync_batchnorm


def _demo_mm_inputs(input_shape=(2, 3, 8, 16), num_classes=10):
Expand Down Expand Up @@ -189,28 +189,6 @@ def _check_input_dim(self, inputs):
pass


def _convert_batchnorm(module):
module_output = module
if isinstance(module, SyncBatchNorm):
# to be consistent with SyncBN, we hack dim check function in BN
module_output = _BatchNorm(module.num_features, module.eps,
module.momentum, module.affine,
module.track_running_stats)
if module.affine:
module_output.weight.data = module.weight.data.clone().detach()
module_output.bias.data = module.bias.data.clone().detach()
# keep requires_grad unchanged
module_output.weight.requires_grad = module.weight.requires_grad
module_output.bias.requires_grad = module.bias.requires_grad
module_output.running_mean = module.running_mean
module_output.running_var = module.running_var
module_output.num_batches_tracked = module.num_batches_tracked
for name, child in module.named_children():
module_output.add_module(name, _convert_batchnorm(child))
del module
return module_output


@patch('torch.nn.modules.batchnorm._BatchNorm._check_input_dim',
_check_input_dim)
@patch('torch.distributed.get_world_size', get_world_size)
Expand Down Expand Up @@ -241,7 +219,7 @@ def _test_encoder_decoder_forward(cfg_file):
imgs = imgs.cuda()
gt_semantic_seg = gt_semantic_seg.cuda()
else:
segmentor = _convert_batchnorm(segmentor)
segmentor = revert_sync_batchnorm(segmentor)

# Test forward train
losses = segmentor.forward(
Expand Down
10 changes: 10 additions & 0 deletions tools/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@
import os
import os.path as osp
import time
import warnings

import mmcv
import torch
from mmcv.cnn.utils import revert_sync_batchnorm
from mmcv.runner import get_dist_info, init_dist
from mmcv.utils import Config, DictAction, get_git_hash

Expand Down Expand Up @@ -137,6 +139,14 @@ def main():
test_cfg=cfg.get('test_cfg'))
model.init_weights()

# SyncBN is not support for DP
if not distributed:
warnings.warn(
'SyncBN is only supported with DDP. To be compatible with DP, '
'we convert SyncBN to BN. Please use dist_train.sh which can '
'avoid this error.')
model = revert_sync_batchnorm(model)

logger.info(model)

datasets = [build_dataset(cfg.data.train)]
Expand Down

0 comments on commit cae715a

Please sign in to comment.