Skip to content

Commit

Permalink
This branch is used to match test precision between openmmlab repo an…
Browse files Browse the repository at this point in the history
…d HRNet raw repo;

* Add HRNet weights convert script (tools/scripts/convert_hrnet.py);

* Modify test_pipeline of pascal context dataset;

* Add ignore items;
  • Loading branch information
sennnnn committed Apr 16, 2021
1 parent 9fb99b4 commit 803dca1
Show file tree
Hide file tree
Showing 7 changed files with 64 additions and 9 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,8 @@ data
.idea

# custom
*.txt
*.jpg
*.npy
*.pkl
*.pkl.json
Expand Down
3 changes: 2 additions & 1 deletion configs/_base_/datasets/pascal_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@
dict(type='LoadImageFromFile'),
dict(
type='MultiScaleFlipAug',
img_scale=img_scale,
img_scale=None,
img_ratios=[1.0],
# img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75],
flip=False,
transforms=[
Expand Down
3 changes: 2 additions & 1 deletion configs/_base_/datasets/pascal_context_59.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@
dict(type='LoadImageFromFile'),
dict(
type='MultiScaleFlipAug',
img_scale=img_scale,
img_scale=None,
img_ratios=[1.0],
# img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75],
flip=False,
transforms=[
Expand Down
14 changes: 10 additions & 4 deletions mmseg/datasets/pipelines/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,10 +389,16 @@ def __call__(self, results):
result dict.
"""

results['img'] = mmcv.imnormalize(results['img'], self.mean, self.std,
self.to_rgb)
results['img_norm_cfg'] = dict(
mean=self.mean, std=self.std, to_rgb=self.to_rgb)
# results['img'] = mmcv.imnormalize(results['img'], self.mean,
# self.std, self.to_rgb)
# results['img_norm_cfg'] = dict(
# mean=self.mean, std=self.std, to_rgb=self.to_rgb)
image = results['img']
image = image.astype(np.float32)[:, :, ::-1]
image = image / 255.0
image -= [0.485, 0.456, 0.406]
image /= [0.229, 0.224, 0.225]
results['img'] = image
return results

def __repr__(self):
Expand Down
1 change: 1 addition & 0 deletions mmseg/models/decode_heads/fcn_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def __init__(self,
kernel_size=kernel_size,
padding=conv_padding,
dilation=dilation,
bias=True,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg))
Expand Down
44 changes: 44 additions & 0 deletions tools/scripts/convert_hrnet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import argparse
from collections import OrderedDict

import torch


def convert(src, dst):
"""Convert keys in detectron pretrained ResNet models to pytorch style."""
# convert to pytorch style
state_dict = OrderedDict()
src_dict = torch.load(src)
src_state_dict = src_dict.get('state_dict', src_dict)
for k, v in src_state_dict.items():
new_key = k.replace('model', 'backbone')
if new_key.startswith('backbone.last_layer.0'):
state_dict[new_key.replace('backbone.last_layer.0',
'decode_head.convs.0.conv')] = v
elif new_key.startswith('backbone.last_layer.1'):
state_dict[new_key.replace('backbone.last_layer.1',
'decode_head.convs.0.bn')] = v
elif new_key.startswith('backbone.last_layer.3'):
state_dict[new_key.replace('backbone.last_layer.3',
'decode_head.conv_seg')] = v
else:
state_dict[new_key] = v

# save checkpoint
checkpoint = dict()
checkpoint['state_dict'] = state_dict
assert len(state_dict) == len(src_state_dict)
checkpoint['meta'] = dict()
torch.save(checkpoint, dst)


def main():
parser = argparse.ArgumentParser(description='Convert model keys')
parser.add_argument('src', help='src detectron model path')
parser.add_argument('dst', help='save path')
args = parser.parse_args()
convert(args.src, args.dst)


if __name__ == '__main__':
main()
6 changes: 3 additions & 3 deletions tools/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,9 +117,9 @@ def main():
# build the model and load checkpoint
cfg.model.train_cfg = None
model = build_segmentor(cfg.model, test_cfg=cfg.get('test_cfg'))
checkpoint = load_checkpoint(model, args.checkpoint, map_location='cpu')
model.CLASSES = checkpoint['meta']['CLASSES']
model.PALETTE = checkpoint['meta']['PALETTE']
_ = load_checkpoint(model, args.checkpoint, map_location='cpu')
model.CLASSES = dataset.CLASSES
model.PALETTE = dataset.PALETTE

efficient_test = False
if args.eval_options is not None:
Expand Down

0 comments on commit 803dca1

Please sign in to comment.