-
-
Notifications
You must be signed in to change notification settings - Fork 108
/
mcibi.py
145 lines (143 loc) · 8.38 KB
/
mcibi.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
'''
Function:
Implementation of "Mining Contextual Information Beyond Image for Semantic Segmentation"
Author:
Zhenchao Jin
'''
import copy
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributed as dist
from ..deeplabv3 import ASPP
from ..base import BaseSegmentor
from .memory import FeaturesMemory
from ..pspnet import PyramidPoolingModule
from ....utils import SSSegOutputStructure
from ...losses import calculatelosses, calculateloss
from ...backbones import BuildActivation, BuildNormalization
'''MCIBI'''
class MCIBI(BaseSegmentor):
def __init__(self, cfg, mode):
super(MCIBI, self).__init__(cfg, mode)
align_corners, norm_cfg, act_cfg, head_cfg = self.align_corners, self.norm_cfg, self.act_cfg, cfg['head']
# build norm layer
if 'norm_cfg' in head_cfg:
self.norm_layers = nn.ModuleList()
for in_channels in head_cfg['norm_cfg']['in_channels_list']:
norm_cfg_copy = head_cfg['norm_cfg'].copy()
norm_cfg_copy.pop('in_channels_list')
norm_layer = BuildNormalization(placeholder=in_channels, norm_cfg=norm_cfg_copy)
self.norm_layers.append(norm_layer)
# build memory
if head_cfg['downsample_backbone']['stride'] > 1:
self.downsample_backbone = nn.Sequential(
nn.Conv2d(head_cfg['in_channels'], head_cfg['in_channels'], **head_cfg['downsample_backbone']),
BuildNormalization(placeholder=head_cfg['in_channels'], norm_cfg=norm_cfg),
BuildActivation(act_cfg),
)
context_within_image_cfg = head_cfg['context_within_image']
if context_within_image_cfg['is_on']:
cwi_cfg = context_within_image_cfg['cfg']
cwi_cfg.update({
'in_channels': head_cfg['in_channels'], 'out_channels': head_cfg['feats_channels'], 'align_corners': align_corners,
'norm_cfg': copy.deepcopy(norm_cfg), 'act_cfg': copy.deepcopy(act_cfg),
})
supported_context_modules = {
'aspp': ASPP, 'ppm': PyramidPoolingModule,
}
self.context_within_image_module = supported_context_modules[context_within_image_cfg['type']](**cwi_cfg)
self.bottleneck = nn.Sequential(
nn.Conv2d(head_cfg['in_channels'], head_cfg['feats_channels'], kernel_size=3, stride=1, padding=1, bias=False),
BuildNormalization(placeholder=head_cfg['feats_channels'], norm_cfg=norm_cfg),
BuildActivation(act_cfg),
)
self.memory_module = FeaturesMemory(
num_classes=cfg['num_classes'], feats_channels=head_cfg['feats_channels'], transform_channels=head_cfg['transform_channels'], num_feats_per_cls=head_cfg['num_feats_per_cls'],
out_channels=head_cfg['out_channels'], use_context_within_image=context_within_image_cfg['is_on'], use_hard_aggregate=head_cfg['use_hard_aggregate'],
norm_cfg=copy.deepcopy(norm_cfg), act_cfg=copy.deepcopy(act_cfg),
)
# build decoder
self.decoder_stage1 = nn.Sequential(
nn.Conv2d(head_cfg['feats_channels'], head_cfg['feats_channels'], kernel_size=1, stride=1, padding=0, bias=False),
BuildNormalization(placeholder=head_cfg['feats_channels'], norm_cfg=norm_cfg),
BuildActivation(act_cfg),
nn.Dropout2d(head_cfg['dropout']),
nn.Conv2d(head_cfg['feats_channels'], cfg['num_classes'], kernel_size=1, stride=1, padding=0),
)
self.decoder_stage2 = nn.Sequential(
nn.Conv2d(head_cfg['out_channels'], head_cfg['out_channels'], kernel_size=1, stride=1, padding=0, bias=False),
BuildNormalization(placeholder=head_cfg['out_channels'], norm_cfg=norm_cfg),
BuildActivation(act_cfg),
nn.Dropout2d(head_cfg['dropout']),
nn.Conv2d(head_cfg['out_channels'], cfg['num_classes'], kernel_size=1, stride=1, padding=0)
)
# build auxiliary decoder
self.setauxiliarydecoder(cfg['auxiliary'])
# freeze normalization layer if necessary
if cfg.get('is_freeze_norm', False): self.freezenormalization()
'''forward'''
def forward(self, data_meta, **kwargs):
img_size = data_meta.images.size(2), data_meta.images.size(3)
# feed to backbone network
backbone_outputs = self.transforminputs(self.backbone_net(data_meta.images), selected_indices=self.cfg['backbone'].get('selected_indices'))
if hasattr(self, 'norm_layers'):
assert len(backbone_outputs) == len(self.norm_layers)
for idx in range(len(backbone_outputs)):
backbone_outputs[idx] = self.norm(backbone_outputs[idx], self.norm_layers[idx])
if self.cfg['head']['downsample_backbone']['stride'] > 1:
for idx in range(len(backbone_outputs)):
backbone_outputs[idx] = self.downsample_backbone(backbone_outputs[idx])
# feed to context within image module
feats_ms = self.context_within_image_module(backbone_outputs[-1]) if hasattr(self, 'context_within_image_module') else None
# feed to memory
memory_input = self.bottleneck(backbone_outputs[-1])
preds_stage1 = self.decoder_stage1(memory_input)
stored_memory, memory_output = self.memory_module(memory_input, preds_stage1, feats_ms)
# feed to decoder
preds_stage2 = self.decoder_stage2(memory_output)
# forward according to the mode
if self.mode in ['TRAIN', 'TRAIN_DEVELOP']:
predictions = self.customizepredsandlosses(
seg_logits=preds_stage2, annotations=data_meta.getannotations(), backbone_outputs=backbone_outputs, losses_cfg=self.cfg['losses'], img_size=img_size, auto_calc_loss=False,
)
preds_stage2 = predictions.pop('loss_cls')
preds_stage1 = F.interpolate(preds_stage1, size=img_size, mode='bilinear', align_corners=self.align_corners)
predictions.update({'loss_cls_stage1': preds_stage1, 'loss_cls_stage2': preds_stage2})
with torch.no_grad():
self.memory_module.update(
features=F.interpolate(memory_input, size=img_size, mode='bilinear', align_corners=self.align_corners),
segmentation=data_meta.getannotations()['seg_targets'], learning_rate=kwargs['learning_rate'], **self.cfg['head']['update_cfg']
)
loss, losses_log_dict = calculatelosses(
predictions=predictions, annotations=data_meta.getannotations(), losses_cfg=self.cfg['losses'], pixel_sampler=self.pixel_sampler
)
if (kwargs['epoch'] > 1) and self.cfg['head']['use_loss']:
loss_memory, loss_memory_log = self.calculatememoryloss(stored_memory)
loss += loss_memory
losses_log_dict['loss_memory'] = loss_memory_log
loss_total = losses_log_dict.pop('loss_total') + losses_log_dict['loss_memory']
losses_log_dict['loss_total'] = loss_total
ssseg_outputs = SSSegOutputStructure(mode=self.mode, loss=loss, losses_log_dict=losses_log_dict) if self.mode == 'TRAIN' else SSSegOutputStructure(mode=self.mode, loss=loss, losses_log_dict=losses_log_dict, seg_logits=preds_stage2)
else:
ssseg_outputs = SSSegOutputStructure(mode=self.mode, seg_logits=preds_stage2)
return ssseg_outputs
'''norm'''
def norm(self, x, norm_layer):
n, c, h, w = x.shape
x = x.reshape(n, c, h * w).transpose(2, 1).contiguous()
x = norm_layer(x)
x = x.transpose(1, 2).reshape(n, c, h, w).contiguous()
return x
'''calculatememoryloss'''
def calculatememoryloss(self, stored_memory):
num_classes, num_feats_per_cls, feats_channels = stored_memory.size()
stored_memory = stored_memory.reshape(num_classes * num_feats_per_cls, feats_channels, 1, 1)
preds_memory = self.decoder_stage2(stored_memory)
target = torch.range(0, num_classes - 1).type_as(stored_memory).long()
target = target.unsqueeze(1).repeat(1, num_feats_per_cls).view(-1)
loss_memory = calculateloss(preds_memory.squeeze(-1).squeeze(-1), target, self.cfg['head']['loss_cfg'])
loss_memory_log = loss_memory.data.clone()
if dist.is_available() and dist.is_initialized():
dist.all_reduce(loss_memory_log.div_(dist.get_world_size()), op=dist.ReduceOp.SUM)
return loss_memory, loss_memory_log.item()