Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feat]: Add Instance-aware Image Colorization #1173

Closed
wants to merge 24 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
8d47540
[feat]: add Instance-aware Image Colorization
ruoningYu Sep 29, 2022
ca983ce
refactor folder
zengyh1900 Oct 10, 2022
d735452
refactor model implementation
zengyh1900 Oct 10, 2022
e319851
refactor demo
zengyh1900 Oct 10, 2022
e22a343
[Enhancement]: add inference module for instance-aware Image Coloriza…
ruoningYu Oct 12, 2022
75bf04f
Merge branch 'dev-1.x' of github.com:open-mmlab/mmediting into add-insta
ruoningYu Oct 12, 2022
c9b1d5b
[Enhancement]: add inference module for Insatance-aware Image Coloriz…
ruoningYu Oct 12, 2022
1aaf2c2
[Fix]: fix inference module of Instance-aware Image Colorization
ruoningYu Oct 15, 2022
5719356
[Refactor]: refactor get_maskrcnn_bbox.py and inst_colorization.py
ruoningYu Oct 16, 2022
1d02652
[Enhancement]: add unit test of Instance-aware Image Colorization
ruoningYu Oct 16, 2022
7a475fe
update configs
zengyh1900 Oct 19, 2022
524f6f5
refactor networks
zengyh1900 Oct 19, 2022
60f19d6
fix siggraphgenerator, i.e., colorization_net
zengyh1900 Oct 20, 2022
1fc4b74
rfactor networks
zengyh1900 Oct 20, 2022
9825d12
fix loading weights
zengyh1900 Oct 25, 2022
ffba00a
refactoring model forward
zengyh1900 Oct 25, 2022
a2fc3ce
refactor packedit
zengyh1900 Oct 26, 2022
ed48fd8
fix color rendering
zengyh1900 Oct 27, 2022
e61cc99
fix inference
zengyh1900 Oct 27, 2022
20d14cb
remove undesired files
zengyh1900 Oct 27, 2022
ceafa26
Merge branch 'dev-1.x' into fix-add-insta
zengyh1900 Oct 27, 2022
d6b0c40
clear code
zengyh1900 Oct 27, 2022
67631ee
[Doc]: update docstring if instance-aware image colorization
ruoningYu Oct 29, 2022
0201cb6
[Enhancement]: add unit test of instance_aware_colorization
ruoningYu Oct 30, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 75 additions & 0 deletions configs/inst_colorization/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
# Instance-aware Image Colorization (CVPR'2020)

