Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Fix] Convert SyncBN to BN when training on DP #772

Merged
merged 14 commits into from
Sep 15, 2021
2 changes: 1 addition & 1 deletion .dev/gather_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def get_final_results(log_json_path, iter_num):
def parse_args():
parser = argparse.ArgumentParser(description='Gather benchmarked models')
parser.add_argument(
'-c', '--config-name', type=str, help='Process the selected config.')
'-f', '--config-name', type=str, help='Process the selected config.')
parser.add_argument(
'-w',
'--work-dir',
Expand Down
2 changes: 1 addition & 1 deletion docker/Dockerfile
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
ARG PYTORCH="1.6.0"
ARG CUDA="10.1"
ARG CUDNN="7"
ARG MMCV="1.3.12"
ARG MMCV="1.3.13"

FROM pytorch/pytorch:${PYTORCH}-cuda${CUDA}-cudnn${CUDNN}-devel

Expand Down
2 changes: 1 addition & 1 deletion docker/serve/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ ARG CUDA="10.1"
ARG CUDNN="7"
FROM pytorch/pytorch:${PYTORCH}-cuda${CUDA}-cudnn${CUDNN}-devel

ARG MMCV="1.3.12"
ARG MMCV="1.3.13"
clownrat6 marked this conversation as resolved.
Show resolved Hide resolved
ARG MMSEG="0.17.0"

ENV PYTHONUNBUFFERED TRUE
Expand Down
2 changes: 1 addition & 1 deletion docs/get_started.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ The compatible MMSegmentation and MMCV versions are as below. Please install the

| MMSegmentation version | MMCV version |
|:-------------------:|:-------------------:|
| master | mmcv-full>=1.3.7, <1.4.0 |
| master | mmcv-full>=1.3.13, <1.4.0 |
| 0.17.0 | mmcv-full>=1.3.7, <1.4.0 |
| 0.16.0 | mmcv-full>=1.3.7, <1.4.0 |
| 0.15.0 | mmcv-full>=1.3.7, <1.4.0 |
Expand Down
8 changes: 8 additions & 0 deletions docs/train.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,14 @@ To trade speed with GPU memory, you may pass in `--options model.backbone.with_c

### Train with a single GPU

official support:

```shell
./tools/dist_train.sh ${CONFIG_FILE} 1 [optional arguments]
```

experimental support (Convert SyncBN to BN):

```shell
python tools/train.py ${CONFIG_FILE} [optional arguments]
```
Expand Down
2 changes: 1 addition & 1 deletion docs_zh-CN/get_started.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

| MMSegmentation 版本 | MMCV 版本 |
|:-------------------:|:-------------------:|
| master | mmcv-full>=1.3.7, <1.4.0 |
| master | mmcv-full>=1.3.13, <1.4.0 |
| 0.17.0 | mmcv-full>=1.3.7, <1.4.0 |
| 0.16.0 | mmcv-full>=1.3.7, <1.4.0 |
| 0.15.0 | mmcv-full>=1.3.7, <1.4.0 |
Expand Down
2 changes: 1 addition & 1 deletion mmseg/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from .version import __version__, version_info

MMCV_MIN = '1.3.7'
MMCV_MIN = '1.3.13'
MMCV_MAX = '1.4.0'


Expand Down
26 changes: 2 additions & 24 deletions tests/test_models/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import pytest
import torch
import torch.nn as nn
from mmcv.utils.parrots_wrapper import SyncBatchNorm, _BatchNorm
from mmcv.cnn.utils import revert_sync_batchnorm


def _demo_mm_inputs(input_shape=(2, 3, 8, 16), num_classes=10):
Expand Down Expand Up @@ -189,28 +189,6 @@ def _check_input_dim(self, inputs):
pass


def _convert_batchnorm(module):
module_output = module
if isinstance(module, SyncBatchNorm):
# to be consistent with SyncBN, we hack dim check function in BN
module_output = _BatchNorm(module.num_features, module.eps,
module.momentum, module.affine,
module.track_running_stats)
if module.affine:
module_output.weight.data = module.weight.data.clone().detach()
module_output.bias.data = module.bias.data.clone().detach()
# keep requires_grad unchanged
module_output.weight.requires_grad = module.weight.requires_grad
module_output.bias.requires_grad = module.bias.requires_grad
module_output.running_mean = module.running_mean
module_output.running_var = module.running_var
module_output.num_batches_tracked = module.num_batches_tracked
for name, child in module.named_children():
module_output.add_module(name, _convert_batchnorm(child))
del module
return module_output


@patch('torch.nn.modules.batchnorm._BatchNorm._check_input_dim',
_check_input_dim)
@patch('torch.distributed.get_world_size', get_world_size)
Expand Down Expand Up @@ -241,7 +219,7 @@ def _test_encoder_decoder_forward(cfg_file):
imgs = imgs.cuda()
gt_semantic_seg = gt_semantic_seg.cuda()
else:
segmentor = _convert_batchnorm(segmentor)
segmentor = revert_sync_batchnorm(segmentor)

# Test forward train
losses = segmentor.forward(
Expand Down
10 changes: 10 additions & 0 deletions tools/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@
import os
import os.path as osp
import time
import warnings

import mmcv
import torch
from mmcv.cnn.utils import revert_sync_batchnorm
from mmcv.runner import get_dist_info, init_dist
from mmcv.utils import Config, DictAction, get_git_hash

Expand Down Expand Up @@ -137,6 +139,14 @@ def main():
test_cfg=cfg.get('test_cfg'))
model.init_weights()

# SyncBN is not support for DP
if not distributed:
warnings.warn(
'SyncBN is only supported with DDP. To be compatible with DP, '
'we convert SyncBN to BN. Please use dist_train.sh which can '
'avoid this error.')
model = revert_sync_batchnorm(model)

logger.info(model)

datasets = [build_dataset(cfg.data.train)]
Expand Down