Skip to content

Commit

Permalink
Modify init_cfg arg
Browse files Browse the repository at this point in the history
  • Loading branch information
sennnnn committed Jun 24, 2021
1 parent ebcb241 commit b24a981
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 10 deletions.
14 changes: 7 additions & 7 deletions mmseg/models/decode_heads/setr_up_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,32 +18,32 @@ class SETRUPHead(BaseDecodeHead):
up_scale (int): The scale factor of interpolate. Default:4.
kernel_size (int): The kernel size of convolution when decoding
feature information from backbone. Default: 3.
init_cfg (dict | list[dict] | None): Initialization config dict.
Default: dict(
type='Constant', val=1.0, bias=0, layer='LayerNorm').
"""

def __init__(self,
norm_layer=dict(type='LN', eps=1e-6, requires_grad=True),
num_convs=1,
up_scale=4,
kernel_size=3,
init_cfg=None,
init_cfg=dict(
type='Constant', val=1.0, bias=0, layer='LayerNorm'),
**kwargs):

assert kernel_size in [1, 3], 'kernel_size must be 1 or 3.'

super(SETRUPHead, self).__init__(**kwargs)
super(SETRUPHead, self).__init__(init_cfg=init_cfg, **kwargs)

assert isinstance(self.in_channels, int)

self.init_cfg = [
dict(type='Constant', val=1.0, bias=0, layer='LayerNorm')
]

_, self.norm = build_norm_layer(norm_layer, self.in_channels)

self.up_convs = nn.ModuleList()
in_channels = self.in_channels
out_channels = self.channels
for i in range(num_convs):
for _ in range(num_convs):
self.up_convs.append(
nn.Sequential(
ConvModule(
Expand Down
7 changes: 4 additions & 3 deletions tests/test_models/test_heads/test_setr_up_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,14 @@ def test_setr_up_head(capsys):
# as embed_dim.
SETRUPHead(in_channels=(32, 32), channels=16, num_classes=19)

# test init_weights of head
# test init_cfg of head
head = SETRUPHead(
in_channels=32,
channels=16,
norm_cfg=dict(type='SyncBN'),
num_classes=19)
head.init_weights()
num_classes=19,
init_cfg=dict(type='Kaiming'))
super(SETRUPHead, head).init_weights()

# test inference of Naive head
# the auxiliary head of Naive head is same as Naive head
Expand Down

0 comments on commit b24a981

Please sign in to comment.