Skip to content

Commit

Permalink
[Feature] Support Segmenter (open-mmlab#955)
Browse files Browse the repository at this point in the history
* segmenter: add model

* update

* readme: update

* config: update

* segmenter: update readme

* segmenter: update

* segmenter: update

* segmenter: update

* configs: set checkpoint path to pretrain folder

* segmenter: modify vit-s/lin, remove data config

* rreadme: update

* configs: transfer from _base_ to segmenter

* configs: add 8x1 suffix

* configs: remove redundant lines

* configs: cleanup

* first attempt

* swipe CI error

* Update mmseg/models/decode_heads/__init__.py

Co-authored-by: Junjun2016 <hejunjun@sjtu.edu.cn>

* segmenter_linear: use fcn backbone

* segmenter_mask: update

* models: add segmenter vit

* decoders: yapf+remove unused imports

* apply precommit

* segmenter/linear_head: fix

* segmenter/linear_header: fix

* segmenter: fix mask transformer

* fix error

* segmenter/mask_head: use trunc_normal init

* refactor segmenter head

* Fetch upstream (#1)

* [Feature] Change options to cfg-option (open-mmlab#1129)

* [Feature] Change option to cfg-option

* add expire date and fix the docs

* modify docstring

* [Fix] Add <!-- [ABSTRACT] --> in metafile open-mmlab#1127

* [Fix] Fix correct num_classes of HRNet in LoveDA dataset open-mmlab#1136

* Bump to v0.20.1 (open-mmlab#1138)

* bump version 0.20.1

* bump version 0.20.1

* [Fix] revise --option to --options open-mmlab#1140

Co-authored-by: Rockey <41846794+RockeyCoss@users.noreply.github.com>
Co-authored-by: MengzhangLI <mcmong@pku.edu.cn>

* decode_head: switch from linear to fcn

* fix init list formatting

* configs: remove variants, keep only vit-s on ade

* align inference metric of vit-s-mask

* configs: add vit t/b/l

* Update mmseg/models/decode_heads/segmenter_mask_head.py

Co-authored-by: Miao Zheng <76149310+MeowZheng@users.noreply.github.com>

* Update mmseg/models/decode_heads/segmenter_mask_head.py

Co-authored-by: Miao Zheng <76149310+MeowZheng@users.noreply.github.com>

* Update mmseg/models/decode_heads/segmenter_mask_head.py

Co-authored-by: Miao Zheng <76149310+MeowZheng@users.noreply.github.com>

* Update mmseg/models/decode_heads/segmenter_mask_head.py

Co-authored-by: Miao Zheng <76149310+MeowZheng@users.noreply.github.com>

* Update mmseg/models/decode_heads/segmenter_mask_head.py

Co-authored-by: Miao Zheng <76149310+MeowZheng@users.noreply.github.com>

* model_converters: use torch instead of einops

* setup: remove einops

* segmenter_mask: fix missing imports

* add necessary imported init funtion

* segmenter/seg-l: set resolution to 640

* segmenter/seg-l: fix test size

* fix vitjax2mmseg

* add README and unittest

* fix unittest

* add docstring

* refactor config and add pretrained link

* fix typo

* add paper name in readme

* change segmenter config names

* fix typo in readme

* fix typos in readme

* fix segmenter typo

* fix segmenter typo

* delete redundant comma in config files

* delete redundant comma in config files

* fix convert script

* update lateset master version

Co-authored-by: MengzhangLI <mcmong@pku.edu.cn>
Co-authored-by: Junjun2016 <hejunjun@sjtu.edu.cn>
Co-authored-by: Rockey <41846794+RockeyCoss@users.noreply.github.com>
Co-authored-by: Miao Zheng <76149310+MeowZheng@users.noreply.github.com>
  • Loading branch information
5 people authored Jan 26, 2022
1 parent 80e8504 commit ee47c41
Show file tree
Hide file tree
Showing 16 changed files with 754 additions and 2 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ Supported methods:
- [x] [STDC (CVPR'2021)](configs/stdc)
- [x] [SETR (CVPR'2021)](configs/setr)
- [x] [DPT (ArXiv'2021)](configs/dpt)
- [x] [Segmenter (ICCV'2021)](configs/segmenter)
- [x] [SegFormer (NeurIPS'2021)](configs/segformer)

Supported datasets:
Expand Down
1 change: 1 addition & 0 deletions README_zh-CN.md
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ MMSegmentation 是一个基于 PyTorch 的语义分割开源工具箱。它是 O
- [x] [STDC (CVPR'2021)](configs/stdc)
- [x] [SETR (CVPR'2021)](configs/setr)
- [x] [DPT (ArXiv'2021)](configs/dpt)
- [x] [Segmenter (ICCV'2021)](configs/segmenter)
- [x] [SegFormer (NeurIPS'2021)](configs/segformer)

已支持的数据集:
Expand Down
35 changes: 35 additions & 0 deletions configs/_base_/models/segmenter_vit-b16_mask.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# model settings
backbone_norm_cfg = dict(type='LN', eps=1e-6, requires_grad=True)
model = dict(
type='EncoderDecoder',
pretrained='pretrain/vit_base_p16_384.pth',
backbone=dict(
type='VisionTransformer',
img_size=(512, 512),
patch_size=16,
in_channels=3,
embed_dims=768,
num_layers=12,
num_heads=12,
drop_path_rate=0.1,
attn_drop_rate=0.0,
drop_rate=0.0,
final_norm=True,
norm_cfg=backbone_norm_cfg,
with_cls_token=True,
interpolate_mode='bicubic',
),
decode_head=dict(
type='SegmenterMaskTransformerHead',
in_channels=768,
channels=768,
num_classes=150,
num_layers=2,
num_heads=12,
embed_dims=768,
dropout_ratio=0.0,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
),
test_cfg=dict(mode='slide', crop_size=(512, 512), stride=(480, 480)),
)
73 changes: 73 additions & 0 deletions configs/segmenter/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
# Segmenter

[Segmenter: Transformer for Semantic Segmentation](https://arxiv.org/abs/2105.05633)

## Introduction

<!-- [ALGORITHM] -->

<a href="https://github.com/rstrudel/segmenter">Official Repo</a>

<a href="https://github.com/open-mmlab/mmsegmentation/blob/v0.21.0/mmseg/models/decode_heads/segmenter_mask_head.py#L15">Code Snippet</a>

## Abstract

<!-- [ABSTRACT] -->

Image segmentation is often ambiguous at the level of individual image patches and requires contextual information to reach label consensus. In this paper we introduce Segmenter, a transformer model for semantic segmentation. In contrast to convolution-based methods, our approach allows to model global context already at the first layer and throughout the network. We build on the recent Vision Transformer (ViT) and extend it to semantic segmentation. To do so, we rely on the output embeddings corresponding to image patches and obtain class labels from these embeddings with a point-wise linear decoder or a mask transformer decoder. We leverage models pre-trained for image classification and show that we can fine-tune them on moderate sized datasets available for semantic segmentation. The linear decoder allows to obtain excellent results already, but the performance can be further improved by a mask transformer generating class masks. We conduct an extensive ablation study to show the impact of the different parameters, in particular the performance is better for large models and small patch sizes. Segmenter attains excellent results for semantic segmentation. It outperforms the state of the art on both ADE20K and Pascal Context datasets and is competitive on Cityscapes.

<!-- [IMAGE] -->
<div align=center>
<img src="https://user-images.githubusercontent.com/24582831/148507554-87eb80bd-02c7-4c31-b102-c6141e231ec8.png" width="70%"/>
</div>

```bibtex
@article{strudel2021Segmenter,
title={Segmenter: Transformer for Semantic Segmentation},
author={Strudel, Robin and Ricardo, Garcia, and Laptev, Ivan and Schmid, Cordelia},
journal={arXiv preprint arXiv:2105.05633},
year={2021}
}
```


## Usage

To use the pre-trained ViT model from [Segmenter](https://github.com/rstrudel/segmenter), it is necessary to convert keys.

We provide a script [`vitjax2mmseg.py`](../../tools/model_converters/vitjax2mmseg.py) in the tools directory to convert the key of models from [ViT-AugReg](https://github.com/rwightman/pytorch-image-models/blob/f55c22bebf9d8afc449d317a723231ef72e0d662/timm/models/vision_transformer.py#L54-L106) to MMSegmentation style.

```shell
python tools/model_converters/vitjax2mmseg.py ${PRETRAIN_PATH} ${STORE_PATH}
```

E.g.

```shell
python tools/model_converters/vitjax2mmseg.py \
Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz \
pretrain/vit_tiny_p16_384.pth
```

This script convert model from `PRETRAIN_PATH` and store the converted model in `STORE_PATH`.

In our default setting, pretrained models and their corresponding [ViT-AugReg](https://github.com/rwightman/pytorch-image-models/blob/f55c22bebf9d8afc449d317a723231ef72e0d662/timm/models/vision_transformer.py#L54-L106) models could be defined below:

| pretrained models | original models |
| ------ | -------- |
|vit_tiny_p16_384.pth | ['vit_tiny_patch16_384'](https://storage.googleapis.com/vit_models/augreg/Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz) |
|vit_small_p16_384.pth | ['vit_small_patch16_384'](https://storage.googleapis.com/vit_models/augreg/S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz) |
|vit_base_p16_384.pth | ['vit_base_patch16_384'](https://storage.googleapis.com/vit_models/augreg/B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_384.npz) |
|vit_large_p16_384.pth | ['vit_large_patch16_384'](https://storage.googleapis.com/vit_models/augreg/L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_384.npz) |

## Results and models

### ADE20K

| Method | Backbone | Crop Size | Lr schd | Mem (GB) | Inf time (fps) | mIoU | mIoU(ms+flip) | config | download |
| ------ | -------- | --------- | ---------- | ------- | -------- | --- | --- | -------------- | ----- |
| Segmenter-Mask | ViT-T_16 | 512x512 | 160000 | 1.21 | 27.98 | 39.99 | 40.83 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/segmenter/segmenter_vit-t_mask_8x1_512x512_160k_ade20k.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/segmenter/segmenter_vit-t_mask_8x1_512x512_160k_ade20k/segmenter_vit-t_mask_8x1_512x512_160k_ade20k_20220105_151706-ffcf7509.pth) &#124; [log](https://download.openmmlab.com/mmsegmentation/v0.5/segmenter/segmenter_vit-t_mask_8x1_512x512_160k_ade20k/segmenter_vit-t_mask_8x1_512x512_160k_ade20k_20220105_151706.log.json) |
| Segmenter-Linear | ViT-S_16 | 512x512 | 160000 | 1.78 | 28.07 | 45.75 | 46.82 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/segmenter/segmenter_vit-s_linear_8x1_512x512_160k_ade20k.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/segmenter/segmenter_vit-s_linear_8x1_512x512_160k_ade20k/segmenter_vit-s_linear_8x1_512x512_160k_ade20k_20220105_151713-39658c46.pth) &#124; [log](https://download.openmmlab.com/mmsegmentation/v0.5/segmenter/segmenter_vit-s_linear_8x1_512x512_160k_ade20k/segmenter_vit-s_linear_8x1_512x512_160k_ade20k_20220105_151713.log.json) |
| Segmenter-Mask | ViT-S_16 | 512x512 | 160000 | 2.03 | 24.80 | 46.19 | 47.85 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/segmenter/segmenter_vit-s_mask_8x1_512x512_160k_ade20k.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/segmenter/segmenter_vit-s_mask_8x1_512x512_160k_ade20k/segmenter_vit-s_mask_8x1_512x512_160k_ade20k_20220105_151706-511bb103.pth) &#124; [log](https://download.openmmlab.com/mmsegmentation/v0.5/segmenter/segmenter_vit-s_mask_8x1_512x512_160k_ade20k/segmenter_vit-s_mask_8x1_512x512_160k_ade20k_20220105_151706.log.json) |
| Segmenter-Mask | ViT-B_16 |512x512 | 160000 | 4.20 | 13.20 | 49.60 | 51.07 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/segmenter/segmenter_vit-b_mask_8x1_512x512_160k_ade20k.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/segmenter/segmenter_vit-b_mask_8x1_512x512_160k_ade20k/segmenter_vit-b_mask_8x1_512x512_160k_ade20k_20220105_151706-bc533b08.pth) &#124; [log](https://download.openmmlab.com/mmsegmentation/v0.5/segmenter/segmenter_vit-b_mask_8x1_512x512_160k_ade20k/segmenter_vit-b_mask_8x1_512x512_160k_ade20k_20220105_151706.log.json) |
| Segmenter-Mask | ViT-L_16 |640x640 | 160000 | 16.56 | 2.62 | 52.16 | 53.65 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/segmenter/segmenter_vit-l_mask_8x1_512x512_160k_ade20k.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/segmenter/segmenter_vit-l_mask_8x1_512x512_160k_ade20k/segmenter_vit-l_mask_8x1_512x512_160k_ade20k_20220105_162750-7ef345be.pth) &#124; [log](https://download.openmmlab.com/mmsegmentation/v0.5/segmenter/segmenter_vit-l_mask_8x1_512x512_160k_ade20k/segmenter_vit-l_mask_8x1_512x512_160k_ade20k_20220105_162750.log.json) |
125 changes: 125 additions & 0 deletions configs/segmenter/segmenter.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
Collections:
- Name: segmenter
Metadata:
Training Data:
- ADE20K
Paper:
URL: https://arxiv.org/abs/2105.05633
Title: 'Segmenter: Transformer for Semantic Segmentation'
README: configs/segmenter/README.md
Code:
URL: https://github.com/open-mmlab/mmsegmentation/blob/v0.21.0/mmseg/models/decode_heads/segmenter_mask_head.py#L15
Version: v0.21.0
Converted From:
Code: https://github.com/rstrudel/segmenter
Models:
- Name: segmenter_vit-t_mask_8x1_512x512_160k_ade20k
In Collection: segmenter
Metadata:
backbone: ViT-T_16
crop size: (512,512)
lr schd: 160000
inference time (ms/im):
- value: 35.74
hardware: V100
backend: PyTorch
batch size: 1
mode: FP32
resolution: (512,512)
Training Memory (GB): 1.21
Results:
- Task: Semantic Segmentation
Dataset: ADE20K
Metrics:
mIoU: 39.99
mIoU(ms+flip): 40.83
Config: configs/segmenter/segmenter_vit-t_mask_8x1_512x512_160k_ade20k.py
Weights: https://download.openmmlab.com/mmsegmentation/v0.5/segmenter/segmenter_vit-t_mask_8x1_512x512_160k_ade20k/segmenter_vit-t_mask_8x1_512x512_160k_ade20k_20220105_151706-ffcf7509.pth
- Name: segmenter_vit-s_linear_8x1_512x512_160k_ade20k
In Collection: segmenter
Metadata:
backbone: ViT-S_16
crop size: (512,512)
lr schd: 160000
inference time (ms/im):
- value: 35.63
hardware: V100
backend: PyTorch
batch size: 1
mode: FP32
resolution: (512,512)
Training Memory (GB): 1.78
Results:
- Task: Semantic Segmentation
Dataset: ADE20K
Metrics:
mIoU: 45.75
mIoU(ms+flip): 46.82
Config: configs/segmenter/segmenter_vit-s_linear_8x1_512x512_160k_ade20k.py
Weights: https://download.openmmlab.com/mmsegmentation/v0.5/segmenter/segmenter_vit-s_linear_8x1_512x512_160k_ade20k/segmenter_vit-s_linear_8x1_512x512_160k_ade20k_20220105_151713-39658c46.pth
- Name: segmenter_vit-s_mask_8x1_512x512_160k_ade20k
In Collection: segmenter
Metadata:
backbone: ViT-S_16
crop size: (512,512)
lr schd: 160000
inference time (ms/im):
- value: 40.32
hardware: V100
backend: PyTorch
batch size: 1
mode: FP32
resolution: (512,512)
Training Memory (GB): 2.03
Results:
- Task: Semantic Segmentation
Dataset: ADE20K
Metrics:
mIoU: 46.19
mIoU(ms+flip): 47.85
Config: configs/segmenter/segmenter_vit-s_mask_8x1_512x512_160k_ade20k.py
Weights: https://download.openmmlab.com/mmsegmentation/v0.5/segmenter/segmenter_vit-s_mask_8x1_512x512_160k_ade20k/segmenter_vit-s_mask_8x1_512x512_160k_ade20k_20220105_151706-511bb103.pth
- Name: segmenter_vit-b_mask_8x1_512x512_160k_ade20k
In Collection: segmenter
Metadata:
backbone: ViT-B_16
crop size: (512,512)
lr schd: 160000
inference time (ms/im):
- value: 75.76
hardware: V100
backend: PyTorch
batch size: 1
mode: FP32
resolution: (512,512)
Training Memory (GB): 4.2
Results:
- Task: Semantic Segmentation
Dataset: ADE20K
Metrics:
mIoU: 49.6
mIoU(ms+flip): 51.07
Config: configs/segmenter/segmenter_vit-b_mask_8x1_512x512_160k_ade20k.py
Weights: https://download.openmmlab.com/mmsegmentation/v0.5/segmenter/segmenter_vit-b_mask_8x1_512x512_160k_ade20k/segmenter_vit-b_mask_8x1_512x512_160k_ade20k_20220105_151706-bc533b08.pth
- Name: segmenter_vit-l_mask_8x1_512x512_160k_ade20k
In Collection: segmenter
Metadata:
backbone: ViT-L_16
crop size: (640,640)
lr schd: 160000
inference time (ms/im):
- value: 381.68
hardware: V100
backend: PyTorch
batch size: 1
mode: FP32
resolution: (640,640)
Training Memory (GB): 16.56
Results:
- Task: Semantic Segmentation
Dataset: ADE20K
Metrics:
mIoU: 52.16
mIoU(ms+flip): 53.65
Config: configs/segmenter/segmenter_vit-l_mask_8x1_512x512_160k_ade20k.py
Weights: https://download.openmmlab.com/mmsegmentation/v0.5/segmenter/segmenter_vit-l_mask_8x1_512x512_160k_ade20k/segmenter_vit-l_mask_8x1_512x512_160k_ade20k_20220105_162750-7ef345be.pth
43 changes: 43 additions & 0 deletions configs/segmenter/segmenter_vit-b_mask_8x1_512x512_160k_ade20k.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
_base_ = [
'../_base_/models/segmenter_vit-b16_mask.py',
'../_base_/datasets/ade20k.py', '../_base_/default_runtime.py',
'../_base_/schedules/schedule_160k.py'
]
optimizer = dict(lr=0.001, weight_decay=0.0)

img_norm_cfg = dict(
mean=[127.5, 127.5, 127.5], std=[127.5, 127.5, 127.5], to_rgb=True)
crop_size = (512, 512)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations', reduce_zero_label=True),
dict(type='Resize', img_scale=(2048, 512), ratio_range=(0.5, 2.0)),
dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
dict(type='RandomFlip', prob=0.5),
dict(type='PhotoMetricDistortion'),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255),
dict(type='DefaultFormatBundle'),
dict(type='Collect', keys=['img', 'gt_semantic_seg'])
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='MultiScaleFlipAug',
img_scale=(2048, 512),
# img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75],
flip=False,
transforms=[
dict(type='Resize', keep_ratio=True),
dict(type='RandomFlip'),
dict(type='Normalize', **img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img'])
])
]
data = dict(
# num_gpus: 8 -> batch_size: 8
samples_per_gpu=1,
train=dict(pipeline=train_pipeline),
val=dict(pipeline=test_pipeline),
test=dict(pipeline=test_pipeline))
60 changes: 60 additions & 0 deletions configs/segmenter/segmenter_vit-l_mask_8x1_512x512_160k_ade20k.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
_base_ = [
'../_base_/models/segmenter_vit-b16_mask.py',
'../_base_/datasets/ade20k.py', '../_base_/default_runtime.py',
'../_base_/schedules/schedule_160k.py'
]

model = dict(
pretrained='pretrain/vit_large_p16_384.pth',
backbone=dict(
type='VisionTransformer',
img_size=(640, 640),
embed_dims=1024,
num_layers=24,
num_heads=16),
decode_head=dict(
type='SegmenterMaskTransformerHead',
in_channels=1024,
channels=1024,
num_heads=16,
embed_dims=1024),
test_cfg=dict(mode='slide', crop_size=(640, 640), stride=(608, 608)))

optimizer = dict(lr=0.001, weight_decay=0.0)

img_norm_cfg = dict(
mean=[127.5, 127.5, 127.5], std=[127.5, 127.5, 127.5], to_rgb=True)
crop_size = (640, 640)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations', reduce_zero_label=True),
dict(type='Resize', img_scale=(2048, 640), ratio_range=(0.5, 2.0)),
dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
dict(type='RandomFlip', prob=0.5),
dict(type='PhotoMetricDistortion'),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255),
dict(type='DefaultFormatBundle'),
dict(type='Collect', keys=['img', 'gt_semantic_seg'])
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='MultiScaleFlipAug',
img_scale=(2048, 640),
# img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75],
flip=False,
transforms=[
dict(type='Resize', keep_ratio=True),
dict(type='RandomFlip'),
dict(type='Normalize', **img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img'])
])
]
data = dict(
# num_gpus: 8 -> batch_size: 8
samples_per_gpu=1,
train=dict(pipeline=train_pipeline),
val=dict(pipeline=test_pipeline),
test=dict(pipeline=test_pipeline))
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
_base_ = './segmenter_vit-s_mask_8x1_512x512_160k_ade20k.py'

model = dict(
decode_head=dict(
_delete_=True,
type='FCNHead',
in_channels=384,
channels=384,
num_convs=0,
dropout_ratio=0.0,
concat_input=False,
num_classes=150,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)))
Loading

0 comments on commit ee47c41

Please sign in to comment.