Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Add configs for AOT-GAN #681

Merged
merged 20 commits into from
May 30, 2022
Merged
188 changes: 188 additions & 0 deletions configs/inpainting/AOT-GAN/AOT-GAN_512x512_places.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
model = dict(
type='AOTInpaintor',
encdec=dict(
type='AOTEncoderDecoder',
encoder=dict(type='AOTEncoder'),
decoder=dict(type='AOTDecoder'),
dilation_neck=dict(
type='AOTBlockNeck', dilation_rates='1+2+4+8', num_aotblock=8)),
disc=dict(
type='SoftMaskPatchDiscriminator',
in_channels=3,
base_channels=64,
num_conv=3,
with_spectral_norm=True,
),
loss_gan=dict(
type='GANLoss',
gan_type='smgan',
loss_weight=0.01,
),
loss_composed_percep=dict(
type='PerceptualLoss',
vgg_type='vgg19',
layer_weights_perceptual={
'1': 1.,
'6': 1.,
'11': 1.,
'20': 1.,
'29': 1.,
},
layer_weights_style={
'8': 1.,
'17': 1.,
'26': 1.,
'31': 1.,
},
perceptual_weight=0.1,
style_weight=250),
loss_out_percep=True,
loss_l1_valid=dict(
type='L1Loss',
loss_weight=1.,
),
pretrained=None)

train_cfg = dict(disc_step=1)
test_cfg = dict(metrics=['l1', 'psnr', 'ssim'])

dataset_type = 'ImgInpaintingDataset'
input_shape = (512, 512)

mask_root = 'data/masks'

train_pipeline = [
dict(type='LoadImageFromFile', key='gt_img', channel_order='rgb'),
dict(
type='LoadMask',
mask_mode='set',
mask_config=dict(
mask_list_file=f'{mask_root}/train_places_mask_list.txt',
prefix=mask_root,
io_backend='disk',
flag='unchanged',
file_client_kwargs=dict())),
dict(
type='RandomResizedCrop',
keys=['gt_img'],
crop_size=input_shape,
),
dict(type='Flip', keys=['gt_img', 'mask'], direction='horizontal'),
dict(
type='Resize',
keys=['mask'],
scale=input_shape,
keep_ratio=False,
interpolation='nearest'),
dict(type='RandomRotation', keys=['mask'], degrees=(0.0, 45.0)),
dict(
type='ColorJitter',
keys=['gt_img'],
brightness=0.5,
contrast=0.5,
saturation=0.5,
hue=0.5),
dict(
type='Normalize',
keys=['gt_img'],
mean=[127.5] * 3,
std=[127.5] * 3,
to_rgb=False),
dict(type='GetMaskedImage'),
dict(
type='Collect',
keys=['gt_img', 'masked_img', 'mask'],
meta_keys=['gt_img_path']),
dict(type='ImageToTensor', keys=['gt_img', 'masked_img', 'mask'])
]

test_pipeline = [
dict(type='LoadImageFromFile', key='gt_img', channel_order='rgb'),
dict(
type='LoadMask',
mask_mode='set',
mask_config=dict(
mask_list_file=f'{mask_root}/val_places_mask_list.txt',
prefix=mask_root,
io_backend='disk',
flag='unchanged',
file_client_kwargs=dict())),
dict(
type='Crop',
keys=['gt_img'],
crop_size=(512, 512),
random_crop=False,
),
dict(
type='Normalize',
keys=['gt_img'],
mean=[127.5] * 3,
std=[127.5] * 3,
to_rgb=True),
dict(type='GetMaskedImage'),
dict(
type='Collect',
keys=['gt_img', 'masked_img', 'mask'],
meta_keys=['gt_img_path']),
dict(type='ImageToTensor', keys=['gt_img', 'masked_img', 'mask'])
]

data_root = 'data/places365'

data = dict(
workers_per_gpu=4,
train_dataloader=dict(samples_per_gpu=12, drop_last=True),
val_dataloader=dict(samples_per_gpu=1),
test_dataloader=dict(samples_per_gpu=1),
train=dict(
type=dataset_type,
ann_file=f'{data_root}/train_places_img_list.txt',
data_prefix=data_root,
pipeline=train_pipeline,
test_mode=False),
val=dict(
type=dataset_type,
ann_file=f'{data_root}/val_places_img_list.txt',
data_prefix=data_root,
pipeline=test_pipeline,
test_mode=True),
test=dict(
type=dataset_type,
ann_file=(f'{data_root}/val_places_img_list.txt'),
data_prefix=data_root,
pipeline=test_pipeline,
test_mode=True))

optimizers = dict(
generator=dict(type='Adam', lr=0.0001, betas=(0.0, 0.9)),
disc=dict(type='Adam', lr=0.0001, betas=(0.0, 0.9)))

lr_config = dict(policy='Fixed', by_epoch=False)

checkpoint_config = dict(by_epoch=False, interval=10000)
log_config = dict(
interval=100,
hooks=[
dict(type='TextLoggerHook', by_epoch=False),
dict(type='TensorboardLoggerHook'),
dict(type='PaviLoggerHook', init_kwargs=dict(project='mmedit'))
])

visual_config = dict(
type='VisualizationHook',
output_dir='visual',
interval=1000,
res_name_list=['gt_img', 'masked_img', 'fake_res', 'fake_img'],
)

evaluation = dict(interval=50000)

total_iters = 500002
dist_params = dict(backend='nccl')
log_level = 'INFO'
work_dir = './workdirs/aotgan_places'
load_from = None
resume_from = None
workflow = [('train', 10000)]
exp_name = 'AOT-GAN_512x512_places'
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have a naming convention for the models. Would you mind use the convention below?

{model}_[model setting]_{backbone}_[refiner]_[norm setting]_[misc]_[gpu x batch_per_gpu]_{schedule}_{dataset}

An example can also be found here

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @quincylin1 , would you mind also providing the weights converted from the official weights?

The checkpoint can be named as follows:

{config_name}_{date}-{hash}.pth

For the hash, you can use tools/publish_model.py to generate.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have a naming convention for the models. Would you mind use the convention below?

{model}_[model setting]_{backbone}_[refiner]_[norm setting]_[misc]_[gpu x batch_per_gpu]_{schedule}_{dataset}

An example can also be found here

image

@ckkelvinchan just wondering if this name is ok for this config: AOT-GAN_512x512_4x12_places.py

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I think it works.

find_unused_parameters = False