From d5586d375ab52b052e1e7ea8008a281fa587d3f2 Mon Sep 17 00:00:00 2001 From: rangoliu Date: Fri, 10 Nov 2023 13:03:37 +0800 Subject: [PATCH 1/2] [Fix] fix best practice (#2063) * fix best practice * fix --- README.md | 6 ++++-- README_zh-CN.md | 6 ++++-- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index db5e1adea..757f10ea0 100644 --- a/README.md +++ b/README.md @@ -140,8 +140,6 @@ Currently, MMagic support multiple image and video generation/editing tasks. https://user-images.githubusercontent.com/49083766/233564593-7d3d48ed-e843-4432-b610-35e3d257765c.mp4 -The best practice on our main branch works with **Python 3.8+** and **PyTorch 1.10+**. - ### ✨ Major features - **State of the Art Models** @@ -156,6 +154,10 @@ The best practice on our main branch works with **Python 3.8+** and **PyTorch 1. By using MMEngine and MMCV of OpenMMLab 2.0 framework, MMagic decompose the editing framework into different modules and one can easily construct a customized editor framework by combining different modules. We can define the training process just like playing with Legos and provide rich components and strategies. In MMagic, you can complete controls on the training process with different levels of APIs. With the support of [MMSeparateDistributedDataParallel](https://github.com/open-mmlab/mmengine/blob/main/mmengine/model/wrappers/seperate_distributed.py), distributed training for dynamic architectures can be easily implemented. +### ✨ Best Practice + +- The best practice on our main branch works with **Python 3.9+** and **PyTorch 2.0+**. +

🔝Back to Table of Contents

## 🙌 Contributing diff --git a/README_zh-CN.md b/README_zh-CN.md index f5c947f21..17f0c7188 100644 --- a/README_zh-CN.md +++ b/README_zh-CN.md @@ -138,8 +138,6 @@ MMagic 是基于 PyTorch 的图像&视频编辑和生成开源工具箱。是 [O https://user-images.githubusercontent.com/49083766/233564593-7d3d48ed-e843-4432-b610-35e3d257765c.mp4 -主分支代码的最佳实践基于 **Python 3.8+** 和 **PyTorch 1.10+** 。 - ### ✨ 主要特性 - **SOTA 算法** @@ -154,6 +152,10 @@ https://user-images.githubusercontent.com/49083766/233564593-7d3d48ed-e843-4432- 通过 OpenMMLab 2.0 框架的 MMEngine 和 MMCV, MMagic 将编辑框架分解为不同的组件,并且可以通过组合不同的模块轻松地构建自定义的编辑器模型。我们可以像搭建“乐高”一样定义训练流程,提供丰富的组件和策略。在 MMagic 中,你可以使用不同的 APIs 完全控制训练流程。得益于 [MMSeparateDistributedDataParallel](https://github.com/open-mmlab/mmengine/blob/main/mmengine/model/wrappers/seperate_distributed.py), 动态模型结构的分布式训练可以轻松实现。 +### ✨ 最佳实践 + +- 主分支代码的最佳实践基于 **Python 3.9+** 和 **PyTorch 2.0+** 。 +

🔝返回目录

## 🙌 参与贡献 From c2f8f3ad3907666bf0673eb778121dae4ba01924 Mon Sep 17 00:00:00 2001 From: rangoliu Date: Fri, 10 Nov 2023 13:16:44 +0800 Subject: [PATCH 2/2] [Enhance] add new config for _base_ dir (#2053) * add new config * fix lint * add gitkeep --- .../_base_/datasets/basicvsr_test_config.py | 192 ++++++++++++++++++ mmagic/configs/_base_/datasets/celeba.py | 48 +++++ .../configs/_base_/datasets/cifar10_nopad.py | 45 ++++ mmagic/configs/_base_/datasets/comp1k.py | 45 ++++ .../deblurring-defocus_test_config.py | 106 ++++++++++ .../datasets/deblurring-motion_test_config.py | 106 ++++++++++ .../datasets/decompression_test_config.py | 75 +++++++ .../denoising-gaussian_color_test_config.py | 113 +++++++++++ .../denoising-gaussian_gray_test_config.py | 97 +++++++++ .../datasets/denoising-real_test_config.py | 69 +++++++ .../_base_/datasets/deraining_test_config.py | 125 ++++++++++++ .../grow_scale_imgs_ffhq_styleganv1.py | 51 +++++ .../configs/_base_/datasets/imagenet_128.py | 52 +++++ .../configs/_base_/datasets/imagenet_256.py | 52 +++++ .../_base_/datasets/liif_test_config.py | 95 +++++++++ .../configs/_base_/datasets/lsun_stylegan.py | 47 +++++ .../datasets/paired_imgs_256x256_crop.py | 108 ++++++++++ mmagic/configs/_base_/datasets/places.py | 50 +++++ .../_base_/datasets/sisr_x2_test_config.py | 98 +++++++++ .../_base_/datasets/sisr_x3_test_config.py | 97 +++++++++ .../_base_/datasets/sisr_x4_test_config.py | 97 +++++++++ .../_base_/datasets/tdan_test_config.py | 130 ++++++++++++ .../datasets/unconditional_imgs_128x128.py | 47 +++++ .../datasets/unconditional_imgs_64x64.py | 46 +++++ .../_base_/datasets/unpaired_imgs_256x256.py | 103 ++++++++++ .../configs/_base_/inpaint_default_runtime.py | 43 ++++ .../configs/_base_/matting_default_runtime.py | 43 ++++ mmagic/configs/_base_/models/base_cyclegan.py | 32 +++ .../configs/_base_/models/base_deepfillv1.py | 106 ++++++++++ .../configs/_base_/models/base_deepfillv2.py | 113 +++++++++++ mmagic/configs/_base_/models/base_edvr.py | 128 ++++++++++++ mmagic/configs/_base_/models/base_gl.py | 62 ++++++ mmagic/configs/_base_/models/base_glean.py | 58 ++++++ mmagic/configs/_base_/models/base_liif.py | 135 ++++++++++++ mmagic/configs/_base_/models/base_pconv.py | 60 ++++++ mmagic/configs/_base_/models/base_pix2pix.py | 32 +++ .../configs/_base_/models/base_styleganv1.py | 11 + .../configs/_base_/models/base_styleganv2.py | 38 ++++ mmagic/configs/_base_/models/base_tof.py | 120 +++++++++++ .../_base_/models/dcgan/base_dcgan_128x128.py | 16 ++ .../_base_/models/dcgan/base_dcgan_64x64.py | 16 ++ .../_base_/models/sagan/base_sagan_128x128.py | 26 +++ .../_base_/models/sagan/base_sagan_32x32.py | 28 +++ .../sngan_proj/base_sngan_proj_128x128.py | 14 ++ .../sngan_proj/base_sngan_proj_32x32.py | 14 ++ mmagic/configs/_base_/schedules/.gitkeep | 0 46 files changed, 3189 insertions(+) create mode 100644 mmagic/configs/_base_/datasets/basicvsr_test_config.py create mode 100644 mmagic/configs/_base_/datasets/celeba.py create mode 100644 mmagic/configs/_base_/datasets/cifar10_nopad.py create mode 100644 mmagic/configs/_base_/datasets/comp1k.py create mode 100644 mmagic/configs/_base_/datasets/deblurring-defocus_test_config.py create mode 100644 mmagic/configs/_base_/datasets/deblurring-motion_test_config.py create mode 100644 mmagic/configs/_base_/datasets/decompression_test_config.py create mode 100644 mmagic/configs/_base_/datasets/denoising-gaussian_color_test_config.py create mode 100644 mmagic/configs/_base_/datasets/denoising-gaussian_gray_test_config.py create mode 100644 mmagic/configs/_base_/datasets/denoising-real_test_config.py create mode 100644 mmagic/configs/_base_/datasets/deraining_test_config.py create mode 100644 mmagic/configs/_base_/datasets/grow_scale_imgs_ffhq_styleganv1.py create mode 100644 mmagic/configs/_base_/datasets/imagenet_128.py create mode 100644 mmagic/configs/_base_/datasets/imagenet_256.py create mode 100644 mmagic/configs/_base_/datasets/liif_test_config.py create mode 100644 mmagic/configs/_base_/datasets/lsun_stylegan.py create mode 100644 mmagic/configs/_base_/datasets/paired_imgs_256x256_crop.py create mode 100644 mmagic/configs/_base_/datasets/places.py create mode 100644 mmagic/configs/_base_/datasets/sisr_x2_test_config.py create mode 100644 mmagic/configs/_base_/datasets/sisr_x3_test_config.py create mode 100644 mmagic/configs/_base_/datasets/sisr_x4_test_config.py create mode 100644 mmagic/configs/_base_/datasets/tdan_test_config.py create mode 100644 mmagic/configs/_base_/datasets/unconditional_imgs_128x128.py create mode 100644 mmagic/configs/_base_/datasets/unconditional_imgs_64x64.py create mode 100644 mmagic/configs/_base_/datasets/unpaired_imgs_256x256.py create mode 100644 mmagic/configs/_base_/inpaint_default_runtime.py create mode 100644 mmagic/configs/_base_/matting_default_runtime.py create mode 100644 mmagic/configs/_base_/models/base_cyclegan.py create mode 100644 mmagic/configs/_base_/models/base_deepfillv1.py create mode 100644 mmagic/configs/_base_/models/base_deepfillv2.py create mode 100644 mmagic/configs/_base_/models/base_edvr.py create mode 100644 mmagic/configs/_base_/models/base_gl.py create mode 100644 mmagic/configs/_base_/models/base_glean.py create mode 100644 mmagic/configs/_base_/models/base_liif.py create mode 100644 mmagic/configs/_base_/models/base_pconv.py create mode 100644 mmagic/configs/_base_/models/base_pix2pix.py create mode 100644 mmagic/configs/_base_/models/base_styleganv1.py create mode 100644 mmagic/configs/_base_/models/base_styleganv2.py create mode 100644 mmagic/configs/_base_/models/base_tof.py create mode 100644 mmagic/configs/_base_/models/dcgan/base_dcgan_128x128.py create mode 100644 mmagic/configs/_base_/models/dcgan/base_dcgan_64x64.py create mode 100644 mmagic/configs/_base_/models/sagan/base_sagan_128x128.py create mode 100644 mmagic/configs/_base_/models/sagan/base_sagan_32x32.py create mode 100644 mmagic/configs/_base_/models/sngan_proj/base_sngan_proj_128x128.py create mode 100644 mmagic/configs/_base_/models/sngan_proj/base_sngan_proj_32x32.py create mode 100644 mmagic/configs/_base_/schedules/.gitkeep diff --git a/mmagic/configs/_base_/datasets/basicvsr_test_config.py b/mmagic/configs/_base_/datasets/basicvsr_test_config.py new file mode 100644 index 000000000..01367f5c4 --- /dev/null +++ b/mmagic/configs/_base_/datasets/basicvsr_test_config.py @@ -0,0 +1,192 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmengine.dataset import DefaultSampler + +from mmagic.datasets import BasicFramesDataset +from mmagic.datasets.transforms import (GenerateSegmentIndices, + LoadImageFromFile, MirrorSequence, + PackInputs) +from mmagic.engine.runner import MultiTestLoop +from mmagic.evaluation import PSNR, SSIM + +# configs for REDS4 +reds_data_root = 'data/REDS' + +reds_pipeline = [ + dict(type=GenerateSegmentIndices, interval_list=[1]), + dict(type=LoadImageFromFile, key='img', channel_order='rgb'), + dict(type=LoadImageFromFile, key='gt', channel_order='rgb'), + dict(type=PackInputs) +] + +reds_dataloader = dict( + num_workers=1, + batch_size=1, + persistent_workers=False, + sampler=dict(type=DefaultSampler, shuffle=False), + dataset=dict( + type=BasicFramesDataset, + metainfo=dict(dataset_type='reds_reds4', task_name='vsr'), + data_root=reds_data_root, + data_prefix=dict(img='train_sharp_bicubic/X4', gt='train_sharp'), + ann_file='meta_info_reds4_val.txt', + depth=1, + num_input_frames=100, + fixed_seq_len=100, + pipeline=reds_pipeline)) + +reds_evaluator = [ + dict(type=PSNR, prefix='REDS4-BIx4-RGB'), + dict(type=SSIM, prefix='REDS4-BIx4-RGB') +] + +# configs for vimeo90k-bd and vimeo90k-bi +vimeo_90k_data_root = 'data/vimeo90k' +vimeo_90k_file_list = [ + 'im1.png', 'im2.png', 'im3.png', 'im4.png', 'im5.png', 'im6.png', 'im7.png' +] + +vimeo_90k_pipeline = [ + dict(type=LoadImageFromFile, key='img', channel_order='rgb'), + dict(type=LoadImageFromFile, key='gt', channel_order='rgb'), + dict(type=MirrorSequence, keys=['img']), + dict(type=PackInputs) +] + +vimeo_90k_bd_dataloader = dict( + num_workers=1, + batch_size=1, + persistent_workers=False, + sampler=dict(type=DefaultSampler, shuffle=False), + dataset=dict( + type=BasicFramesDataset, + metainfo=dict(dataset_type='vimeo90k_seq', task_name='vsr'), + data_root=vimeo_90k_data_root, + data_prefix=dict(img='BDx4', gt='GT'), + ann_file='meta_info_Vimeo90K_test_GT.txt', + depth=2, + num_input_frames=7, + fixed_seq_len=7, + load_frames_list=dict(img=vimeo_90k_file_list, gt=['im4.png']), + pipeline=vimeo_90k_pipeline)) + +vimeo_90k_bi_dataloader = dict( + num_workers=1, + batch_size=1, + persistent_workers=False, + sampler=dict(type=DefaultSampler, shuffle=False), + dataset=dict( + type=BasicFramesDataset, + metainfo=dict(dataset_type='vimeo90k_seq', task_name='vsr'), + data_root=vimeo_90k_data_root, + data_prefix=dict(img='BIx4', gt='GT'), + ann_file='meta_info_Vimeo90K_test_GT.txt', + depth=2, + num_input_frames=7, + fixed_seq_len=7, + load_frames_list=dict(img=vimeo_90k_file_list, gt=['im4.png']), + pipeline=vimeo_90k_pipeline)) + +vimeo_90k_bd_evaluator = [ + dict(type=PSNR, convert_to='Y', prefix='Vimeo-90K-T-BDx4-Y'), + dict(type=SSIM, convert_to='Y', prefix='Vimeo-90K-T-BDx4-Y'), +] + +vimeo_90k_bi_evaluator = [ + dict(type=PSNR, convert_to='Y', prefix='Vimeo-90K-T-BIx4-Y'), + dict(type=SSIM, convert_to='Y', prefix='Vimeo-90K-T-BIx4-Y'), +] + +# config for UDM10 (BDx4) +udm10_data_root = 'data/UDM10' + +udm10_pipeline = [ + dict( + type=GenerateSegmentIndices, + interval_list=[1], + filename_tmpl='{:04d}.png'), + dict(type=LoadImageFromFile, key='img', channel_order='rgb'), + dict(type=LoadImageFromFile, key='gt', channel_order='rgb'), + dict(type=PackInputs) +] + +udm10_dataloader = dict( + num_workers=1, + batch_size=1, + persistent_workers=False, + sampler=dict(type=DefaultSampler, shuffle=False), + dataset=dict( + type=BasicFramesDataset, + metainfo=dict(dataset_type='udm10', task_name='vsr'), + data_root=udm10_data_root, + data_prefix=dict(img='BDx4', gt='GT'), + pipeline=udm10_pipeline)) + +udm10_evaluator = [ + dict(type=PSNR, convert_to='Y', prefix='UDM10-BDx4-Y'), + dict(type=SSIM, convert_to='Y', prefix='UDM10-BDx4-Y') +] + +# config for vid4 +vid4_data_root = 'data/Vid4' + +vid4_pipeline = [ + dict(type=GenerateSegmentIndices, interval_list=[1]), + dict(type=LoadImageFromFile, key='img', channel_order='rgb'), + dict(type=LoadImageFromFile, key='gt', channel_order='rgb'), + dict(type=PackInputs) +] +vid4_bd_dataloader = dict( + num_workers=1, + batch_size=1, + persistent_workers=False, + sampler=dict(type=DefaultSampler, shuffle=False), + dataset=dict( + type=BasicFramesDataset, + metainfo=dict(dataset_type='vid4', task_name='vsr'), + data_root=vid4_data_root, + data_prefix=dict(img='BDx4', gt='GT'), + ann_file='meta_info_Vid4_GT.txt', + depth=1, + pipeline=vid4_pipeline)) + +vid4_bi_dataloader = dict( + num_workers=1, + batch_size=1, + persistent_workers=False, + sampler=dict(type=DefaultSampler, shuffle=False), + dataset=dict( + type=BasicFramesDataset, + metainfo=dict(dataset_type='vid4', task_name='vsr'), + data_root=vid4_data_root, + data_prefix=dict(img='BIx4', gt='GT'), + ann_file='meta_info_Vid4_GT.txt', + depth=1, + pipeline=vid4_pipeline)) + +vid4_bd_evaluator = [ + dict(type=PSNR, convert_to='Y', prefix='VID4-BDx4-Y'), + dict(type=SSIM, convert_to='Y', prefix='VID4-BDx4-Y'), +] +vid4_bi_evaluator = [ + dict(type=PSNR, convert_to='Y', prefix='VID4-BIx4-Y'), + dict(type=SSIM, convert_to='Y', prefix='VID4-BIx4-Y'), +] + +# config for test +test_cfg = dict(type=MultiTestLoop) +test_dataloader = [ + reds_dataloader, + vimeo_90k_bd_dataloader, + vimeo_90k_bi_dataloader, + udm10_dataloader, + vid4_bd_dataloader, + vid4_bi_dataloader, +] +test_evaluator = [ + reds_evaluator, + vimeo_90k_bd_evaluator, + vimeo_90k_bi_evaluator, + udm10_evaluator, + vid4_bd_evaluator, + vid4_bi_evaluator, +] diff --git a/mmagic/configs/_base_/datasets/celeba.py b/mmagic/configs/_base_/datasets/celeba.py new file mode 100644 index 000000000..1980be75f --- /dev/null +++ b/mmagic/configs/_base_/datasets/celeba.py @@ -0,0 +1,48 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmengine.dataset import DefaultSampler, InfiniteSampler + +from mmagic.evaluation import MAE, PSNR, SSIM + +# Base config for CelebA-HQ dataset + +# dataset settings +dataset_type = 'BasicImageDataset' +data_root = 'data/CelebA-HQ' + +train_dataloader = dict( + num_workers=4, + persistent_workers=False, + sampler=dict(type=InfiniteSampler, shuffle=True), + dataset=dict( + type=dataset_type, + data_root=data_root, + data_prefix=dict(gt=''), + ann_file='train_celeba_img_list.txt', + test_mode=False, + )) + +val_dataloader = dict( + num_workers=4, + persistent_workers=False, + drop_last=False, + sampler=dict(type=DefaultSampler, shuffle=False), + dataset=dict( + type=dataset_type, + data_root=data_root, + data_prefix=dict(gt=''), + ann_file='val_celeba_img_list.txt', + test_mode=True, + )) + +test_dataloader = val_dataloader + +val_evaluator = [ + dict(type=MAE, mask_key='mask', scaling=100), + # By default, compute with pixel value from 0-1 + # scale=2 to align with 1.0 + # scale=100 seems to align with readme + dict(type=PSNR), + dict(type=SSIM), +] + +test_evaluator = val_evaluator diff --git a/mmagic/configs/_base_/datasets/cifar10_nopad.py b/mmagic/configs/_base_/datasets/cifar10_nopad.py new file mode 100644 index 000000000..696802183 --- /dev/null +++ b/mmagic/configs/_base_/datasets/cifar10_nopad.py @@ -0,0 +1,45 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmengine.dataset import DefaultSampler, InfiniteSampler + +from mmagic.datasets import CIFAR10 +from mmagic.datasets.transforms import Flip, PackInputs + +cifar_pipeline = [ + dict(type=Flip, keys=['gt'], flip_ratio=0.5, direction='horizontal'), + dict(type=PackInputs) +] +cifar_dataset = dict( + type=CIFAR10, + data_root='./data', + data_prefix='cifar10', + test_mode=False, + pipeline=cifar_pipeline) + +# test dataset do not use flip +cifar_pipeline_test = [dict(type=PackInputs)] +cifar_dataset_test = dict( + type=CIFAR10, + data_root='./data', + data_prefix='cifar10', + test_mode=False, + pipeline=cifar_pipeline_test) + +train_dataloader = dict( + num_workers=2, + dataset=cifar_dataset, + sampler=dict(type=InfiniteSampler, shuffle=True), + persistent_workers=True) + +val_dataloader = dict( + batch_size=32, + num_workers=2, + dataset=cifar_dataset_test, + sampler=dict(type=DefaultSampler, shuffle=False), + persistent_workers=True) + +test_dataloader = dict( + batch_size=32, + num_workers=2, + dataset=cifar_dataset_test, + sampler=dict(type=DefaultSampler, shuffle=False), + persistent_workers=True) diff --git a/mmagic/configs/_base_/datasets/comp1k.py b/mmagic/configs/_base_/datasets/comp1k.py new file mode 100644 index 000000000..db4028b6d --- /dev/null +++ b/mmagic/configs/_base_/datasets/comp1k.py @@ -0,0 +1,45 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmengine.dataset import DefaultSampler, InfiniteSampler + +from mmagic.evaluation import SAD, ConnectivityError, GradientError, MattingMSE + +# Base config for Composition-1K dataset + +# dataset settings +dataset_type = 'AdobeComp1kDataset' +data_root = 'data/adobe_composition-1k' + +train_dataloader = dict( + num_workers=4, + persistent_workers=False, + sampler=dict(type=InfiniteSampler, shuffle=True), + dataset=dict( + type=dataset_type, + data_root=data_root, + ann_file='training_list.json', + test_mode=False, + )) + +val_dataloader = dict( + num_workers=4, + persistent_workers=False, + drop_last=False, + sampler=dict(type=DefaultSampler, shuffle=False), + dataset=dict( + type=dataset_type, + data_root=data_root, + ann_file='test_list.json', + test_mode=True, + )) + +test_dataloader = val_dataloader + +# TODO: matting +val_evaluator = [ + dict(type=SAD), + dict(type=MattingMSE), + dict(type=GradientError), + dict(type=ConnectivityError), +] + +test_evaluator = val_evaluator diff --git a/mmagic/configs/_base_/datasets/deblurring-defocus_test_config.py b/mmagic/configs/_base_/datasets/deblurring-defocus_test_config.py new file mode 100644 index 000000000..1addf30d3 --- /dev/null +++ b/mmagic/configs/_base_/datasets/deblurring-defocus_test_config.py @@ -0,0 +1,106 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmengine.dataset import DefaultSampler + +from mmagic.datasets import BasicImageDataset +from mmagic.datasets.transforms import LoadImageFromFile, PackInputs +from mmagic.engine.runner import MultiTestLoop +from mmagic.evaluation import MAE, PSNR, SSIM + +test_pipeline = [ + dict( + type=LoadImageFromFile, + key='img', + color_type='color', + channel_order='rgb', + imdecode_backend='cv2'), + dict( + type=LoadImageFromFile, + key='imgL', + color_type='color', + channel_order='rgb', + imdecode_backend='cv2'), + dict( + type=LoadImageFromFile, + key='imgR', + color_type='color', + channel_order='rgb', + imdecode_backend='cv2'), + dict( + type=LoadImageFromFile, + key='gt', + color_type='color', + channel_order='rgb', + imdecode_backend='cv2'), + dict(type=PackInputs) +] + +dpdd_data_root = 'data/DPDD' + +dpdd_indoor_dataloader = dict( + num_workers=4, + persistent_workers=False, + drop_last=False, + sampler=dict(type=DefaultSampler, shuffle=False), + dataset=dict( + type=BasicImageDataset, + metainfo=dict(dataset_type='DPDD-Indoor', task_name='deblurring'), + data_root=dpdd_data_root, + data_prefix=dict( + img='inputC', imgL='inputL', imgR='inputR', gt='target'), + ann_file='indoor_labels.txt', + pipeline=test_pipeline)) +dpdd_indoor_evaluator = [ + dict(type=MAE, prefix='DPDD-Indoor'), + dict(type=PSNR, prefix='DPDD-Indoor'), + dict(type=SSIM, prefix='DPDD-Indoor'), +] + +dpdd_outdoor_dataloader = dict( + num_workers=4, + persistent_workers=False, + drop_last=False, + sampler=dict(type=DefaultSampler, shuffle=False), + dataset=dict( + type=BasicImageDataset, + metainfo=dict(dataset_type='DPDD-Outdoor', task_name='deblurring'), + data_root=dpdd_data_root, + data_prefix=dict( + img='inputC', imgL='inputL', imgR='inputR', gt='target'), + ann_file='outdoor_labels.txt', + pipeline=test_pipeline)) +dpdd_outdoor_evaluator = [ + dict(type=MAE, prefix='DPDD-Outdoor'), + dict(type=PSNR, prefix='DPDD-Outdoor'), + dict(type=SSIM, prefix='DPDD-Outdoor'), +] + +dpdd_dataloader = dict( + num_workers=4, + persistent_workers=False, + drop_last=False, + sampler=dict(type=DefaultSampler, shuffle=False), + dataset=dict( + type=BasicImageDataset, + metainfo=dict(dataset_type='DPDD-Combined', task_name='deblurring'), + data_root=dpdd_data_root, + data_prefix=dict( + img='inputC', imgL='inputL', imgR='inputR', gt='target'), + pipeline=test_pipeline)) +dpdd_evaluator = [ + dict(type=MAE, prefix='DPDD-Combined'), + dict(type=PSNR, prefix='DPDD-Combined'), + dict(type=SSIM, prefix='DPDD-Combined'), +] + +# test config +test_cfg = dict(type=MultiTestLoop) +test_dataloader = [ + dpdd_indoor_dataloader, + dpdd_outdoor_dataloader, + dpdd_dataloader, +] +test_evaluator = [ + dpdd_indoor_evaluator, + dpdd_outdoor_evaluator, + dpdd_evaluator, +] diff --git a/mmagic/configs/_base_/datasets/deblurring-motion_test_config.py b/mmagic/configs/_base_/datasets/deblurring-motion_test_config.py new file mode 100644 index 000000000..49e4c7308 --- /dev/null +++ b/mmagic/configs/_base_/datasets/deblurring-motion_test_config.py @@ -0,0 +1,106 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmengine.dataset import DefaultSampler + +from mmagic.datasets import BasicImageDataset +from mmagic.datasets.transforms import LoadImageFromFile, PackInputs +from mmagic.engine.runner import MultiTestLoop +from mmagic.evaluation import PSNR, SSIM + +test_pipeline = [ + dict( + type=LoadImageFromFile, + key='img', + color_type='color', + channel_order='rgb', + imdecode_backend='cv2'), + dict( + type=LoadImageFromFile, + key='gt', + color_type='color', + channel_order='rgb', + imdecode_backend='cv2'), + dict(type=PackInputs) +] + +gopro_data_root = 'data/gopro/test' +gopro_dataloader = dict( + num_workers=4, + persistent_workers=False, + drop_last=False, + sampler=dict(type=DefaultSampler, shuffle=False), + dataset=dict( + type=BasicImageDataset, + metainfo=dict(dataset_type='GoPro', task_name='deblurring'), + data_root=gopro_data_root, + data_prefix=dict(img='blur', gt='sharp'), + pipeline=test_pipeline)) +gopro_evaluator = [ + dict(type=PSNR, prefix='GoPro'), + dict(type=SSIM, prefix='GoPro'), +] + +hide_data_root = 'data/HIDE' +hide_dataloader = dict( + num_workers=4, + persistent_workers=False, + drop_last=False, + sampler=dict(type=DefaultSampler, shuffle=False), + dataset=dict( + type=BasicImageDataset, + metainfo=dict(dataset_type='HIDE', task_name='deblurring'), + data_root=hide_data_root, + data_prefix=dict(img='input', gt='target'), + pipeline=test_pipeline)) +hide_evaluator = [ + dict(type=PSNR, prefix='HIDE'), + dict(type=SSIM, prefix='HIDE'), +] + +realblurj_data_root = 'data/RealBlur_J' +realblurj_dataloader = dict( + num_workers=4, + persistent_workers=False, + drop_last=False, + sampler=dict(type=DefaultSampler, shuffle=False), + dataset=dict( + type=BasicImageDataset, + metainfo=dict(dataset_type='RealBlur_J', task_name='deblurring'), + data_root=realblurj_data_root, + data_prefix=dict(img='input', gt='target'), + pipeline=test_pipeline)) +realblurj_evaluator = [ + dict(type=PSNR, convert_to='Y', prefix='RealBlurJ'), + dict(type=SSIM, convert_to='Y', prefix='RealBlurJ'), +] + +realblurr_data_root = 'data/RealBlur_R' +realblurr_dataloader = dict( + num_workers=4, + persistent_workers=False, + drop_last=False, + sampler=dict(type=DefaultSampler, shuffle=False), + dataset=dict( + type=BasicImageDataset, + metainfo=dict(dataset_type='RealBlur_R', task_name='deblurring'), + data_root=realblurr_data_root, + data_prefix=dict(img='input', gt='target'), + pipeline=test_pipeline)) +realblurr_evaluator = [ + dict(type=PSNR, convert_to='Y', prefix='RealBlurR'), + dict(type=SSIM, convert_to='Y', prefix='RealBlurR'), +] + +# test config +test_cfg = dict(type=MultiTestLoop) +test_dataloader = [ + gopro_dataloader, + hide_dataloader, + realblurj_dataloader, + realblurr_dataloader, +] +test_evaluator = [ + gopro_evaluator, + hide_evaluator, + realblurj_evaluator, + realblurr_evaluator, +] diff --git a/mmagic/configs/_base_/datasets/decompression_test_config.py b/mmagic/configs/_base_/datasets/decompression_test_config.py new file mode 100644 index 000000000..edc1e84f3 --- /dev/null +++ b/mmagic/configs/_base_/datasets/decompression_test_config.py @@ -0,0 +1,75 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmengine.dataset import DefaultSampler + +from mmagic.datasets import BasicImageDataset +from mmagic.datasets.transforms import (LoadImageFromFile, PackInputs, + RandomJPEGCompression) +from mmagic.engine.runner import MultiTestLoop +from mmagic.evaluation import PSNR, SSIM + +quality = 10 +test_pipeline = [ + dict( + type=LoadImageFromFile, + key='img', + color_type='color', + channel_order='rgb', + imdecode_backend='cv2'), + dict( + type=LoadImageFromFile, + key='gt', + color_type='color', + channel_order='rgb', + imdecode_backend='cv2'), + dict( + type=RandomJPEGCompression, + params=dict(quality=[quality, quality]), + bgr2rgb=True, + keys=['img']), + dict(type=PackInputs) +] + +classic5_data_root = 'data/Classic5' +classic5_dataloader = dict( + num_workers=4, + persistent_workers=False, + drop_last=False, + sampler=dict(type=DefaultSampler, shuffle=False), + dataset=dict( + type=BasicImageDataset, + metainfo=dict(dataset_type='classic5', task_name='CAR'), + data_root=classic5_data_root, + data_prefix=dict(img='', gt=''), + pipeline=test_pipeline)) +classic5_evaluator = [ + dict(type=PSNR, prefix='Classic5'), + dict(type=SSIM, prefix='Classic5'), +] + +live1_data_root = 'data/LIVE1' +live1_dataloader = dict( + num_workers=4, + persistent_workers=False, + drop_last=False, + sampler=dict(type=DefaultSampler, shuffle=False), + dataset=dict( + type=BasicImageDataset, + metainfo=dict(dataset_type='live1', task_name='CAR'), + data_root=live1_data_root, + data_prefix=dict(img='', gt=''), + pipeline=test_pipeline)) +live1_evaluator = [ + dict(type=PSNR, prefix='LIVE1'), + dict(type=SSIM, prefix='LIVE1'), +] + +# test config +test_cfg = dict(type=MultiTestLoop) +test_dataloader = [ + classic5_dataloader, + live1_dataloader, +] +test_evaluator = [ + classic5_evaluator, + live1_evaluator, +] diff --git a/mmagic/configs/_base_/datasets/denoising-gaussian_color_test_config.py b/mmagic/configs/_base_/datasets/denoising-gaussian_color_test_config.py new file mode 100644 index 000000000..c1ec86b06 --- /dev/null +++ b/mmagic/configs/_base_/datasets/denoising-gaussian_color_test_config.py @@ -0,0 +1,113 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmengine.dataset import DefaultSampler + +from mmagic.datasets import BasicImageDataset +from mmagic.datasets.transforms import (LoadImageFromFile, PackInputs, + RandomNoise) +from mmagic.engine.runner import MultiTestLoop +from mmagic.evaluation import PSNR, SSIM + +sigma = 15 +test_pipeline = [ + dict( + type=LoadImageFromFile, + key='img', + color_type='color', + channel_order='rgb', + imdecode_backend='cv2'), + dict( + type=LoadImageFromFile, + key='gt', + color_type='color', + channel_order='rgb', + imdecode_backend='cv2'), + dict( + type=RandomNoise, + params=dict( + noise_type=['gaussian'], + noise_prob=[1], + gaussian_sigma=[sigma, sigma], + gaussian_gray_noise_prob=0), + keys=['img']), + dict(type=PackInputs) +] + +data_root = 'data/denoising_gaussian_test' +cbsd68_dataloader = dict( + num_workers=4, + persistent_workers=False, + drop_last=False, + sampler=dict(type=DefaultSampler, shuffle=False), + dataset=dict( + type=BasicImageDataset, + metainfo=dict(dataset_type='CBSD68', task_name='denoising'), + data_root=data_root, + data_prefix=dict(img='CBSD68', gt='CBSD68'), + pipeline=test_pipeline)) +cbsd68_evaluator = [ + dict(type=PSNR, prefix='CBSD68'), + dict(type=SSIM, prefix='CBSD68'), +] + +kodak24_dataloader = dict( + num_workers=4, + persistent_workers=False, + drop_last=False, + sampler=dict(type=DefaultSampler, shuffle=False), + dataset=dict( + type=BasicImageDataset, + metainfo=dict(dataset_type='Kodak24', task_name='denoising'), + data_root=data_root, + data_prefix=dict(img='Kodak24', gt='Kodak24'), + pipeline=test_pipeline)) +kodak24_evaluator = [ + dict(type=PSNR, prefix='Kodak24'), + dict(type=SSIM, prefix='Kodak24'), +] + +mcmaster_dataloader = dict( + num_workers=4, + persistent_workers=False, + drop_last=False, + sampler=dict(type=DefaultSampler, shuffle=False), + dataset=dict( + type=BasicImageDataset, + metainfo=dict(dataset_type='McMaster', task_name='denoising'), + data_root=data_root, + data_prefix=dict(img='McMaster', gt='McMaster'), + pipeline=test_pipeline)) +mcmaster_evaluator = [ + dict(type=PSNR, prefix='McMaster'), + dict(type=SSIM, prefix='McMaster'), +] + +urban100_dataloader = dict( + num_workers=4, + persistent_workers=False, + drop_last=False, + sampler=dict(type=DefaultSampler, shuffle=False), + dataset=dict( + type=BasicImageDataset, + metainfo=dict(dataset_type='Urban100', task_name='denoising'), + data_root=data_root, + data_prefix=dict(img='Urban100', gt='Urban100'), + pipeline=test_pipeline)) +urban100_evaluator = [ + dict(type=PSNR, prefix='Urban100'), + dict(type=SSIM, prefix='Urban100'), +] + +# test config +test_cfg = dict(type=MultiTestLoop) +test_dataloader = [ + cbsd68_dataloader, + kodak24_dataloader, + mcmaster_dataloader, + urban100_dataloader, +] +test_evaluator = [ + cbsd68_evaluator, + kodak24_evaluator, + mcmaster_evaluator, + urban100_evaluator, +] diff --git a/mmagic/configs/_base_/datasets/denoising-gaussian_gray_test_config.py b/mmagic/configs/_base_/datasets/denoising-gaussian_gray_test_config.py new file mode 100644 index 000000000..c51cf5fe6 --- /dev/null +++ b/mmagic/configs/_base_/datasets/denoising-gaussian_gray_test_config.py @@ -0,0 +1,97 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmengine.dataset import DefaultSampler + +from mmagic.datasets import BasicImageDataset +from mmagic.datasets.transforms import (LoadImageFromFile, PackInputs, + RandomNoise) +from mmagic.engine.runner import MultiTestLoop +from mmagic.evaluation import PSNR, SSIM + +sigma = 15 +test_pipeline = [ + dict( + type=LoadImageFromFile, + key='img', + color_type='color', + channel_order='rgb', + to_y_channel=True, + imdecode_backend='cv2'), + dict( + type=LoadImageFromFile, + key='gt', + color_type='color', + channel_order='rgb', + to_y_channel=True, + imdecode_backend='cv2'), + dict( + type=RandomNoise, + params=dict( + noise_type=['gaussian'], + noise_prob=[1], + gaussian_sigma=[sigma, sigma], + gaussian_gray_noise_prob=1), + keys=['img']), + dict(type=PackInputs) +] + +data_root = 'data/denoising_gaussian_test' +set12_dataloader = dict( + num_workers=4, + persistent_workers=False, + drop_last=False, + sampler=dict(type=DefaultSampler, shuffle=False), + dataset=dict( + type=BasicImageDataset, + metainfo=dict(dataset_type='Set12', task_name='denoising'), + data_root=data_root, + data_prefix=dict(img='Set12', gt='Set12'), + pipeline=test_pipeline)) +set12_evaluator = [ + dict(type=PSNR, prefix='Set12'), + dict(type=SSIM, prefix='Set12'), +] + +bsd68_dataloader = dict( + num_workers=4, + persistent_workers=False, + drop_last=False, + sampler=dict(type=DefaultSampler, shuffle=False), + dataset=dict( + type=BasicImageDataset, + metainfo=dict(dataset_type='BSD68', task_name='denoising'), + data_root=data_root, + data_prefix=dict(img='BSD68', gt='BSD68'), + pipeline=test_pipeline)) +bsd68_evaluator = [ + dict(type=PSNR, prefix='BSD68'), + dict(type=SSIM, prefix='BSD68'), +] + +urban100_dataloader = dict( + num_workers=4, + persistent_workers=False, + drop_last=False, + sampler=dict(type=DefaultSampler, shuffle=False), + dataset=dict( + type=BasicImageDataset, + metainfo=dict(dataset_type='Urban100', task_name='denoising'), + data_root=data_root, + data_prefix=dict(img='Urban100', gt='Urban100'), + pipeline=test_pipeline)) +urban100_evaluator = [ + dict(type=PSNR, prefix='Urban100'), + dict(type=SSIM, prefix='Urban100'), +] + +# test config +test_cfg = dict(type=MultiTestLoop) +test_dataloader = [ + set12_dataloader, + bsd68_dataloader, + urban100_dataloader, +] +test_evaluator = [ + set12_evaluator, + bsd68_evaluator, + urban100_evaluator, +] diff --git a/mmagic/configs/_base_/datasets/denoising-real_test_config.py b/mmagic/configs/_base_/datasets/denoising-real_test_config.py new file mode 100644 index 000000000..27b04757d --- /dev/null +++ b/mmagic/configs/_base_/datasets/denoising-real_test_config.py @@ -0,0 +1,69 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmengine.dataset import DefaultSampler + +from mmagic.datasets import BasicImageDataset +from mmagic.datasets.transforms import LoadImageFromFile, PackInputs +from mmagic.engine.runner import MultiTestLoop +from mmagic.evaluation import PSNR, SSIM + +test_pipeline = [ + dict( + type=LoadImageFromFile, + key='img', + color_type='color', + channel_order='rgb', + imdecode_backend='cv2'), + dict( + type=LoadImageFromFile, + key='gt', + color_type='color', + channel_order='rgb', + imdecode_backend='cv2'), + dict(type=PackInputs) +] + +sidd_data_root = 'data/SIDD/val/' +sidd_dataloader = dict( + num_workers=4, + persistent_workers=False, + drop_last=False, + sampler=dict(type=DefaultSampler, shuffle=False), + dataset=dict( + type=BasicImageDataset, + metainfo=dict(dataset_type='SIDD', task_name='denoising'), + data_root=sidd_data_root, + data_prefix=dict(img='noisy', gt='gt'), + filename_tmpl=dict(gt='{}_GT', img='{}_NOISY'), + pipeline=test_pipeline)) +sidd_evaluator = [ + dict(type=PSNR, prefix='SIDD'), + dict(type=SSIM, prefix='SIDD'), +] + +dnd_data_root = 'data/DND' +dnd_dataloader = dict( + num_workers=4, + persistent_workers=False, + drop_last=False, + sampler=dict(type=DefaultSampler, shuffle=False), + dataset=dict( + type=BasicImageDataset, + metainfo=dict(dataset_type='DND', task_name='denoising'), + data_root=dnd_data_root, + data_prefix=dict(img='input', gt='groundtruth'), + pipeline=test_pipeline)) +dnd_evaluator = [ + dict(type=PSNR, prefix='DND'), + dict(type=SSIM, prefix='DND'), +] + +# test config +test_cfg = dict(type=MultiTestLoop) +test_dataloader = [ + sidd_dataloader, + # dnd_dataloader, +] +test_evaluator = [ + sidd_evaluator, + # dnd_dataloader, +] diff --git a/mmagic/configs/_base_/datasets/deraining_test_config.py b/mmagic/configs/_base_/datasets/deraining_test_config.py new file mode 100644 index 000000000..d8b28a1ac --- /dev/null +++ b/mmagic/configs/_base_/datasets/deraining_test_config.py @@ -0,0 +1,125 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmengine.dataset import DefaultSampler + +from mmagic.datasets import BasicImageDataset +from mmagic.datasets.transforms import LoadImageFromFile, PackInputs +from mmagic.engine.runner import MultiTestLoop +from mmagic.evaluation import PSNR, SSIM + +test_pipeline = [ + dict( + type=LoadImageFromFile, + key='img', + color_type='color', + channel_order='rgb', + imdecode_backend='cv2'), + dict( + type=LoadImageFromFile, + key='gt', + color_type='color', + channel_order='rgb', + imdecode_backend='cv2'), + dict(type=PackInputs) +] + +rain100h_data_root = 'data/Rain100H' +rain100h_dataloader = dict( + num_workers=4, + persistent_workers=False, + drop_last=False, + sampler=dict(type=DefaultSampler, shuffle=False), + dataset=dict( + type=BasicImageDataset, + metainfo=dict(dataset_type='Rain100H', task_name='deraining'), + data_root=rain100h_data_root, + data_prefix=dict(img='input', gt='target'), + pipeline=test_pipeline)) +rain100h_evaluator = [ + dict(type=PSNR, convert_to='Y', prefix='Rain100H'), + dict(type=SSIM, convert_to='Y', prefix='Rain100H'), +] + +rain100l_data_root = 'data/Rain100L' +rain100l_dataloader = dict( + num_workers=4, + persistent_workers=False, + drop_last=False, + sampler=dict(type=DefaultSampler, shuffle=False), + dataset=dict( + type=BasicImageDataset, + metainfo=dict(dataset_type='Rain100L', task_name='deraining'), + data_root=rain100l_data_root, + data_prefix=dict(img='input', gt='target'), + pipeline=test_pipeline)) +rain100l_evaluator = [ + dict(type=PSNR, convert_to='Y', prefix='Rain100L'), + dict(type=SSIM, convert_to='Y', prefix='Rain100L'), +] + +test100_data_root = 'data/Test100' +test100_dataloader = dict( + num_workers=4, + persistent_workers=False, + drop_last=False, + sampler=dict(type=DefaultSampler, shuffle=False), + dataset=dict( + type=BasicImageDataset, + metainfo=dict(dataset_type='Test100', task_name='deraining'), + data_root=test100_data_root, + data_prefix=dict(img='input', gt='target'), + pipeline=test_pipeline)) +test100_evaluator = [ + dict(type=PSNR, convert_to='Y', prefix='Test100'), + dict(type=SSIM, convert_to='Y', prefix='Test100'), +] + +test1200_data_root = 'data/Test1200' +test1200_dataloader = dict( + num_workers=4, + persistent_workers=False, + drop_last=False, + sampler=dict(type=DefaultSampler, shuffle=False), + dataset=dict( + type=BasicImageDataset, + metainfo=dict(dataset_type='Test1200', task_name='deraining'), + data_root=test1200_data_root, + data_prefix=dict(img='input', gt='target'), + pipeline=test_pipeline)) +test1200_evaluator = [ + dict(type=PSNR, convert_to='Y', prefix='Test1200'), + dict(type=SSIM, convert_to='Y', prefix='Test1200'), +] + +test2800_data_root = 'data/Test2800' +test2800_dataloader = dict( + num_workers=4, + persistent_workers=False, + drop_last=False, + sampler=dict(type=DefaultSampler, shuffle=False), + dataset=dict( + type=BasicImageDataset, + metainfo=dict(dataset_type='Test2800', task_name='deraining'), + data_root=test2800_data_root, + data_prefix=dict(img='input', gt='target'), + pipeline=test_pipeline)) +test2800_evaluator = [ + dict(type=PSNR, convert_to='Y', prefix='Test2800'), + dict(type=SSIM, convert_to='Y', prefix='Test2800'), +] + +# test config +test_cfg = dict(type=MultiTestLoop) +test_dataloader = [ + rain100h_dataloader, + rain100l_dataloader, + test100_dataloader, + test1200_dataloader, + test2800_dataloader, +] +test_evaluator = [ + rain100h_evaluator, + rain100l_evaluator, + test100_evaluator, + test1200_evaluator, + test2800_evaluator, +] diff --git a/mmagic/configs/_base_/datasets/grow_scale_imgs_ffhq_styleganv1.py b/mmagic/configs/_base_/datasets/grow_scale_imgs_ffhq_styleganv1.py new file mode 100644 index 000000000..6c8db9451 --- /dev/null +++ b/mmagic/configs/_base_/datasets/grow_scale_imgs_ffhq_styleganv1.py @@ -0,0 +1,51 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmengine.dataset import DefaultSampler, InfiniteSampler + +from mmagic.datasets import BasicImageDataset, GrowScaleImgDataset +from mmagic.datasets.transforms import Flip, LoadImageFromFile, PackInputs + +dataset_type = 'GrowScaleImgDataset' + +pipeline = [ + dict(type=LoadImageFromFile, key='gt'), + dict(type=Flip, keys='gt', direction='horizontal'), + dict(type=PackInputs) +] + +train_dataloader = dict( + num_workers=4, + batch_size=64, + dataset=dict( + type=GrowScaleImgDataset, + data_roots={ + '1024': './data/ffhq/images', + '256': './data/ffhq/ffhq_imgs/ffhq_256', + }, + gpu_samples_base=4, + # note that this should be changed with total gpu number + gpu_samples_per_scale={ + '4': 64, + '8': 32, + '16': 16, + '32': 8, + '64': 4, + '128': 4, + '256': 4, + '512': 4, + '1024': 4 + }, + len_per_stage=300000, + pipeline=pipeline), + sampler=dict(type=InfiniteSampler, shuffle=True)) + +test_dataloader = dict( + num_workers=4, + batch_size=64, + dataset=dict( + type=BasicImageDataset, + data_prefix=dict(gt=''), + pipeline=pipeline, + data_root='./data/ffhq/images'), + sampler=dict(type=DefaultSampler, shuffle=False)) + +val_dataloader = test_dataloader diff --git a/mmagic/configs/_base_/datasets/imagenet_128.py b/mmagic/configs/_base_/datasets/imagenet_128.py new file mode 100644 index 000000000..aaa272015 --- /dev/null +++ b/mmagic/configs/_base_/datasets/imagenet_128.py @@ -0,0 +1,52 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmengine.dataset import DefaultSampler + +from mmagic.datasets.transforms import (CenterCropLongEdge, Flip, + LoadImageFromFile, PackInputs, + RandomCropLongEdge, Resize) + +# dataset settings +dataset_type = 'ImageNet' + +# different from mmcls, we adopt the setting used in BigGAN. +# We use `RandomCropLongEdge` in training and `CenterCropLongEdge` in testing. +train_pipeline = [ + dict(type=LoadImageFromFile, key='gt'), + dict(type=RandomCropLongEdge, keys='gt'), + dict(type=Resize, scale=(128, 128), keys='gt', backend='pillow'), + dict(type=Flip, keys='gt', flip_ratio=0.5, direction='horizontal'), + dict(type=PackInputs) +] + +test_pipeline = [ + dict(type=LoadImageFromFile, key='gt'), + dict(type=CenterCropLongEdge, keys='gt'), + dict(type=Resize, scale=(128, 128), keys='gt', backend='pillow'), + dict(type=PackInputs) +] + +train_dataloader = dict( + batch_size=None, + num_workers=5, + dataset=dict( + type=dataset_type, + data_root='./data/imagenet/', + ann_file='meta/train.txt', + data_prefix='train', + pipeline=train_pipeline), + sampler=dict(type=DefaultSampler, shuffle=True), + persistent_workers=True) + +val_dataloader = dict( + batch_size=64, + num_workers=5, + dataset=dict( + type=dataset_type, + data_root='./data/imagenet/', + ann_file='meta/train.txt', + data_prefix='train', + pipeline=test_pipeline), + sampler=dict(type=DefaultSampler, shuffle=False), + persistent_workers=True) + +test_dataloader = val_dataloader diff --git a/mmagic/configs/_base_/datasets/imagenet_256.py b/mmagic/configs/_base_/datasets/imagenet_256.py new file mode 100644 index 000000000..a443b69a4 --- /dev/null +++ b/mmagic/configs/_base_/datasets/imagenet_256.py @@ -0,0 +1,52 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmengine.dataset import DefaultSampler + +from mmagic.datasets.transforms import (CenterCropLongEdge, Flip, + LoadImageFromFile, PackInputs, + RandomCropLongEdge, Resize) + +# dataset settings +dataset_type = 'ImageNet' + +# different from mmcls, we adopt the setting used in BigGAN. +# We use `RandomCropLongEdge` in training and `CenterCropLongEdge` in testing. +train_pipeline = [ + dict(type=LoadImageFromFile, key='img'), + dict(type=RandomCropLongEdge, keys=['img']), + dict(type=Resize, scale=(256, 256), keys=['img'], backend='pillow'), + dict(type=Flip, keys=['img'], flip_ratio=0.5, direction='horizontal'), + dict(type=PackInputs) +] + +test_pipeline = [ + dict(type=LoadImageFromFile, key='img'), + dict(type=CenterCropLongEdge, keys=['img']), + dict(type=Resize, scale=(256, 256), backend='pillow'), + dict(type=PackInputs) +] + +train_dataloader = dict( + batch_size=None, + num_workers=5, + dataset=dict( + type=dataset_type, + data_root='./data/imagenet/', + ann_file='meta/train.txt', + data_prefix='train', + pipeline=train_pipeline), + sampler=dict(type=DefaultSampler, shuffle=True), + persistent_workers=True) + +val_dataloader = dict( + batch_size=None, + num_workers=5, + dataset=dict( + type=dataset_type, + data_root='./data/imagenet/', + ann_file='meta/train.txt', + data_prefix='train', + pipeline=test_pipeline), + sampler=dict(type=DefaultSampler, shuffle=False), + persistent_workers=True) + +test_dataloader = val_dataloader diff --git a/mmagic/configs/_base_/datasets/liif_test_config.py b/mmagic/configs/_base_/datasets/liif_test_config.py new file mode 100644 index 000000000..343ab634b --- /dev/null +++ b/mmagic/configs/_base_/datasets/liif_test_config.py @@ -0,0 +1,95 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmengine.dataset import DefaultSampler + +from mmagic.datasets import BasicImageDataset +from mmagic.datasets.transforms import (GenerateCoordinateAndCell, + LoadImageFromFile, PackInputs, + RandomDownSampling) +from mmagic.engine.runner import MultiTestLoop +from mmagic.evaluation import PSNR, SSIM + +scale_test_list = [2, 3, 4, 6, 18, 30] + +test_pipelines = [[ + dict( + type=LoadImageFromFile, + key='gt', + color_type='color', + channel_order='rgb', + imdecode_backend='cv2'), + dict(type=RandomDownSampling, scale_min=scale_test, scale_max=scale_test), + dict(type=GenerateCoordinateAndCell, scale=scale_test, reshape_gt=False), + dict(type=PackInputs) +] for scale_test in scale_test_list] + +# test config for Set5 +set5_dataloaders = [ + dict( + num_workers=4, + persistent_workers=False, + drop_last=False, + sampler=dict(type=DefaultSampler, shuffle=False), + dataset=dict( + type=BasicImageDataset, + metainfo=dict(dataset_type='set5', task_name='sisr'), + data_root='data/Set5', + data_prefix=dict(img='LRbicx4', gt='GTmod12'), + pipeline=test_pipeline)) for test_pipeline in test_pipelines +] +set5_evaluators = [[ + dict(type=PSNR, crop_border=scale, prefix=f'Set5x{scale}'), + dict(type=SSIM, crop_border=scale, prefix=f'Set5x{scale}'), +] for scale in scale_test_list] + +# test config for Set14 +set14_dataloaders = [ + dict( + num_workers=4, + persistent_workers=False, + drop_last=False, + sampler=dict(type=DefaultSampler, shuffle=False), + dataset=dict( + type=BasicImageDataset, + metainfo=dict(dataset_type='set14', task_name='sisr'), + data_root='data/Set14', + data_prefix=dict(img='LRbicx4', gt='GTmod12'), + pipeline=test_pipeline)) for test_pipeline in test_pipelines +] +set14_evaluators = [[ + dict(type=PSNR, crop_border=scale, prefix=f'Set14x{scale}'), + dict(type=SSIM, crop_border=scale, prefix=f'Set14x{scale}'), +] for scale in scale_test_list] + +# test config for DIV2K +div2k_dataloaders = [ + dict( + num_workers=4, + persistent_workers=False, + drop_last=False, + sampler=dict(type=DefaultSampler, shuffle=False), + dataset=dict( + type=BasicImageDataset, + ann_file='meta_info_DIV2K100sub_GT.txt', + metainfo=dict(dataset_type='div2k', task_name='sisr'), + data_root='data/DIV2K', + data_prefix=dict( + img='DIV2K_train_LR_bicubic/X4_sub', gt='DIV2K_train_HR_sub'), + pipeline=test_pipeline)) for test_pipeline in test_pipelines +] +div2k_evaluators = [[ + dict(type=PSNR, crop_border=scale, prefix=f'DIV2Kx{scale}'), + dict(type=SSIM, crop_border=scale, prefix=f'DIV2Kx{scale}'), +] for scale in scale_test_list] + +# test config +test_cfg = dict(type=MultiTestLoop) +test_dataloader = [ + *set5_dataloaders, + *set14_dataloaders, + *div2k_dataloaders, +] +test_evaluator = [ + *set5_evaluators, + *set14_evaluators, + *div2k_evaluators, +] diff --git a/mmagic/configs/_base_/datasets/lsun_stylegan.py b/mmagic/configs/_base_/datasets/lsun_stylegan.py new file mode 100644 index 000000000..1b66a8689 --- /dev/null +++ b/mmagic/configs/_base_/datasets/lsun_stylegan.py @@ -0,0 +1,47 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmengine.dataset import DefaultSampler, InfiniteSampler + +from mmagic.datasets.transforms import LoadImageFromFile, PackInputs + +dataset_type = 'BasicImageDataset' + +train_pipeline = [ + dict(type=LoadImageFromFile, key='gt'), + dict(type=PackInputs) +] + +val_pipeline = [dict(type=LoadImageFromFile, key='gt'), dict(type=PackInputs)] + +# `batch_size` and `data_root` need to be set. +train_dataloader = dict( + batch_size=4, + num_workers=8, + persistent_workers=True, + sampler=dict(type=InfiniteSampler, shuffle=True), + dataset=dict( + type=dataset_type, + data_prefix=dict(gt=''), + data_root=None, # set by user + pipeline=train_pipeline)) + +val_dataloader = dict( + batch_size=4, + num_workers=8, + dataset=dict( + type=dataset_type, + data_prefix=dict(gt=''), + data_root=None, # set by user + pipeline=val_pipeline), + sampler=dict(type=DefaultSampler, shuffle=False), + persistent_workers=True) + +test_dataloader = dict( + batch_size=4, + num_workers=8, + dataset=dict( + type=dataset_type, + data_prefix=dict(gt=''), + data_root=None, # set by user + pipeline=val_pipeline), + sampler=dict(type=DefaultSampler, shuffle=False), + persistent_workers=True) diff --git a/mmagic/configs/_base_/datasets/paired_imgs_256x256_crop.py b/mmagic/configs/_base_/datasets/paired_imgs_256x256_crop.py new file mode 100644 index 000000000..c814402fd --- /dev/null +++ b/mmagic/configs/_base_/datasets/paired_imgs_256x256_crop.py @@ -0,0 +1,108 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmengine.dataset import DefaultSampler, InfiniteSampler + +from mmagic.datasets.transforms import (FixedCrop, Flip, + LoadPairedImageFromFile, PackInputs, + Resize) + +dataset_type = 'PairedImageDataset' +# domain_a = None # set by user +# domain_b = None # set by user + +train_pipeline = [ + dict( + type=LoadPairedImageFromFile, + key='pair', + domain_a='A', + domain_b='B', + color_type='color'), + dict( + type=Resize, + keys=['img_A', 'img_B'], + scale=(286, 286), + interpolation='bicubic'), + dict(type=FixedCrop, keys=['img_A', 'img_B'], crop_size=(256, 256)), + dict(type=Flip, keys=['img_A', 'img_B'], direction='horizontal'), + # NOTE: users should implement their own keyMapper and Pack operation + # dict( + # type='KeyMapper', + # mapping={ + # f'img_{domain_a}': 'img_A', + # f'img_{domain_b}': 'img_B' + # }, + # remapping={ + # f'img_{domain_a}': f'img_{domain_a}', + # f'img_{domain_b}': f'img_{domain_b}' + # }), + # dict( + # type=PackInputs, + # keys=[f'img_{domain_a}', f'img_{domain_b}'], + # data_keys=[f'img_{domain_a}', f'img_{domain_b}']) +] + +test_pipeline = [ + dict( + type=LoadPairedImageFromFile, + key='pair', + domain_a='A', + domain_b='B', + color_type='color'), + dict( + type='TransformBroadcaster', + mapping={'img': ['img_A', 'img_B']}, + auto_remap=True, + share_random_params=True, + transforms=[ + dict( + type=Resize, + scale=(256, 256), + keys='img', + interpolation='bicubic') + ]), + # NOTE: users should implement their own keyMapper and Pack operation + # dict( + # type='KeyMapper', + # mapping={ + # f'img_{domain_a}': 'img_A', + # f'img_{domain_b}': 'img_B' + # }, + # remapping={ + # f'img_{domain_a}': f'img_{domain_a}', + # f'img_{domain_b}': f'img_{domain_b}' + # }), + # dict( + # type=PackInputs, + # keys=[f'img_{domain_a}', f'img_{domain_b}'], + # data_keys=[f'img_{domain_a}', f'img_{domain_b}']) +] + +# `batch_size` and `data_root` need to be set. +train_dataloader = dict( + batch_size=4, + num_workers=4, + persistent_workers=True, + sampler=dict(type=InfiniteSampler, shuffle=True), + dataset=dict( + type=dataset_type, + data_root=None, # set by user + pipeline=train_pipeline)) + +val_dataloader = dict( + batch_size=4, + num_workers=4, + dataset=dict( + type=dataset_type, + data_root=None, # set by user + pipeline=test_pipeline), + sampler=dict(type=DefaultSampler, shuffle=False), + persistent_workers=True) + +test_dataloader = dict( + batch_size=4, + num_workers=4, + dataset=dict( + type=dataset_type, + data_root=None, # set by user + pipeline=test_pipeline), + sampler=dict(type=DefaultSampler, shuffle=False), + persistent_workers=True) diff --git a/mmagic/configs/_base_/datasets/places.py b/mmagic/configs/_base_/datasets/places.py new file mode 100644 index 000000000..a19822b4a --- /dev/null +++ b/mmagic/configs/_base_/datasets/places.py @@ -0,0 +1,50 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmengine.dataset import DefaultSampler, InfiniteSampler + +from mmagic.evaluation import MAE, PSNR, SSIM + +# Base config for places365 dataset + +# dataset settings +dataset_type = 'BasicImageDataset' +data_root = 'data/Places' + +train_dataloader = dict( + num_workers=4, + persistent_workers=False, + sampler=dict(type=InfiniteSampler, shuffle=True), + dataset=dict( + type=dataset_type, + data_root=data_root, + data_prefix=dict(gt='data_large'), + ann_file='meta/places365_train_challenge.txt', + # Note that Places365-standard (1.8M images) and + # Place365-challenge (8M images) use different image lists. + test_mode=False, + )) + +val_dataloader = dict( + num_workers=4, + persistent_workers=False, + drop_last=False, + sampler=dict(type=DefaultSampler, shuffle=False), + dataset=dict( + type=dataset_type, + data_root=data_root, + data_prefix=dict(gt='val_large'), + ann_file='meta/places365_val.txt', + test_mode=True, + )) + +test_dataloader = val_dataloader + +val_evaluator = [ + dict(type=MAE, mask_key='mask', scaling=100), + # By default, compute with pixel value from 0-1 + # scale=2 to align with 1.0 + # scale=100 seems to align with readme + dict(type=PSNR), + dict(type=SSIM), +] + +test_evaluator = val_evaluator diff --git a/mmagic/configs/_base_/datasets/sisr_x2_test_config.py b/mmagic/configs/_base_/datasets/sisr_x2_test_config.py new file mode 100644 index 000000000..1cf9d08c3 --- /dev/null +++ b/mmagic/configs/_base_/datasets/sisr_x2_test_config.py @@ -0,0 +1,98 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmengine.dataset import DefaultSampler + +from mmagic.datasets import BasicImageDataset +from mmagic.datasets.transforms import LoadImageFromFile, PackInputs +from mmagic.engine.runner import MultiTestLoop +from mmagic.evaluation import PSNR, SSIM, Evaluator + +test_pipeline = [ + dict( + type=LoadImageFromFile, + key='img', + color_type='color', + channel_order='rgb', + imdecode_backend='cv2'), + dict( + type=LoadImageFromFile, + key='gt', + color_type='color', + channel_order='rgb', + imdecode_backend='cv2'), + dict(type=PackInputs) +] + +# test config for Set5 +set5_data_root = 'data/Set5' +set5_dataloader = dict( + num_workers=4, + persistent_workers=False, + drop_last=False, + sampler=dict(type=DefaultSampler, shuffle=False), + dataset=dict( + type=BasicImageDataset, + metainfo=dict(dataset_type='set5', task_name='sisr'), + data_root=set5_data_root, + data_prefix=dict(img='LRbicx2', gt='GTmod12'), + pipeline=test_pipeline)) +set5_evaluator = dict( + type=Evaluator, + metrics=[ + dict(type=PSNR, crop_border=2, prefix='Set5'), + dict(type=SSIM, crop_border=2, prefix='Set5'), + ]) + +set14_data_root = 'data/Set14' +set14_dataloader = dict( + num_workers=4, + persistent_workers=False, + drop_last=False, + sampler=dict(type=DefaultSampler, shuffle=False), + dataset=dict( + type=BasicImageDataset, + metainfo=dict(dataset_type='set14', task_name='sisr'), + data_root=set14_data_root, + data_prefix=dict(img='LRbicx2', gt='GTmod12'), + pipeline=test_pipeline)) +set14_evaluator = dict( + type=Evaluator, + metrics=[ + dict(type=PSNR, crop_border=2, prefix='Set14'), + dict(type=SSIM, crop_border=2, prefix='Set14'), + ]) + +# test config for DIV2K +div2k_data_root = 'data/DIV2K' +div2k_dataloader = dict( + num_workers=4, + persistent_workers=False, + drop_last=False, + sampler=dict(type=DefaultSampler, shuffle=False), + dataset=dict( + type=BasicImageDataset, + # ann_file='meta_info_DIV2K800sub_GT.txt', + ann_file='meta_info_DIV2K100sub_GT.txt', + metainfo=dict(dataset_type='div2k', task_name='sisr'), + data_root=div2k_data_root, + data_prefix=dict( + img='DIV2K_train_LR_bicubic/X2_sub', gt='DIV2K_train_HR_sub'), + pipeline=test_pipeline)) +div2k_evaluator = dict( + type=Evaluator, + metrics=[ + dict(type=PSNR, crop_border=2, prefix='DIV2K'), + dict(type=SSIM, crop_border=2, prefix='DIV2K'), + ]) + +# test config +test_cfg = dict(type=MultiTestLoop) +test_dataloader = [ + set5_dataloader, + set14_dataloader, + div2k_dataloader, +] +test_evaluator = [ + set5_evaluator, + set14_evaluator, + div2k_evaluator, +] diff --git a/mmagic/configs/_base_/datasets/sisr_x3_test_config.py b/mmagic/configs/_base_/datasets/sisr_x3_test_config.py new file mode 100644 index 000000000..be85f35af --- /dev/null +++ b/mmagic/configs/_base_/datasets/sisr_x3_test_config.py @@ -0,0 +1,97 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmengine.dataset import DefaultSampler + +from mmagic.datasets import BasicImageDataset +from mmagic.datasets.transforms import LoadImageFromFile, PackInputs +from mmagic.engine.runner import MultiTestLoop +from mmagic.evaluation import PSNR, SSIM, Evaluator + +test_pipeline = [ + dict( + type=LoadImageFromFile, + key='img', + color_type='color', + channel_order='rgb', + imdecode_backend='cv2'), + dict( + type=LoadImageFromFile, + key='gt', + color_type='color', + channel_order='rgb', + imdecode_backend='cv2'), + dict(type=PackInputs) +] + +# test config for Set5 +set5_data_root = 'data/Set5' +set5_dataloader = dict( + num_workers=4, + persistent_workers=False, + drop_last=False, + sampler=dict(type=DefaultSampler, shuffle=False), + dataset=dict( + type=BasicImageDataset, + metainfo=dict(dataset_type='set5', task_name='sisr'), + data_root=set5_data_root, + data_prefix=dict(img='LRbicx3', gt='GTmod12'), + pipeline=test_pipeline)) +set5_evaluator = dict( + type=Evaluator, + metrics=[ + dict(type=PSNR, crop_border=3, prefix='Set5'), + dict(type=SSIM, crop_border=3, prefix='Set5'), + ]) + +set14_data_root = 'data/Set14' +set14_dataloader = dict( + num_workers=4, + persistent_workers=False, + drop_last=False, + sampler=dict(type=DefaultSampler, shuffle=False), + dataset=dict( + type=BasicImageDataset, + metainfo=dict(dataset_type='set14', task_name='sisr'), + data_root=set14_data_root, + data_prefix=dict(img='LRbicx3', gt='GTmod12'), + pipeline=test_pipeline)) +set14_evaluator = dict( + type=Evaluator, + metrics=[ + dict(type=PSNR, crop_border=3, prefix='Set14'), + dict(type=SSIM, crop_border=3, prefix='Set14'), + ]) + +# test config for DIV2K +div2k_data_root = 'data/DIV2K' +div2k_dataloader = dict( + num_workers=4, + persistent_workers=False, + drop_last=False, + sampler=dict(type=DefaultSampler, shuffle=False), + dataset=dict( + type=BasicImageDataset, + ann_file='meta_info_DIV2K100sub_GT.txt', + metainfo=dict(dataset_type='div2k', task_name='sisr'), + data_root=div2k_data_root, + data_prefix=dict( + img='DIV2K_train_LR_bicubic/X3_sub', gt='DIV2K_train_HR_sub'), + pipeline=test_pipeline)) +div2k_evaluator = dict( + type=Evaluator, + metrics=[ + dict(type=PSNR, crop_border=3, prefix='DIV2K'), + dict(type=SSIM, crop_border=3, prefix='DIV2K'), + ]) + +# test config +test_cfg = dict(type=MultiTestLoop) +test_dataloader = [ + set5_dataloader, + set14_dataloader, + div2k_dataloader, +] +test_evaluator = [ + set5_evaluator, + set14_evaluator, + div2k_evaluator, +] diff --git a/mmagic/configs/_base_/datasets/sisr_x4_test_config.py b/mmagic/configs/_base_/datasets/sisr_x4_test_config.py new file mode 100644 index 000000000..2d5af23d4 --- /dev/null +++ b/mmagic/configs/_base_/datasets/sisr_x4_test_config.py @@ -0,0 +1,97 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmengine.dataset import DefaultSampler + +from mmagic.datasets import BasicImageDataset +from mmagic.datasets.transforms import LoadImageFromFile, PackInputs +from mmagic.engine.runner import MultiTestLoop +from mmagic.evaluation import PSNR, SSIM, Evaluator + +test_pipeline = [ + dict( + type=LoadImageFromFile, + key='img', + color_type='color', + channel_order='rgb', + imdecode_backend='cv2'), + dict( + type=LoadImageFromFile, + key='gt', + color_type='color', + channel_order='rgb', + imdecode_backend='cv2'), + dict(type=PackInputs) +] + +# test config for Set5 +set5_data_root = 'data/Set5' +set5_dataloader = dict( + num_workers=4, + persistent_workers=False, + drop_last=False, + sampler=dict(type=DefaultSampler, shuffle=False), + dataset=dict( + type=BasicImageDataset, + metainfo=dict(dataset_type='set5', task_name='sisr'), + data_root=set5_data_root, + data_prefix=dict(img='LRbicx4', gt='GTmod12'), + pipeline=test_pipeline)) +set5_evaluator = dict( + type=Evaluator, + metrics=[ + dict(type=PSNR, crop_border=4, prefix='Set5'), + dict(type=SSIM, crop_border=4, prefix='Set5'), + ]) + +set14_data_root = 'data/Set14' +set14_dataloader = dict( + num_workers=4, + persistent_workers=False, + drop_last=False, + sampler=dict(type=DefaultSampler, shuffle=False), + dataset=dict( + type=BasicImageDataset, + metainfo=dict(dataset_type='set14', task_name='sisr'), + data_root=set14_data_root, + data_prefix=dict(img='LRbicx4', gt='GTmod12'), + pipeline=test_pipeline)) +set14_evaluator = dict( + type=Evaluator, + metrics=[ + dict(type=PSNR, crop_border=4, prefix='Set14'), + dict(type=SSIM, crop_border=4, prefix='Set14'), + ]) + +# test config for DIV2K +div2k_data_root = 'data/DIV2K' +div2k_dataloader = dict( + num_workers=4, + persistent_workers=False, + drop_last=False, + sampler=dict(type=DefaultSampler, shuffle=False), + dataset=dict( + type=BasicImageDataset, + ann_file='meta_info_DIV2K100sub_GT.txt', + metainfo=dict(dataset_type='div2k', task_name='sisr'), + data_root=div2k_data_root, + data_prefix=dict( + img='DIV2K_train_LR_bicubic/X4_sub', gt='DIV2K_train_HR_sub'), + pipeline=test_pipeline)) +div2k_evaluator = dict( + type=Evaluator, + metrics=[ + dict(type=PSNR, crop_border=4, prefix='DIV2K'), + dict(type=SSIM, crop_border=4, prefix='DIV2K'), + ]) + +# test config +test_cfg = dict(type=MultiTestLoop) +test_dataloader = [ + set5_dataloader, + set14_dataloader, + div2k_dataloader, +] +test_evaluator = [ + set5_evaluator, + set14_evaluator, + div2k_evaluator, +] diff --git a/mmagic/configs/_base_/datasets/tdan_test_config.py b/mmagic/configs/_base_/datasets/tdan_test_config.py new file mode 100644 index 000000000..1794cc040 --- /dev/null +++ b/mmagic/configs/_base_/datasets/tdan_test_config.py @@ -0,0 +1,130 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmengine.dataset import DefaultSampler + +from mmagic.datasets import BasicFramesDataset +from mmagic.datasets.transforms import (GenerateFrameIndiceswithPadding, + GenerateSegmentIndices, + LoadImageFromFile, PackInputs) +from mmagic.engine.runner import MultiTestLoop +from mmagic.evaluation import PSNR, SSIM, Evaluator + +# configs for SPMCS-30 +SPMC_data_root = 'data/SPMCS' + +SPMC_pipeline = [ + dict(type=GenerateFrameIndiceswithPadding, padding='reflection'), + dict(type=LoadImageFromFile, key='img', channel_order='rgb'), + dict(type=LoadImageFromFile, key='gt', channel_order='rgb'), + dict(type=PackInputs) +] + +SPMC_bd_dataloader = dict( + num_workers=1, + batch_size=1, + persistent_workers=False, + sampler=dict(type=DefaultSampler, shuffle=False), + dataset=dict( + type=BasicFramesDataset, + metainfo=dict(dataset_type='spmcs', task_name='vsr'), + data_root=SPMC_data_root, + data_prefix=dict(img='BDx4', gt='GT'), + ann_file='meta_info_SPMCS_GT.txt', + depth=2, + num_input_frames=5, + pipeline=SPMC_pipeline)) + +SPMC_bi_dataloader = dict( + num_workers=1, + batch_size=1, + persistent_workers=False, + sampler=dict(type=DefaultSampler, shuffle=False), + dataset=dict( + type=BasicFramesDataset, + metainfo=dict(dataset_type='spmcs', task_name='vsr'), + data_root=SPMC_data_root, + data_prefix=dict(img='BIx4', gt='GT'), + ann_file='meta_info_SPMCS_GT.txt', + depth=2, + num_input_frames=5, + pipeline=SPMC_pipeline)) + +SPMC_bd_evaluator = dict( + type=Evaluator, + metrics=[ + dict(type=PSNR, crop_border=8, convert_to='Y', prefix='SPMCS-BDx4-Y'), + dict(type=SSIM, crop_border=8, convert_to='Y', prefix='SPMCS-BDx4-Y'), + ]) +SPMC_bi_evaluator = dict( + type=Evaluator, + metrics=[ + dict(type=PSNR, crop_border=8, convert_to='Y', prefix='SPMCS-BIx4-Y'), + dict(type=SSIM, crop_border=8, convert_to='Y', prefix='SPMCS-BIx4-Y'), + ]) + +# config for vid4 +vid4_data_root = 'data/Vid4' + +vid4_pipeline = [ + # dict(type=GenerateSegmentIndices, interval_list=[1]), + dict(type=GenerateFrameIndiceswithPadding, padding='reflection'), + dict(type=LoadImageFromFile, key='img', channel_order='rgb'), + dict(type=LoadImageFromFile, key='gt', channel_order='rgb'), + dict(type=PackInputs) +] +vid4_bd_dataloader = dict( + num_workers=1, + batch_size=1, + persistent_workers=False, + sampler=dict(type=DefaultSampler, shuffle=False), + dataset=dict( + type=BasicFramesDataset, + metainfo=dict(dataset_type='vid4', task_name='vsr'), + data_root=vid4_data_root, + data_prefix=dict(img='BDx4', gt='GT'), + ann_file='meta_info_Vid4_GT.txt', + depth=2, + num_input_frames=5, + pipeline=vid4_pipeline)) + +vid4_bi_dataloader = dict( + num_workers=1, + batch_size=1, + persistent_workers=False, + sampler=dict(type=DefaultSampler, shuffle=False), + dataset=dict( + type=BasicFramesDataset, + metainfo=dict(dataset_type='vid4', task_name='vsr'), + data_root=vid4_data_root, + data_prefix=dict(img='BIx4', gt='GT'), + ann_file='meta_info_Vid4_GT.txt', + depth=2, + num_input_frames=5, + pipeline=vid4_pipeline)) + +vid4_bd_evaluator = dict( + type=Evaluator, + metrics=[ + dict(type=PSNR, convert_to='Y', prefix='VID4-BDx4-Y'), + dict(type=SSIM, convert_to='Y', prefix='VID4-BDx4-Y'), + ]) +vid4_bi_evaluator = dict( + type=Evaluator, + metrics=[ + dict(type=PSNR, convert_to='Y', prefix='VID4-BIx4-Y'), + dict(type=SSIM, convert_to='Y', prefix='VID4-BIx4-Y'), + ]) + +# config for test +test_cfg = dict(type=MultiTestLoop) +test_dataloader = [ + SPMC_bd_dataloader, + SPMC_bi_dataloader, + vid4_bd_dataloader, + vid4_bi_dataloader, +] +test_evaluator = [ + SPMC_bd_evaluator, + SPMC_bi_evaluator, + vid4_bd_evaluator, + vid4_bi_evaluator, +] diff --git a/mmagic/configs/_base_/datasets/unconditional_imgs_128x128.py b/mmagic/configs/_base_/datasets/unconditional_imgs_128x128.py new file mode 100644 index 000000000..23126d0e0 --- /dev/null +++ b/mmagic/configs/_base_/datasets/unconditional_imgs_128x128.py @@ -0,0 +1,47 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmengine.dataset import DefaultSampler, InfiniteSampler + +from mmagic.datasets.transforms import LoadImageFromFile, PackInputs, Resize + +# dataset_type = 'BasicImageDataset' +dataset_type = 'BasicImageDataset' + +train_pipeline = [ + dict(type=LoadImageFromFile, key='gt'), + dict(type=Resize, keys='gt', scale=(128, 128)), + dict(type=PackInputs) +] + +# `batch_size` and `data_root` need to be set. +train_dataloader = dict( + batch_size=None, + num_workers=4, + persistent_workers=True, + sampler=dict(type=InfiniteSampler, shuffle=True), + dataset=dict( + type=dataset_type, + data_prefix=dict(gt=''), + data_root=None, # set by user + pipeline=train_pipeline)) + +val_dataloader = dict( + batch_size=None, + num_workers=4, + dataset=dict( + type=dataset_type, + data_prefix=dict(gt=''), + data_root=None, # set by user + pipeline=train_pipeline), + sampler=dict(type=DefaultSampler, shuffle=False), + persistent_workers=True) + +test_dataloader = dict( + batch_size=None, + num_workers=4, + dataset=dict( + type=dataset_type, + data_prefix=dict(gt=''), + data_root=None, # set by user + pipeline=train_pipeline), + sampler=dict(type=DefaultSampler, shuffle=False), + persistent_workers=True) diff --git a/mmagic/configs/_base_/datasets/unconditional_imgs_64x64.py b/mmagic/configs/_base_/datasets/unconditional_imgs_64x64.py new file mode 100644 index 000000000..def398d1d --- /dev/null +++ b/mmagic/configs/_base_/datasets/unconditional_imgs_64x64.py @@ -0,0 +1,46 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmengine.dataset import DefaultSampler, InfiniteSampler + +from mmagic.datasets.transforms import LoadImageFromFile, PackInputs, Resize + +dataset_type = 'BasicImageDataset' + +train_pipeline = [ + dict(type=LoadImageFromFile, key='gt'), + dict(type=Resize, keys='gt', scale=(64, 64)), + dict(type=PackInputs) +] + +# `batch_size` and `data_root` need to be set. +train_dataloader = dict( + batch_size=None, + num_workers=4, + persistent_workers=True, + sampler=dict(type=InfiniteSampler, shuffle=True), + dataset=dict( + type=dataset_type, + data_prefix=dict(gt=''), + data_root=None, # set by user + pipeline=train_pipeline)) + +val_dataloader = dict( + batch_size=None, + num_workers=4, + dataset=dict( + type=dataset_type, + data_prefix=dict(gt=''), + data_root=None, # set by user + pipeline=train_pipeline), + sampler=dict(type=DefaultSampler, shuffle=False), + persistent_workers=True) + +test_dataloader = dict( + batch_size=None, + num_workers=4, + dataset=dict( + type=dataset_type, + data_prefix=dict(gt=''), + data_root=None, # set by user + pipeline=train_pipeline), + sampler=dict(type=DefaultSampler, shuffle=False), + persistent_workers=True) diff --git a/mmagic/configs/_base_/datasets/unpaired_imgs_256x256.py b/mmagic/configs/_base_/datasets/unpaired_imgs_256x256.py new file mode 100644 index 000000000..38de740e8 --- /dev/null +++ b/mmagic/configs/_base_/datasets/unpaired_imgs_256x256.py @@ -0,0 +1,103 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmengine.dataset import DefaultSampler, InfiniteSampler + +from mmagic.datasets.transforms import Crop, Flip, PackInputs, Resize + +dataset_type = 'UnpairedImageDataset' +domain_a = None # set by user +domain_b = None # set by user +train_pipeline = [ + dict(type='LoadImageFromFile', key='img_A', color_type='color'), + dict(type='LoadImageFromFile', key='img_B', color_type='color'), + dict( + type='TransformBroadcaster', + mapping={'img': ['img_A', 'img_B']}, + auto_remap=True, + share_random_params=True, + transforms=[ + dict(type=Resize, scale=(286, 286), interpolation='bicubic'), + dict( + type=Crop, + keys=['img'], + crop_size=(256, 256), + random_crop=True), + ]), + dict(type=Flip, keys=['img_A'], direction='horizontal'), + dict(type=Flip, keys=['img_B'], direction='horizontal'), + # NOTE: users should implement their own keyMapper and Pack operation + # dict( + # type='KeyMapper', + # mapping={ + # f'img_{domain_a}': 'img_A', + # f'img_{domain_b}': 'img_B' + # }, + # remapping={ + # f'img_{domain_a}': f'img_{domain_a}', + # f'img_{domain_b}': f'img_{domain_b}' + # }), + # dict( + # type=PackInputs, + # keys=[f'img_{domain_a}', f'img_{domain_b}'], + # data_keys=[f'img_{domain_a}', f'img_{domain_b}']) +] + +test_pipeline = [ + dict(type='LoadImageFromFile', key='img_A', color_type='color'), + dict(type='LoadImageFromFile', key='img_B', color_type='color'), + dict( + type='TransformBroadcaster', + mapping={'img': ['img_A', 'img_B']}, + auto_remap=True, + share_random_params=True, + transforms=dict( + type=Resize, scale=(256, 256), interpolation='bicubic'), + ), + # NOTE: users should implement their own keyMapper and Pack operation + # dict( + # type='KeyMapper', + # mapping={ + # f'img_{domain_a}': 'img_A', + # f'img_{domain_b}': 'img_B' + # }, + # remapping={ + # f'img_{domain_a}': f'img_{domain_a}', + # f'img_{domain_b}': f'img_{domain_b}' + # }), + # dict( + # type=PackInputs, + # keys=[f'img_{domain_a}', f'img_{domain_b}'], + # data_keys=[f'img_{domain_a}', f'img_{domain_b}']) +] + +# `batch_size` and `data_root` need to be set. +train_dataloader = dict( + batch_size=4, + num_workers=4, + persistent_workers=True, + sampler=dict(type=InfiniteSampler, shuffle=True), + dataset=dict( + type=dataset_type, + data_root=None, # set by user + pipeline=train_pipeline)) + +val_dataloader = dict( + batch_size=4, + num_workers=4, + dataset=dict( + type=dataset_type, + data_root=None, # set by user + test_mode=True, + pipeline=test_pipeline), + sampler=dict(type=DefaultSampler, shuffle=False), + persistent_workers=True) + +test_dataloader = dict( + batch_size=4, + num_workers=4, + dataset=dict( + type=dataset_type, + data_root=None, # set by user + test_mode=True, + pipeline=test_pipeline), + sampler=dict(type=DefaultSampler, shuffle=False), + persistent_workers=True) diff --git a/mmagic/configs/_base_/inpaint_default_runtime.py b/mmagic/configs/_base_/inpaint_default_runtime.py new file mode 100644 index 000000000..6d776e5c1 --- /dev/null +++ b/mmagic/configs/_base_/inpaint_default_runtime.py @@ -0,0 +1,43 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook, + LoggerHook, ParamSchedulerHook) +from mmengine.runner import LogProcessor +from mmengine.visualization import LocalVisBackend + +from mmagic.engine.hooks import BasicVisualizationHook +from mmagic.visualization import ConcatImageVisualizer + +default_scope = 'mmagic' +save_dir = './work_dirs' + +default_hooks = dict( + timer=dict(type=IterTimerHook), + logger=dict(type=LoggerHook, interval=100), + param_scheduler=dict(type=ParamSchedulerHook), + checkpoint=dict( + type=CheckpointHook, interval=50000, by_epoch=False, out_dir=save_dir), + sampler_seed=dict(type=DistSamplerSeedHook), +) + +env_cfg = dict( + cudnn_benchmark=False, + mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0), + dist_cfg=dict(backend='nccl'), +) + +vis_backends = [dict(type=LocalVisBackend)] +visualizer = dict( + type=ConcatImageVisualizer, + vis_backends=vis_backends, + fn_key='gt_path', + img_keys=['gt_img', 'input', 'pred_img'], + bgr2rgb=True) +custom_hooks = [dict(type=BasicVisualizationHook, interval=1)] + +log_level = 'INFO' +log_processor = dict(type=LogProcessor, by_epoch=False) + +load_from = None +resume = False + +# TODO: support auto scaling lr diff --git a/mmagic/configs/_base_/matting_default_runtime.py b/mmagic/configs/_base_/matting_default_runtime.py new file mode 100644 index 000000000..07344cbc0 --- /dev/null +++ b/mmagic/configs/_base_/matting_default_runtime.py @@ -0,0 +1,43 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook, + LoggerHook, ParamSchedulerHook) +from mmengine.runner import LogProcessor +from mmengine.visualization import LocalVisBackend + +from mmagic.engine.hooks import BasicVisualizationHook +from mmagic.visualization import ConcatImageVisualizer + +default_scope = 'mmagic' +save_dir = './work_dirs' + +default_hooks = dict( + timer=dict(type=IterTimerHook), + logger=dict(type=LoggerHook, interval=100), + param_scheduler=dict(type=ParamSchedulerHook), + checkpoint=dict( + type=CheckpointHook, interval=10000, by_epoch=False, out_dir=save_dir), + sampler_seed=dict(type=DistSamplerSeedHook), +) + +env_cfg = dict( + cudnn_benchmark=False, + mp_cfg=dict(mp_start_method='fork', opencv_num_threads=4), + dist_cfg=dict(backend='nccl'), +) + +vis_backends = [dict(type=LocalVisBackend)] +visualizer = dict( + type=ConcatImageVisualizer, + vis_backends=vis_backends, + fn_key='trimap_path', + img_keys=['pred_alpha', 'trimap', 'gt_merged', 'gt_alpha'], + bgr2rgb=True) +custom_hooks = [dict(type=BasicVisualizationHook, interval=2000)] + +log_level = 'INFO' +log_processor = dict(type=LogProcessor, by_epoch=False) + +load_from = None +resume = False + +# TODO: support auto scaling lr diff --git a/mmagic/configs/_base_/models/base_cyclegan.py b/mmagic/configs/_base_/models/base_cyclegan.py new file mode 100644 index 000000000..fbaa8a0d2 --- /dev/null +++ b/mmagic/configs/_base_/models/base_cyclegan.py @@ -0,0 +1,32 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmagic.models import DataPreprocessor +from mmagic.models.archs import PatchDiscriminator +from mmagic.models.editors import CycleGAN +from mmagic.models.editors.cyclegan import ResnetGenerator + +_domain_a = None # set by user +_domain_b = None # set by user +model = dict( + type=CycleGAN, + data_preprocessor=dict(type=DataPreprocessor), + generator=dict( + type=ResnetGenerator, + in_channels=3, + out_channels=3, + base_channels=64, + norm_cfg=dict(type='IN'), + use_dropout=False, + num_blocks=9, + padding_mode='reflect', + init_cfg=dict(type='normal', gain=0.02)), + discriminator=dict( + type=PatchDiscriminator, + in_channels=3, + base_channels=64, + num_conv=3, + norm_cfg=dict(type='IN'), + init_cfg=dict(type='normal', gain=0.02)), + default_domain=None, # set by user + reachable_domains=None, # set by user + related_domains=None # set by user +) diff --git a/mmagic/configs/_base_/models/base_deepfillv1.py b/mmagic/configs/_base_/models/base_deepfillv1.py new file mode 100644 index 000000000..7588597a4 --- /dev/null +++ b/mmagic/configs/_base_/models/base_deepfillv1.py @@ -0,0 +1,106 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmengine.model import MMSeparateDistributedDataParallel +from mmengine.optim import OptimWrapper + +from mmagic.models import DataPreprocessor +from mmagic.models.archs import MultiLayerDiscriminator +from mmagic.models.editors import (ContextualAttentionNeck, DeepFillDecoder, + DeepFillEncoder, DeepFillEncoderDecoder, + DeepFillRefiner, DeepFillv1Discriminators, + DeepFillv1Inpaintor, GLDilationNeck, + GLEncoderDecoder) +from mmagic.models.losses import (DiscShiftLoss, GANLoss, GradientPenaltyLoss, + L1Loss) + +# DistributedDataParallel +model_wrapper_cfg = dict(type=MMSeparateDistributedDataParallel) + +model = dict( + type=DeepFillv1Inpaintor, + data_preprocessor=dict( + type=DataPreprocessor, + mean=[127.5], + std=[127.5], + ), + encdec=dict( + type=DeepFillEncoderDecoder, + stage1=dict( + type=GLEncoderDecoder, + encoder=dict(type=DeepFillEncoder, padding_mode='reflect'), + decoder=dict( + type=DeepFillDecoder, in_channels=128, padding_mode='reflect'), + dilation_neck=dict( + type=GLDilationNeck, + in_channels=128, + act_cfg=dict(type='ELU'), + padding_mode='reflect')), + stage2=dict( + type=DeepFillRefiner, + encoder_attention=dict( + type=DeepFillEncoder, + encoder_type='stage2_attention', + padding_mode='reflect'), + encoder_conv=dict( + type=DeepFillEncoder, + encoder_type='stage2_conv', + padding_mode='reflect'), + dilation_neck=dict( + type=GLDilationNeck, + in_channels=128, + act_cfg=dict(type='ELU'), + padding_mode='reflect'), + contextual_attention=dict( + type=ContextualAttentionNeck, + in_channels=128, + padding_mode='reflect'), + decoder=dict( + type=DeepFillDecoder, in_channels=256, + padding_mode='reflect'))), + disc=dict( + type=DeepFillv1Discriminators, + global_disc_cfg=dict( + type=MultiLayerDiscriminator, + in_channels=3, + max_channels=256, + fc_in_channels=256 * 16 * 16, + fc_out_channels=1, + num_convs=4, + norm_cfg=None, + act_cfg=dict(type='ELU'), + out_act_cfg=dict(type='LeakyReLU', negative_slope=0.2)), + local_disc_cfg=dict( + type=MultiLayerDiscriminator, + in_channels=3, + max_channels=512, + fc_in_channels=512 * 8 * 8, + fc_out_channels=1, + num_convs=4, + norm_cfg=None, + act_cfg=dict(type='ELU'), + out_act_cfg=dict(type='LeakyReLU', negative_slope=0.2))), + stage1_loss_type=('loss_l1_hole', 'loss_l1_valid'), + stage2_loss_type=('loss_l1_hole', 'loss_l1_valid', 'loss_gan'), + loss_gan=dict( + type=GANLoss, + gan_type='wgan', + loss_weight=0.0001, + ), + loss_l1_hole=dict( + type=L1Loss, + loss_weight=1.0, + ), + loss_l1_valid=dict( + type=L1Loss, + loss_weight=1.0, + ), + loss_gp=dict(type=GradientPenaltyLoss, loss_weight=10.), + loss_disc_shift=dict(type=DiscShiftLoss, loss_weight=0.001)) + +# optimizer +optim_wrapper = dict( + constructor='MultiOptimWrapperConstructor', + generator=dict(type=OptimWrapper, optimizer=dict(type='Adam', lr=0.0001)), + disc=dict(type=OptimWrapper, optimizer=dict(type='Adam', lr=0.0001))) + +# learning policy +# Fixed diff --git a/mmagic/configs/_base_/models/base_deepfillv2.py b/mmagic/configs/_base_/models/base_deepfillv2.py new file mode 100644 index 000000000..0ed601203 --- /dev/null +++ b/mmagic/configs/_base_/models/base_deepfillv2.py @@ -0,0 +1,113 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmengine.model import MMSeparateDistributedDataParallel +from mmengine.optim import OptimWrapper + +from mmagic.models import DataPreprocessor +from mmagic.models.archs import MultiLayerDiscriminator +from mmagic.models.base_models import TwoStageInpaintor +from mmagic.models.editors import (ContextualAttentionNeck, DeepFillDecoder, + DeepFillEncoder, DeepFillEncoderDecoder, + DeepFillRefiner, GLDilationNeck, + GLEncoderDecoder) +from mmagic.models.losses import GANLoss, L1Loss + +# DistributedDataParallel +model_wrapper_cfg = dict(type=MMSeparateDistributedDataParallel) + +model = dict( + type=TwoStageInpaintor, + disc_input_with_mask=True, + data_preprocessor=dict( + type=DataPreprocessor, + mean=[127.5], + std=[127.5], + ), + encdec=dict( + type=DeepFillEncoderDecoder, + stage1=dict( + type=GLEncoderDecoder, + encoder=dict( + type=DeepFillEncoder, + conv_type='gated_conv', + channel_factor=0.75, + padding_mode='reflect'), + decoder=dict( + type=DeepFillDecoder, + conv_type='gated_conv', + in_channels=96, + channel_factor=0.75, + out_act_cfg=dict(type='Tanh'), + padding_mode='reflect'), + dilation_neck=dict( + type=GLDilationNeck, + in_channels=96, + conv_type='gated_conv', + act_cfg=dict(type='ELU'), + padding_mode='reflect')), + stage2=dict( + type=DeepFillRefiner, + encoder_attention=dict( + type=DeepFillEncoder, + encoder_type='stage2_attention', + conv_type='gated_conv', + channel_factor=0.75, + padding_mode='reflect'), + encoder_conv=dict( + type=DeepFillEncoder, + encoder_type='stage2_conv', + conv_type='gated_conv', + channel_factor=0.75, + padding_mode='reflect'), + dilation_neck=dict( + type=GLDilationNeck, + in_channels=96, + conv_type='gated_conv', + act_cfg=dict(type='ELU'), + padding_mode='reflect'), + contextual_attention=dict( + type=ContextualAttentionNeck, + in_channels=96, + conv_type='gated_conv', + padding_mode='reflect'), + decoder=dict( + type=DeepFillDecoder, + in_channels=192, + conv_type='gated_conv', + out_act_cfg=dict(type='Tanh'), + padding_mode='reflect'))), + disc=dict( + type=MultiLayerDiscriminator, + in_channels=4, + max_channels=256, + fc_in_channels=None, + num_convs=6, + norm_cfg=None, + act_cfg=dict(type='LeakyReLU', negative_slope=0.2), + out_act_cfg=dict(type='LeakyReLU', negative_slope=0.2), + with_spectral_norm=True, + ), + stage1_loss_type=('loss_l1_hole', 'loss_l1_valid'), + stage2_loss_type=('loss_l1_hole', 'loss_l1_valid', 'loss_gan'), + loss_gan=dict( + type=GANLoss, + gan_type='hinge', + loss_weight=0.1, + ), + loss_l1_hole=dict( + type=L1Loss, + loss_weight=1.0, + ), + loss_l1_valid=dict( + type=L1Loss, + loss_weight=1.0, + ), +) + +# optimizer +optim_wrapper = dict( + constructor='MultiOptimWrapperConstructor', + generator=dict(type=OptimWrapper, optimizer=dict(type='Adam', lr=0.0001)), + disc=dict(type=OptimWrapper, optimizer=dict(type='Adam', lr=0.0001))) + +# learning policy +# Fixed diff --git a/mmagic/configs/_base_/models/base_edvr.py b/mmagic/configs/_base_/models/base_edvr.py new file mode 100644 index 000000000..0ae9a49ea --- /dev/null +++ b/mmagic/configs/_base_/models/base_edvr.py @@ -0,0 +1,128 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmengine.dataset import DefaultSampler, InfiniteSampler +from mmengine.hooks import CheckpointHook +from mmengine.optim import OptimWrapper +from mmengine.runner import IterBasedTrainLoop + +from mmagic.datasets import BasicFramesDataset +from mmagic.datasets.transforms import (Flip, GenerateFrameIndices, + GenerateFrameIndiceswithPadding, + GenerateSegmentIndices, + LoadImageFromFile, PackInputs, + PairedRandomCrop, RandomTransposeHW, + SetValues, TemporalReverse) +from mmagic.engine.runner import MultiTestLoop, MultiValLoop +from mmagic.evaluation import PSNR, SSIM + +_base_ = '../default_runtime.py' + +scale = 4 + +train_pipeline = [ + dict(type=GenerateFrameIndices, interval_list=[1], frames_per_clip=99), + dict(type=TemporalReverse, keys='img_path', reverse_ratio=0), + dict( + type=LoadImageFromFile, + key='img', + color_type='color', + channel_order='rgb'), + dict( + type=LoadImageFromFile, + key='gt', + color_type='color', + channel_order='rgb'), + dict(type=SetValues, dictionary=dict(scale=scale)), + dict(type=PairedRandomCrop, gt_patch_size=256), + dict( + type=Flip, keys=['img', 'gt'], flip_ratio=0.5, direction='horizontal'), + dict(type=Flip, keys=['img', 'gt'], flip_ratio=0.5, direction='vertical'), + dict(type=RandomTransposeHW, keys=['img', 'gt'], transpose_ratio=0.5), + dict(type=PackInputs) +] + +val_pipeline = [ + dict(type=GenerateFrameIndiceswithPadding, padding='reflection_circle'), + dict( + type=LoadImageFromFile, + key='img', + color_type='color', + channel_order='rgb'), + dict( + type=LoadImageFromFile, + key='gt', + color_type='color', + channel_order='rgb'), + dict(type=PackInputs) +] + +demo_pipeline = [ + dict(type=GenerateSegmentIndices, interval_list=[1]), + dict( + type=LoadImageFromFile, + key='img', + color_type='color', + channel_order='rgb'), + dict(type=PackInputs) +] + +data_root = 'data/REDS' +save_dir = './work_dirs' + +train_dataloader = dict( + num_workers=8, + batch_size=8, + persistent_workers=False, + sampler=dict(type=InfiniteSampler, shuffle=True), + dataset=dict( + type=BasicFramesDataset, + metainfo=dict(dataset_type='reds_reds4', task_name='vsr'), + data_root=data_root, + data_prefix=dict(img='train_sharp_bicubic/X4', gt='train_sharp'), + ann_file='meta_info_reds4_train.txt', + depth=2, + num_input_frames=5, + num_output_frames=1, + pipeline=train_pipeline)) + +val_dataloader = dict( + num_workers=1, + batch_size=1, + persistent_workers=False, + sampler=dict(type=DefaultSampler, shuffle=False), + dataset=dict( + type=BasicFramesDataset, + metainfo=dict(dataset_type='reds_reds4', task_name='vsr'), + data_root=data_root, + data_prefix=dict(img='train_sharp_bicubic/X4', gt='train_sharp'), + ann_file='meta_info_reds4_val.txt', + depth=2, + num_input_frames=5, + num_output_frames=1, + pipeline=val_pipeline)) + +test_dataloader = val_dataloader + +val_evaluator = [ + dict(type=PSNR), + dict(type=SSIM), +] +test_evaluator = val_evaluator + +train_cfg = dict(type=IterBasedTrainLoop, max_iters=600_000, val_interval=5000) +val_cfg = dict(type=MultiValLoop) +test_cfg = dict(type=MultiTestLoop) + +# optimizer +optim_wrapper = dict( + constructor='DefaultOptimWrapperConstructor', + type=OptimWrapper, + optimizer=dict(type='Adam', lr=2e-4, betas=(0.9, 0.999)), +) + +default_hooks = dict( + checkpoint=dict( + type=CheckpointHook, + interval=5000, + save_optimizer=True, + out_dir=save_dir, + by_epoch=False)) diff --git a/mmagic/configs/_base_/models/base_gl.py b/mmagic/configs/_base_/models/base_gl.py new file mode 100644 index 000000000..30cb72db2 --- /dev/null +++ b/mmagic/configs/_base_/models/base_gl.py @@ -0,0 +1,62 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmengine.model import MMSeparateDistributedDataParallel +from mmengine.optim import OptimWrapper + +from mmagic.models import DataPreprocessor +from mmagic.models.editors import (GLDecoder, GLDilationNeck, GLEncoder, + GLEncoderDecoder) +from mmagic.models.editors.global_local import GLDiscs, GLInpaintor +from mmagic.models.losses import GANLoss, L1Loss + +# DistributedDataParallel +model_wrapper_cfg = dict(type=MMSeparateDistributedDataParallel) + +model = dict( + type=GLInpaintor, + data_preprocessor=dict( + type=DataPreprocessor, + mean=[127.5], + std=[127.5], + ), + encdec=dict( + type=GLEncoderDecoder, + encoder=dict(type=GLEncoder, norm_cfg=dict(type='SyncBN')), + decoder=dict(type=GLDecoder, norm_cfg=dict(type='SyncBN')), + dilation_neck=dict(type=GLDilationNeck, norm_cfg=dict(type='SyncBN'))), + disc=dict( + type=GLDiscs, + global_disc_cfg=dict( + in_channels=3, + max_channels=512, + fc_in_channels=512 * 4 * 4, + fc_out_channels=1024, + num_convs=6, + norm_cfg=dict(type='SyncBN'), + ), + local_disc_cfg=dict( + in_channels=3, + max_channels=512, + fc_in_channels=512 * 4 * 4, + fc_out_channels=1024, + num_convs=5, + norm_cfg=dict(type='SyncBN'), + ), + ), + loss_gan=dict( + type=GANLoss, + gan_type='vanilla', + loss_weight=0.001, + ), + loss_l1_hole=dict( + type=L1Loss, + loss_weight=1.0, + )) + +# optimizer +optim_wrapper = dict( + constructor='MultiOptimWrapperConstructor', + generator=dict(type=OptimWrapper, optimizer=dict(type='Adam', lr=0.0004)), + disc=dict(type=OptimWrapper, optimizer=dict(type='Adam', lr=0.0004))) + +# learning policy +# Fixed diff --git a/mmagic/configs/_base_/models/base_glean.py b/mmagic/configs/_base_/models/base_glean.py new file mode 100644 index 000000000..5a2faed3f --- /dev/null +++ b/mmagic/configs/_base_/models/base_glean.py @@ -0,0 +1,58 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook, + LoggerHook, ParamSchedulerHook) +from mmengine.model import MMSeparateDistributedDataParallel +from mmengine.optim import CosineAnnealingLR, OptimWrapper +from mmengine.runner import IterBasedTrainLoop + +from mmagic.engine.runner import MultiTestLoop, MultiValLoop +from mmagic.evaluation import MAE, PSNR, SSIM + +_base_ = '../default_runtime.py' + +# DistributedDataParallel +model_wrapper_cfg = dict( + type=MMSeparateDistributedDataParallel, find_unused_parameters=True) + +save_dir = './work_dirs' + +val_evaluator = [ + dict(type=MAE), + dict(type=PSNR), + dict(type=SSIM), +] +test_evaluator = val_evaluator + +train_cfg = dict(type=IterBasedTrainLoop, max_iters=300_000, val_interval=5000) +val_cfg = dict(type=MultiValLoop) +test_cfg = dict(type=MultiTestLoop) + +# optimizer +optim_wrapper = dict( + constructor='MultiOptimWrapperConstructor', + generator=dict( + type=OptimWrapper, + optimizer=dict(type='Adam', lr=1e-4, betas=(0.9, 0.99))), + discriminator=dict( + type=OptimWrapper, + optimizer=dict(type='Adam', lr=1e-4, betas=(0.9, 0.99))), +) + +# learning policy +param_scheduler = dict( + type=CosineAnnealingLR, by_epoch=False, T_max=600_000, eta_min=1e-7) + +default_hooks = dict( + checkpoint=dict( + type=CheckpointHook, + interval=5000, + save_optimizer=True, + by_epoch=False, + out_dir=save_dir, + save_best=['MAE', 'PSNR', 'SSIM'], + rule=['less', 'greater', 'greater']), + timer=dict(type=IterTimerHook), + logger=dict(type=LoggerHook, interval=100), + param_scheduler=dict(type=ParamSchedulerHook), + sampler_seed=dict(type=DistSamplerSeedHook), +) diff --git a/mmagic/configs/_base_/models/base_liif.py b/mmagic/configs/_base_/models/base_liif.py new file mode 100644 index 000000000..78cd4c05b --- /dev/null +++ b/mmagic/configs/_base_/models/base_liif.py @@ -0,0 +1,135 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmengine.dataset import DefaultSampler, InfiniteSampler +from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook, + LoggerHook, ParamSchedulerHook) +from mmengine.optim import MultiStepLR, OptimWrapper +from mmengine.runner import IterBasedTrainLoop + +from mmagic.datasets.transforms import (Flip, GenerateCoordinateAndCell, + LoadImageFromFile, PackInputs, + RandomDownSampling, RandomTransposeHW) +from mmagic.engine.runner import MultiValLoop +from mmagic.evaluation import MAE, PSNR, SSIM, Evaluator + +_base_ = '../default_runtime.py' +work_dir = './work_dirs/liif' +save_dir = './work_dirs' + +scale_min, scale_max = 1, 4 +scale_test = 4 + +train_pipeline = [ + dict( + type=LoadImageFromFile, + key='gt', + color_type='color', + channel_order='rgb', + imdecode_backend='cv2'), + dict( + type=RandomDownSampling, + scale_min=scale_min, + scale_max=scale_max, + patch_size=48), + dict( + type=Flip, keys=['img', 'gt'], flip_ratio=0.5, direction='horizontal'), + dict(type=Flip, keys=['img', 'gt'], flip_ratio=0.5, direction='vertical'), + dict(type=RandomTransposeHW, keys=['img', 'gt'], transpose_ratio=0.5), + dict(type=GenerateCoordinateAndCell, sample_quantity=2304), + dict(type=PackInputs) +] +val_pipeline = [ + dict( + type=LoadImageFromFile, + key='gt', + color_type='color', + channel_order='rgb', + imdecode_backend='cv2'), + dict(type=RandomDownSampling, scale_min=scale_max, scale_max=scale_max), + dict(type=GenerateCoordinateAndCell, reshape_gt=False), + dict(type=PackInputs) +] +# test_pipeline = [ +# dict( +# type=LoadImageFromFile, +# key='gt', +# color_type='color', +# channel_order='rgb', +# imdecode_backend='cv2'), +# dict( +# type=LoadImageFromFile, +# key='img', +# color_type='color', +# channel_order='rgb', +# imdecode_backend='cv2'), +# dict(type=GenerateCoordinateAndCell, scale=scale_test, +# reshape_gt=False), +# dict(type=PackInputs) +# ] + +# dataset settings +dataset_type = 'BasicImageDataset' +data_root = 'data' + +train_dataloader = dict( + num_workers=8, + batch_size=16, + persistent_workers=False, + sampler=dict(type=InfiniteSampler, shuffle=True), + dataset=dict( + type=dataset_type, + ann_file='meta_info_DIV2K800sub_GT.txt', + metainfo=dict(dataset_type='div2k', task_name='sisr'), + data_root=data_root + '/DIV2K', + data_prefix=dict(gt='DIV2K_train_HR_sub'), + pipeline=train_pipeline)) + +val_dataloader = dict( + num_workers=4, + persistent_workers=False, + drop_last=False, + sampler=dict(type=DefaultSampler, shuffle=False), + dataset=dict( + type=dataset_type, + metainfo=dict(dataset_type='set5', task_name='sisr'), + data_root=data_root + '/Set5', + data_prefix=dict(img='LRbicx4', gt='GTmod12'), + pipeline=val_pipeline)) + +val_evaluator = dict( + type=Evaluator, + metrics=[ + dict(type=MAE), + dict(type=PSNR, crop_border=scale_max), + dict(type=SSIM, crop_border=scale_max), + ]) + +train_cfg = dict( + type=IterBasedTrainLoop, max_iters=1_000_000, val_interval=3000) +val_cfg = dict(type=MultiValLoop) + +# optimizer +optim_wrapper = dict( + constructor='DefaultOptimWrapperConstructor', + type=OptimWrapper, + optimizer=dict(type='Adam', lr=1e-4)) + +# learning policy +param_scheduler = dict( + type=MultiStepLR, + by_epoch=False, + milestones=[200_000, 400_000, 600_000, 800_000], + gamma=0.5) + +default_hooks = dict( + checkpoint=dict( + type=CheckpointHook, + interval=3000, + save_optimizer=True, + by_epoch=False, + out_dir=save_dir, + ), + timer=dict(type=IterTimerHook), + logger=dict(type=LoggerHook, interval=100), + param_scheduler=dict(type=ParamSchedulerHook), + sampler_seed=dict(type=DistSamplerSeedHook), +) diff --git a/mmagic/configs/_base_/models/base_pconv.py b/mmagic/configs/_base_/models/base_pconv.py new file mode 100644 index 000000000..6b6dc571e --- /dev/null +++ b/mmagic/configs/_base_/models/base_pconv.py @@ -0,0 +1,60 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmengine.model import MMSeparateDistributedDataParallel +from mmengine.optim import OptimWrapper + +from mmagic.models import DataPreprocessor +from mmagic.models.editors import (PConvDecoder, PConvEncoder, + PConvEncoderDecoder, PConvInpaintor) +from mmagic.models.losses import L1Loss, MaskedTVLoss, PerceptualLoss + +# DistributedDataParallel +model_wrapper_cfg = dict(type=MMSeparateDistributedDataParallel) + +model = dict( + type=PConvInpaintor, + data_preprocessor=dict( + type=DataPreprocessor, + mean=[127.5], + std=[127.5], + ), + encdec=dict( + type=PConvEncoderDecoder, + encoder=dict( + type=PConvEncoder, + norm_cfg=dict(type='SyncBN', requires_grad=False), + norm_eval=True), + decoder=dict(type=PConvDecoder, norm_cfg=dict(type='SyncBN'))), + disc=None, + loss_composed_percep=dict( + type=PerceptualLoss, + vgg_type='vgg16', + layer_weights={ + '4': 1., + '9': 1., + '16': 1., + }, + perceptual_weight=0.05, + style_weight=120, + pretrained=('torchvision://vgg16')), + loss_out_percep=True, + loss_l1_hole=dict( + type=L1Loss, + loss_weight=6., + ), + loss_l1_valid=dict( + type=L1Loss, + loss_weight=1., + ), + loss_tv=dict( + type=MaskedTVLoss, + loss_weight=0.1, + )) + +# optimizer +optim_wrapper = dict( + constructor='DefaultOptimWrapperConstructor', + type=OptimWrapper, + optimizer=dict(type='Adam', lr=0.00005)) + +# learning policy +# Fixed diff --git a/mmagic/configs/_base_/models/base_pix2pix.py b/mmagic/configs/_base_/models/base_pix2pix.py new file mode 100644 index 000000000..1fadaab45 --- /dev/null +++ b/mmagic/configs/_base_/models/base_pix2pix.py @@ -0,0 +1,32 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmagic.models import DataPreprocessor +from mmagic.models.archs import PatchDiscriminator +from mmagic.models.editors import Pix2Pix +from mmagic.models.editors.pix2pix import UnetGenerator + +source_domain = None # set by user +target_domain = None # set by user +# model settings +model = dict( + type=Pix2Pix, + data_preprocessor=dict(type=DataPreprocessor), + generator=dict( + type=UnetGenerator, + in_channels=3, + out_channels=3, + num_down=8, + base_channels=64, + norm_cfg=dict(type='BN'), + use_dropout=True, + init_cfg=dict(type='normal', gain=0.02)), + discriminator=dict( + type=PatchDiscriminator, + in_channels=6, + base_channels=64, + num_conv=3, + norm_cfg=dict(type='BN'), + init_cfg=dict(type='normal', gain=0.02)), + loss_config=dict(pixel_loss_weight=100.0), + default_domain=target_domain, + reachable_domains=[target_domain], + related_domains=[target_domain, source_domain]) diff --git a/mmagic/configs/_base_/models/base_styleganv1.py b/mmagic/configs/_base_/models/base_styleganv1.py new file mode 100644 index 000000000..8068d6873 --- /dev/null +++ b/mmagic/configs/_base_/models/base_styleganv1.py @@ -0,0 +1,11 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmagic.models import DataPreprocessor +from mmagic.models.editors.stylegan1 import (StyleGAN1, StyleGAN1Discriminator, + StyleGAN1Generator) + +model = dict( + type=StyleGAN1, + data_preprocessor=dict(type=DataPreprocessor), + style_channels=512, + generator=dict(type=StyleGAN1Generator, out_size=None, style_channels=512), + discriminator=dict(type=StyleGAN1Discriminator, in_size=None)) diff --git a/mmagic/configs/_base_/models/base_styleganv2.py b/mmagic/configs/_base_/models/base_styleganv2.py new file mode 100644 index 000000000..73bc8e32b --- /dev/null +++ b/mmagic/configs/_base_/models/base_styleganv2.py @@ -0,0 +1,38 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmengine.model import ExponentialMovingAverage + +from mmagic.models import DataPreprocessor +from mmagic.models.editors import StyleGAN2 +from mmagic.models.editors.stylegan2 import (StyleGAN2Discriminator, + StyleGAN2Generator) + +# define GAN model + +d_reg_interval = 16 +g_reg_interval = 4 + +g_reg_ratio = g_reg_interval / (g_reg_interval + 1) +d_reg_ratio = d_reg_interval / (d_reg_interval + 1) + +loss_config = dict( + r1_loss_weight=10. / 2. * d_reg_interval, + r1_interval=d_reg_interval, + norm_mode='HWC', + g_reg_interval=g_reg_interval, + g_reg_weight=2. * g_reg_interval, + pl_batch_shrink=2) + +model = dict( + type=StyleGAN2, + data_preprocessor=dict(type=DataPreprocessor), + generator=dict( + type=StyleGAN2Generator, + out_size=None, # Need to be set. + style_channels=512, + ), + discriminator=dict( + type=StyleGAN2Discriminator, + in_size=None, # Need to be set. + ), + ema_config=dict(type=ExponentialMovingAverage), + loss_config=loss_config) diff --git a/mmagic/configs/_base_/models/base_tof.py b/mmagic/configs/_base_/models/base_tof.py new file mode 100644 index 000000000..ba4f26abc --- /dev/null +++ b/mmagic/configs/_base_/models/base_tof.py @@ -0,0 +1,120 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmengine.dataset import DefaultSampler, InfiniteSampler +from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook, + LoggerHook, ParamSchedulerHook) +from mmengine.optim import MultiStepLR, OptimWrapper +from mmengine.runner import IterBasedTrainLoop + +from mmagic.datasets.transforms import LoadImageFromFile, PackInputs +from mmagic.engine.runner import MultiTestLoop, MultiValLoop +from mmagic.evaluation import MAE, PSNR, SSIM + +_base_ = '../default_runtime.py' + +train_pipeline = [ + dict( + type=LoadImageFromFile, + key='img', + channel_order='rgb', + imdecode_backend='pillow'), + dict( + type=LoadImageFromFile, + key='gt', + channel_order='rgb', + imdecode_backend='pillow'), + dict(type=PackInputs) +] + +demo_pipeline = [ + dict( + type=LoadImageFromFile, + key='img', + channel_order='rgb', + imdecode_backend='pillow'), + dict(type=PackInputs) +] + +# dataset settings +train_dataset_type = 'BasicFramesDataset' +val_dataset_type = 'BasicFramesDataset' +data_root = 'data/vimeo_triplet' +save_dir = './work_dirs' + +train_dataloader = dict( + num_workers=4, + persistent_workers=False, + sampler=dict(type=InfiniteSampler, shuffle=True), + dataset=dict( + type=train_dataset_type, + ann_file='tri_trainlist.txt', + metainfo=dict(dataset_type='vimeo90k', task_name='vfi'), + data_root=data_root, + data_prefix=dict(img='sequences', gt='sequences'), + pipeline=train_pipeline, + depth=2, + load_frames_list=dict(img=['im1.png', 'im3.png'], gt=['im2.png']))) + +val_dataloader = dict( + num_workers=4, + persistent_workers=False, + drop_last=False, + sampler=dict(type=DefaultSampler, shuffle=False), + dataset=dict( + type=val_dataset_type, + ann_file='tri_testlist.txt', + metainfo=dict(dataset_type='vimeo90k', task_name='vfi'), + data_root=data_root, + data_prefix=dict(img='sequences', gt='sequences'), + pipeline=train_pipeline, + depth=2, + load_frames_list=dict(img=['im1.png', 'im3.png'], gt=['im2.png']))) + +test_dataloader = val_dataloader + +val_evaluator = [ + dict(type=MAE), + dict(type=PSNR), + dict(type=SSIM), +] +test_evaluator = val_evaluator + +# 5000 iters == 1 epoch +epoch_length = 5000 + +train_cfg = dict( + type=IterBasedTrainLoop, max_iters=1_000_000, val_interval=epoch_length) +val_cfg = dict(type=MultiValLoop) +test_cfg = dict(type=MultiTestLoop) + +# optimizer +optim_wrapper = dict( + constructor='DefaultOptimWrapperConstructor', + type=OptimWrapper, + optimizer=dict( + type='Adam', + lr=5e-5, + betas=(0.9, 0.99), + weight_decay=1e-4, + ), +) + +# learning policy +param_scheduler = dict( + type=MultiStepLR, + by_epoch=False, + gamma=0.5, + milestones=[200000, 400000, 600000, 800000]) + +default_hooks = dict( + checkpoint=dict( + type=CheckpointHook, + interval=epoch_length, + save_optimizer=True, + by_epoch=False, + out_dir=save_dir, + ), + timer=dict(type=IterTimerHook), + logger=dict(type=LoggerHook, interval=100), + param_scheduler=dict(type=ParamSchedulerHook), + sampler_seed=dict(type=DistSamplerSeedHook), +) diff --git a/mmagic/configs/_base_/models/dcgan/base_dcgan_128x128.py b/mmagic/configs/_base_/models/dcgan/base_dcgan_128x128.py new file mode 100644 index 000000000..42cb98350 --- /dev/null +++ b/mmagic/configs/_base_/models/dcgan/base_dcgan_128x128.py @@ -0,0 +1,16 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmagic.models import DataPreprocessor +from mmagic.models.editors import DCGAN +from mmagic.models.editors.dcgan import DCGANDiscriminator, DCGANGenerator + +# define GAN model +model = dict( + type=DCGAN, + noise_size=100, + data_preprocessor=dict(type=DataPreprocessor), + generator=dict(type=DCGANGenerator, output_scale=128, base_channels=1024), + discriminator=dict( + type=DCGANDiscriminator, + input_scale=128, + output_scale=4, + out_channels=100)) diff --git a/mmagic/configs/_base_/models/dcgan/base_dcgan_64x64.py b/mmagic/configs/_base_/models/dcgan/base_dcgan_64x64.py new file mode 100644 index 000000000..393c44e71 --- /dev/null +++ b/mmagic/configs/_base_/models/dcgan/base_dcgan_64x64.py @@ -0,0 +1,16 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmagic.models import DataPreprocessor +from mmagic.models.editors import DCGAN +from mmagic.models.editors.dcgan import DCGANDiscriminator, DCGANGenerator + +# define GAN model +model = dict( + type=DCGAN, + noise_size=100, + data_preprocessor=dict(type=DataPreprocessor), + generator=dict(type=DCGANGenerator, output_scale=64, base_channels=1024), + discriminator=dict( + type=DCGANDiscriminator, + input_scale=64, + output_scale=4, + out_channels=1)) diff --git a/mmagic/configs/_base_/models/sagan/base_sagan_128x128.py b/mmagic/configs/_base_/models/sagan/base_sagan_128x128.py new file mode 100644 index 000000000..1e2674f51 --- /dev/null +++ b/mmagic/configs/_base_/models/sagan/base_sagan_128x128.py @@ -0,0 +1,26 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmagic.models import DataPreprocessor +from mmagic.models.editors import SAGAN +from mmagic.models.editors.biggan import SelfAttentionBlock +from mmagic.models.editors.sagan import ProjDiscriminator, SNGANGenerator + +model = dict( + type=SAGAN, + num_classes=1000, + data_preprocessor=dict(type=DataPreprocessor), + generator=dict( + type=SNGANGenerator, + output_scale=128, + base_channels=64, + attention_cfg=dict(type=SelfAttentionBlock), + attention_after_nth_block=4, + with_spectral_norm=True), + discriminator=dict( + type=ProjDiscriminator, + input_scale=128, + base_channels=64, + attention_cfg=dict(type=SelfAttentionBlock), + attention_after_nth_block=1, + with_spectral_norm=True), + generator_steps=1, + discriminator_steps=1) diff --git a/mmagic/configs/_base_/models/sagan/base_sagan_32x32.py b/mmagic/configs/_base_/models/sagan/base_sagan_32x32.py new file mode 100644 index 000000000..83ac553ca --- /dev/null +++ b/mmagic/configs/_base_/models/sagan/base_sagan_32x32.py @@ -0,0 +1,28 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmagic.models import DataPreprocessor +from mmagic.models.editors import SAGAN +from mmagic.models.editors.biggan import SelfAttentionBlock +from mmagic.models.editors.sagan import ProjDiscriminator, SNGANGenerator + +model = dict( + type=SAGAN, + data_preprocessor=dict(type=DataPreprocessor), + num_classes=10, + generator=dict( + type=SNGANGenerator, + num_classes=10, + output_scale=32, + base_channels=256, + attention_cfg=dict(type=SelfAttentionBlock), + attention_after_nth_block=2, + with_spectral_norm=True), + discriminator=dict( + type=ProjDiscriminator, + num_classes=10, + input_scale=32, + base_channels=128, + attention_cfg=dict(type=SelfAttentionBlock), + attention_after_nth_block=1, + with_spectral_norm=True), + generator_steps=1, + discriminator_steps=5) diff --git a/mmagic/configs/_base_/models/sngan_proj/base_sngan_proj_128x128.py b/mmagic/configs/_base_/models/sngan_proj/base_sngan_proj_128x128.py new file mode 100644 index 000000000..02b76bf08 --- /dev/null +++ b/mmagic/configs/_base_/models/sngan_proj/base_sngan_proj_128x128.py @@ -0,0 +1,14 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmagic.models import DataPreprocessor +from mmagic.models.editors.sagan import (SAGAN, ProjDiscriminator, + SNGANGenerator) + +# define GAN model +model = dict( + type=SAGAN, + num_classes=1000, + data_preprocessor=dict(type=DataPreprocessor), + generator=dict(type=SNGANGenerator, output_scale=128, base_channels=64), + discriminator=dict( + type=ProjDiscriminator, input_scale=128, base_channels=64), + discriminator_steps=2) diff --git a/mmagic/configs/_base_/models/sngan_proj/base_sngan_proj_32x32.py b/mmagic/configs/_base_/models/sngan_proj/base_sngan_proj_32x32.py new file mode 100644 index 000000000..7678c2405 --- /dev/null +++ b/mmagic/configs/_base_/models/sngan_proj/base_sngan_proj_32x32.py @@ -0,0 +1,14 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmagic.models import DataPreprocessor +from mmagic.models.editors.sagan import (SAGAN, ProjDiscriminator, + SNGANGenerator) + +# define GAN model +model = dict( + type=SAGAN, + num_classes=10, + data_preprocessor=dict(type=DataPreprocessor), + generator=dict(type=SNGANGenerator, output_scale=32, base_channels=256), + discriminator=dict( + type=ProjDiscriminator, input_scale=32, base_channels=128), + discriminator_steps=5) diff --git a/mmagic/configs/_base_/schedules/.gitkeep b/mmagic/configs/_base_/schedules/.gitkeep new file mode 100644 index 000000000..e69de29bb