> [Instance-Aware Image Colorization](https://openaccess.thecvf.com/content_CVPR_2020/html/Su_Instance-Aware_Image_Colorization_CVPR_2020_paper.html)

> **Task**: Colorization

<!-- [ALGORITHM] -->

## Abstract

<!-- [ABSTRACT] -->

Image colorization is inherently an ill-posed problem with multi-modal uncertainty. Previous methods leverage the deep neural network to map input grayscale images to plausible color outputs directly. Although these learning-based methods have shown impressive performance, they usually fail on the input images that contain multiple objects. The leading cause is that existing models perform learning and colorization on the entire image. In the absence of a clear figure-ground separation, these models cannot effectively locate and learn meaningful object-level semantics. In this paper, we propose a method for achieving instance-aware colorization. Our network architecture leverages an off-the-shelf object detector to obtain cropped object images and uses an instance colorization network to extract object-level features. We use a similar network to extract the full-image features and apply a fusion module to full object-level and image-level features to predict the final colors. Both colorization networks and fusion modules are learned from a large-scale dataset. Experimental results show that our work outperforms existing methods on different quality metrics and achieves state-of-the-art performance on image colorization.

## Results and models

## Quick Start

**Train**

<details>
<summary>Train Instructions</summary>

You can use the following commands to train a model with cpu or single/multiple GPUs.

```shell
# CPU train
CUDA_VISIBLE_DEVICES=-1 python tools/train.py configs/inst_colorization/inst-colorizatioon_cocostuff_full_256x256.py

# single-gpu train
python tools/train.py configs/inst_colorization/inst-colorizatioon_cocostuff_full_256x256.py

# multi-gpu train
./tools/dist_train.sh configs/inst_colorization/inst-colorizatioon_cocostuff_full_256x256.py 8
```

For more details, you can refer to **Train a model** part in [train_test.md](/docs/en/user_guides/train_test.md#Train-a-model-in-MMEditing).

</details>

**Test**

<details>
<summary>Test Instructions</summary>

You can use the following commands to test a model with cpu or single/multiple GPUs.

```shell
# CPU test
CUDA_VISIBLE_DEVICES=-1 python demo/colorization_demo.py configs/inst_colorization//inst-colorizatioon_cocostuff_full_256x256.py ../checkpoints/instance_aware_cocostuff.pth

# single-gpu demo
python demo/colorization_demo.py configs/inst_colorization/inst-colorizatioon_cocostuff_full_256x256.py ../checkpoints/instance_aware_cocostuff.pth

# multi-gpu test
./tools/dist_test.sh configs/inst_colorization/inst-colorizatioon_cocostuff_full_256x256.py ../checkpoints/instance_aware_cocostuff.pth 8
```

For more details, you can refer to **Test a pre-trained model** part in [train_test.md](/docs/en/user_guides/train_test.md#Test-a-pre-trained-model-in-MMEditing).

</details>

<details>
<summary align="right">Instance-aware Image Colorization (CVPR'2020)</summary>

```bibtex
@inproceedings{Su-CVPR-2020,
author = {Su, Jheng-Wei and Chu, Hung-Kuo and Huang, Jia-Bin},
title = {Instance-aware Image Colorization},
booktitle = {IEEE Conference on Computer Vision and Pattern Recognition (CVPR)},
year = {2020}
}
```

</details>
75 changes: 75 additions & 0 deletions configs/inst_colorization/README_zh-CN.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
# Instance-aware Image Colorization (CVPR'2020)

> [Instance-Aware Image Colorization](https://openaccess.thecvf.com/content_CVPR_2020/html/Su_Instance-Aware_Image_Colorization_CVPR_2020_paper.html)

> **任务**: 图像上色

<!-- [ALGORITHM] -->

## 摘要

<!-- [ABSTRACT] -->

Image colorization is inherently an ill-posed problem with multi-modal uncertainty. Previous methods leverage the deep neural network to map input grayscale images to plausible color outputs directly. Although these learning-based methods have shown impressive performance, they usually fail on the input images that contain multiple objects. The leading cause is that existing models perform learning and colorization on the entire image. In the absence of a clear figure-ground separation, these models cannot effectively locate and learn meaningful object-level semantics. In this paper, we propose a method for achieving instance-aware colorization. Our network architecture leverages an off-the-shelf object detector to obtain cropped object images and uses an instance colorization network to extract object-level features. We use a similar network to extract the full-image features and apply a fusion module to full object-level and image-level features to predict the final colors. Both colorization networks and fusion modules are learned from a large-scale dataset. Experimental results show that our work outperforms existing methods on different quality metrics and achieves state-of-the-art performance on image colorization.

## 结果和模型

## 快速开始

**训练**

<details>
<summary>训练说明</summary>

您可以使用以下命令来训练模型。

```shell
# CPU上训练
CUDA_VISIBLE_DEVICES=-1 python tools/train.py configs/insta/inst-colorizatioon_cocostuff_full_256x256.py

# 单个GPU上训练
python tools/train.py configs/insta/inst-colorizatioon_cocostuff_full_256x256.py

# 多个GPU上训练
./tools/dist_train.sh configs/insta/inst-colorizatioon_cocostuff_full_256x256.py 8
```

更多细节可以参考 [train_test.md](/docs/zh_cn/user_guides/train_test.md) 中的 **Train a model** 部分。

</details>

**测试**

<details>
<summary>测试说明</summary>

您可以使用以下命令来测试模型。

```shell
# CPU上测试
CUDA_VISIBLE_DEVICES=-1 python demo/colorization_demo.py configs/inst_colorization/inst-colorizatioon_cocostuff_full_256x256.py ../checkpoints/instance_aware_cocostuff.pth

# 单个GPU上 demo
python demo/colorization_demo.py configs/inst_colorization/inst-colorizatioon_cocostuff_256x256.py work_dirs/checkpoints/instance_aware_cocostuff.pth work_dirs/colorization_example.jpg work_dirs/output_example.png

# 多个GPU上测试
./tools/dist_test.sh configs/inst_colorization/inst-colorizatioon_cocostuff_full_256x256.py ../checkpoints/instance_aware_cocostuff.pth 8
```

更多细节可以参考 [train_test.md](/docs/zh_cn/user_guides/train_test.md) 中的 **Test a pre-trained model** 部分。

</details>

<details>
<summary align="right">Instance-aware Image Colorization (CVPR'2020)</summary>

```bibtex
@inproceedings{Su-CVPR-2020,
author = {Su, Jheng-Wei and Chu, Hung-Kuo and Huang, Jia-Bin},
title = {Instance-aware Image Colorization},
booktitle = {IEEE Conference on Computer Vision and Pattern Recognition (CVPR)},
year = {2020}
}
```

</details>
58 changes: 58 additions & 0 deletions configs/inst_colorization/inst-colorizatioon_cocostuff_256x256.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
_base_ = ['../_base_/default_runtime.py']

exp_name = 'inst-colorization_cocostuff_256x256'
save_dir = './'
work_dir = '..'

stage = 'full'

model = dict(
type='InstColorization',
data_preprocessor=dict(
type='EditDataPreprocessor',
mean=[127.5],
std=[127.5],
),
image_model=dict(
type='ColorizationNet', input_nc=4, output_nc=2, norm_type='batch'),
instance_model=dict(
type='ColorizationNet', input_nc=4, output_nc=2, norm_type='batch'),
fusion_model=dict(
type='FusionNet', input_nc=4, output_nc=2, norm_type='batch'),
color_data_opt=dict(
ab_thresh=0,
p=1.0,
sample_PS=[
1,
2,
3,
4,
5,
6,
7,
8,
9,
],
ab_norm=110,
ab_max=110.,
ab_quant=10.,
l_norm=100.,
l_cent=50.,
mask_cent=0.5),
which_direction='AtoB',
loss=dict(type='HuberLoss', delta=.01))

# yapf: disable
test_pipeline = [
dict(type='LoadImageFromFile', key='img', channel_order='rgb'),
dict(
type='InstanceCrop',
config_file='COCO-InstanceSegmentation/mask_rcnn_X_101_32x8d_FPN_3x.yaml', # noqa
finesize=256),
dict(
type='Resize',
keys=['img', 'cropped_img'],
scale=(256, 256),
keep_ratio=False),
dict(type='PackEditInputs'),
]
9 changes: 9 additions & 0 deletions configs/inst_colorization/metafile.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
Collections:
- Metadata:
Architecture:
- Instance-aware Image Colorization
Name: Instance-aware Image Colorization
Paper:
- https://openaccess.thecvf.com/content_CVPR_2020/html/Su_Instance-Aware_Image_Colorization_CVPR_2020_paper.html
README: configs/inst_colorization/README.md
Models: []
43 changes: 43 additions & 0 deletions demo/colorization_demo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# Copyright (c) OpenMMLab. All rights reserved.
import argparse

import mmcv
import torch

from mmedit.apis import colorization_inference, init_model
from mmedit.utils import modify_args, tensor2img


def parse_args():
modify_args()
parser = argparse.ArgumentParser(description='Colorzation demo')
parser.add_argument('config', help='test config file path')
parser.add_argument('checkpoints', help='checkpoints file path')
parser.add_argument('img_path', help='path to input image file')
parser.add_argument('save_path', help='path to save generation result')
parser.add_argument(
'--imshow', action='store_true', help='whether show image with opencv')
parser.add_argument('--device', type=int, default=0, help='CUDA device id')
args = parser.parse_args()
return args


def main():
args = parse_args()

if args.device < 0 or not torch.cuda.is_available():
device = torch.device('cpu')
else:
device = torch.device('cuda', args.device)

model = init_model(args.config, args.checkpoints, device=device)
output = colorization_inference(model, args.img_path)
result = tensor2img(output)
mmcv.imwrite(result, args.save_path)

if args.imshow:
mmcv.imshow(output, 'predicted generation result')


if __name__ == '__main__':
main()
19 changes: 7 additions & 12 deletions mmedit/apis/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .colorization_inference import colorization_inference
from .gan_inference import sample_conditional_model, sample_unconditional_model
from .inference import delete_cfg, init_model, set_random_seed
from .inpainting_inference import inpainting_inference
Expand All @@ -10,16 +11,10 @@
from .video_interpolation_inference import video_interpolation_inference

__all__ = [
'init_model',
'delete_cfg',
'set_random_seed',
'matting_inference',
'inpainting_inference',
'restoration_inference',
'restoration_video_inference',
'restoration_face_inference',
'video_interpolation_inference',
'sample_conditional_model',
'sample_unconditional_model',
'sample_img2img_model',
'init_model', 'delete_cfg', 'set_random_seed', 'matting_inference',
'inpainting_inference', 'restoration_inference',
'restoration_video_inference', 'restoration_face_inference',
'video_interpolation_inference', 'sample_conditional_model',
'sample_unconditional_model', 'sample_img2img_model',
'colorization_inference'
]
50 changes: 50 additions & 0 deletions mmedit/apis/colorization_inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from mmengine.dataset import Compose
from mmengine.dataset.utils import default_collate as collate
from torch.nn.parallel import scatter


def colorization_inference(model, img):
"""Inference image with the model.

Args:
model (nn.Module): The loaded model.
img (str): Image file path.

Returns:
Tensor: The predicted colorization result.
"""
device = next(model.parameters()).device

# build the data pipeline
test_pipeline = Compose(model.cfg.test_pipeline)
# prepare data
data = dict(img_path=img)
_data = test_pipeline(data)
data = dict()
data['inputs'] = _data['inputs'] / 255.0
data = collate([data])
data['data_samples'] = [_data['data_samples']]
if 'cuda' in str(device):
data = scatter(data, [device])[0]
data['data_samples'][0].cropped_img.data = scatter(
data['data_samples'][0].cropped_img.data, [device])[0] / 255.0

data['data_samples'][0].box_info.data = scatter(
data['data_samples'][0].box_info.data, [device])[0]

data['data_samples'][0].box_info_2x.data = scatter(
data['data_samples'][0].box_info_2x.data, [device])[0]

data['data_samples'][0].box_info_4x.data = scatter(
data['data_samples'][0].box_info_4x.data, [device])[0]

data['data_samples'][0].box_info_8x.data = scatter(
data['data_samples'][0].box_info_8x.data, [device])[0]

# forward the model
with torch.no_grad():
result = model(mode='tensor', **data)

return result
12 changes: 3 additions & 9 deletions mmedit/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,7 @@
from .unpaired_image_dataset import UnpairedImageDataset

__all__ = [
'AdobeComp1kDataset',
'BasicImageDataset',
'BasicFramesDataset',
'BasicConditionalDataset',
'UnpairedImageDataset',
'PairedImageDataset',
'ImageNet',
'CIFAR10',
'GrowScaleImgDataset',
'AdobeComp1kDataset', 'BasicImageDataset', 'BasicFramesDataset',
'BasicConditionalDataset', 'UnpairedImageDataset', 'PairedImageDataset',
'ImageNet', 'CIFAR10', 'GrowScaleImgDataset'
]
3 changes: 2 additions & 1 deletion mmedit/datasets/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
GenerateFrameIndiceswithPadding,
GenerateSegmentIndices)
from .get_masked_image import GetMaskedImage
from .get_maskrcnn_bbox import InstanceCrop
from .loading import (GetSpatialDiscountMask, LoadImageFromFile, LoadMask,
LoadPairedImageFromFile)
from .matlab_like_resize import MATLABLikeResize
Expand Down Expand Up @@ -45,5 +46,5 @@
'GenerateSoftSeg', 'FormatTrimap', 'TransformTrimap', 'GenerateTrimap',
'GenerateTrimapWithDistTransform', 'CompositeFg', 'RandomLoadResizeBg',
'MergeFgAndBg', 'PerturbBg', 'RandomJitter', 'LoadPairedImageFromFile',
'CenterCropLongEdge', 'RandomCropLongEdge', 'NumpyPad'
'CenterCropLongEdge', 'RandomCropLongEdge', 'NumpyPad', 'InstanceCrop'
]
Loading