Skip to content

Commit

Permalink
Merge branch 'main' into fix_fast_composer
Browse files Browse the repository at this point in the history
  • Loading branch information
zengyh1900 committed Nov 10, 2023
2 parents 0383703 + c2f8f3a commit 47c727f
Show file tree
Hide file tree
Showing 48 changed files with 3,197 additions and 4 deletions.
6 changes: 4 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -140,8 +140,6 @@ Currently, MMagic support multiple image and video generation/editing tasks.

https://user-images.githubusercontent.com/49083766/233564593-7d3d48ed-e843-4432-b610-35e3d257765c.mp4

The best practice on our main branch works with **Python 3.8+** and **PyTorch 1.10+**.

### ✨ Major features

- **State of the Art Models**
Expand All @@ -156,6 +154,10 @@ The best practice on our main branch works with **Python 3.8+** and **PyTorch 1.

By using MMEngine and MMCV of OpenMMLab 2.0 framework, MMagic decompose the editing framework into different modules and one can easily construct a customized editor framework by combining different modules. We can define the training process just like playing with Legos and provide rich components and strategies. In MMagic, you can complete controls on the training process with different levels of APIs. With the support of [MMSeparateDistributedDataParallel](https://github.com/open-mmlab/mmengine/blob/main/mmengine/model/wrappers/seperate_distributed.py), distributed training for dynamic architectures can be easily implemented.

### ✨ Best Practice

- The best practice on our main branch works with **Python 3.9+** and **PyTorch 2.0+**.

<p align="right"><a href="#table">🔝Back to Table of Contents</a></p>

## 🙌 Contributing
Expand Down
6 changes: 4 additions & 2 deletions README_zh-CN.md
Original file line number Diff line number Diff line change
Expand Up @@ -138,8 +138,6 @@ MMagic 是基于 PyTorch 的图像&视频编辑和生成开源工具箱。是 [O

https://user-images.githubusercontent.com/49083766/233564593-7d3d48ed-e843-4432-b610-35e3d257765c.mp4

主分支代码的最佳实践基于 **Python 3.8+****PyTorch 1.10+**

### ✨ 主要特性

- **SOTA 算法**
Expand All @@ -154,6 +152,10 @@ https://user-images.githubusercontent.com/49083766/233564593-7d3d48ed-e843-4432-

通过 OpenMMLab 2.0 框架的 MMEngine 和 MMCV, MMagic 将编辑框架分解为不同的组件,并且可以通过组合不同的模块轻松地构建自定义的编辑器模型。我们可以像搭建“乐高”一样定义训练流程,提供丰富的组件和策略。在 MMagic 中,你可以使用不同的 APIs 完全控制训练流程。得益于 [MMSeparateDistributedDataParallel](https://github.com/open-mmlab/mmengine/blob/main/mmengine/model/wrappers/seperate_distributed.py), 动态模型结构的分布式训练可以轻松实现。

### ✨ 最佳实践

- 主分支代码的最佳实践基于 **Python 3.9+****PyTorch 2.0+**

<p align="right"><a href="#table">🔝返回目录</a></p>

## 🙌 参与贡献
Expand Down
192 changes: 192 additions & 0 deletions mmagic/configs/_base_/datasets/basicvsr_test_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,192 @@
# Copyright (c) OpenMMLab. All rights reserved.
from mmengine.dataset import DefaultSampler

from mmagic.datasets import BasicFramesDataset
from mmagic.datasets.transforms import (GenerateSegmentIndices,
LoadImageFromFile, MirrorSequence,
PackInputs)
from mmagic.engine.runner import MultiTestLoop
from mmagic.evaluation import PSNR, SSIM

# configs for REDS4
reds_data_root = 'data/REDS'

reds_pipeline = [
dict(type=GenerateSegmentIndices, interval_list=[1]),
dict(type=LoadImageFromFile, key='img', channel_order='rgb'),
dict(type=LoadImageFromFile, key='gt', channel_order='rgb'),
dict(type=PackInputs)
]

reds_dataloader = dict(
num_workers=1,
batch_size=1,
persistent_workers=False,
sampler=dict(type=DefaultSampler, shuffle=False),
dataset=dict(
type=BasicFramesDataset,
metainfo=dict(dataset_type='reds_reds4', task_name='vsr'),
data_root=reds_data_root,
data_prefix=dict(img='train_sharp_bicubic/X4', gt='train_sharp'),
ann_file='meta_info_reds4_val.txt',
depth=1,
num_input_frames=100,
fixed_seq_len=100,
pipeline=reds_pipeline))

reds_evaluator = [
dict(type=PSNR, prefix='REDS4-BIx4-RGB'),
dict(type=SSIM, prefix='REDS4-BIx4-RGB')
]

# configs for vimeo90k-bd and vimeo90k-bi
vimeo_90k_data_root = 'data/vimeo90k'
vimeo_90k_file_list = [
'im1.png', 'im2.png', 'im3.png', 'im4.png', 'im5.png', 'im6.png', 'im7.png'
]

vimeo_90k_pipeline = [
dict(type=LoadImageFromFile, key='img', channel_order='rgb'),
dict(type=LoadImageFromFile, key='gt', channel_order='rgb'),
dict(type=MirrorSequence, keys=['img']),
dict(type=PackInputs)
]

vimeo_90k_bd_dataloader = dict(
num_workers=1,
batch_size=1,
persistent_workers=False,
sampler=dict(type=DefaultSampler, shuffle=False),
dataset=dict(
type=BasicFramesDataset,
metainfo=dict(dataset_type='vimeo90k_seq', task_name='vsr'),
data_root=vimeo_90k_data_root,
data_prefix=dict(img='BDx4', gt='GT'),
ann_file='meta_info_Vimeo90K_test_GT.txt',
depth=2,
num_input_frames=7,
fixed_seq_len=7,
load_frames_list=dict(img=vimeo_90k_file_list, gt=['im4.png']),
pipeline=vimeo_90k_pipeline))

vimeo_90k_bi_dataloader = dict(
num_workers=1,
batch_size=1,
persistent_workers=False,
sampler=dict(type=DefaultSampler, shuffle=False),
dataset=dict(
type=BasicFramesDataset,
metainfo=dict(dataset_type='vimeo90k_seq', task_name='vsr'),
data_root=vimeo_90k_data_root,
data_prefix=dict(img='BIx4', gt='GT'),
ann_file='meta_info_Vimeo90K_test_GT.txt',
depth=2,
num_input_frames=7,
fixed_seq_len=7,
load_frames_list=dict(img=vimeo_90k_file_list, gt=['im4.png']),
pipeline=vimeo_90k_pipeline))

vimeo_90k_bd_evaluator = [
dict(type=PSNR, convert_to='Y', prefix='Vimeo-90K-T-BDx4-Y'),
dict(type=SSIM, convert_to='Y', prefix='Vimeo-90K-T-BDx4-Y'),
]

vimeo_90k_bi_evaluator = [
dict(type=PSNR, convert_to='Y', prefix='Vimeo-90K-T-BIx4-Y'),
dict(type=SSIM, convert_to='Y', prefix='Vimeo-90K-T-BIx4-Y'),
]

# config for UDM10 (BDx4)
udm10_data_root = 'data/UDM10'

udm10_pipeline = [
dict(
type=GenerateSegmentIndices,
interval_list=[1],
filename_tmpl='{:04d}.png'),
dict(type=LoadImageFromFile, key='img', channel_order='rgb'),
dict(type=LoadImageFromFile, key='gt', channel_order='rgb'),
dict(type=PackInputs)
]

udm10_dataloader = dict(
num_workers=1,
batch_size=1,
persistent_workers=False,
sampler=dict(type=DefaultSampler, shuffle=False),
dataset=dict(
type=BasicFramesDataset,
metainfo=dict(dataset_type='udm10', task_name='vsr'),
data_root=udm10_data_root,
data_prefix=dict(img='BDx4', gt='GT'),
pipeline=udm10_pipeline))

udm10_evaluator = [
dict(type=PSNR, convert_to='Y', prefix='UDM10-BDx4-Y'),
dict(type=SSIM, convert_to='Y', prefix='UDM10-BDx4-Y')
]

# config for vid4
vid4_data_root = 'data/Vid4'

vid4_pipeline = [
dict(type=GenerateSegmentIndices, interval_list=[1]),
dict(type=LoadImageFromFile, key='img', channel_order='rgb'),
dict(type=LoadImageFromFile, key='gt', channel_order='rgb'),
dict(type=PackInputs)
]
vid4_bd_dataloader = dict(
num_workers=1,
batch_size=1,
persistent_workers=False,
sampler=dict(type=DefaultSampler, shuffle=False),
dataset=dict(
type=BasicFramesDataset,
metainfo=dict(dataset_type='vid4', task_name='vsr'),
data_root=vid4_data_root,
data_prefix=dict(img='BDx4', gt='GT'),
ann_file='meta_info_Vid4_GT.txt',
depth=1,
pipeline=vid4_pipeline))

vid4_bi_dataloader = dict(
num_workers=1,
batch_size=1,
persistent_workers=False,
sampler=dict(type=DefaultSampler, shuffle=False),
dataset=dict(
type=BasicFramesDataset,
metainfo=dict(dataset_type='vid4', task_name='vsr'),
data_root=vid4_data_root,
data_prefix=dict(img='BIx4', gt='GT'),
ann_file='meta_info_Vid4_GT.txt',
depth=1,
pipeline=vid4_pipeline))

vid4_bd_evaluator = [
dict(type=PSNR, convert_to='Y', prefix='VID4-BDx4-Y'),
dict(type=SSIM, convert_to='Y', prefix='VID4-BDx4-Y'),
]
vid4_bi_evaluator = [
dict(type=PSNR, convert_to='Y', prefix='VID4-BIx4-Y'),
dict(type=SSIM, convert_to='Y', prefix='VID4-BIx4-Y'),
]

# config for test
test_cfg = dict(type=MultiTestLoop)
test_dataloader = [
reds_dataloader,
vimeo_90k_bd_dataloader,
vimeo_90k_bi_dataloader,
udm10_dataloader,
vid4_bd_dataloader,
vid4_bi_dataloader,
]
test_evaluator = [
reds_evaluator,
vimeo_90k_bd_evaluator,
vimeo_90k_bi_evaluator,
udm10_evaluator,
vid4_bd_evaluator,
vid4_bi_evaluator,
]
48 changes: 48 additions & 0 deletions mmagic/configs/_base_/datasets/celeba.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# Copyright (c) OpenMMLab. All rights reserved.
from mmengine.dataset import DefaultSampler, InfiniteSampler

from mmagic.evaluation import MAE, PSNR, SSIM

# Base config for CelebA-HQ dataset

# dataset settings
dataset_type = 'BasicImageDataset'
data_root = 'data/CelebA-HQ'

train_dataloader = dict(
num_workers=4,
persistent_workers=False,
sampler=dict(type=InfiniteSampler, shuffle=True),
dataset=dict(
type=dataset_type,
data_root=data_root,
data_prefix=dict(gt=''),
ann_file='train_celeba_img_list.txt',
test_mode=False,
))

val_dataloader = dict(
num_workers=4,
persistent_workers=False,
drop_last=False,
sampler=dict(type=DefaultSampler, shuffle=False),
dataset=dict(
type=dataset_type,
data_root=data_root,
data_prefix=dict(gt=''),
ann_file='val_celeba_img_list.txt',
test_mode=True,
))

test_dataloader = val_dataloader

val_evaluator = [
dict(type=MAE, mask_key='mask', scaling=100),
# By default, compute with pixel value from 0-1
# scale=2 to align with 1.0
# scale=100 seems to align with readme
dict(type=PSNR),
dict(type=SSIM),
]

test_evaluator = val_evaluator
45 changes: 45 additions & 0 deletions mmagic/configs/_base_/datasets/cifar10_nopad.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# Copyright (c) OpenMMLab. All rights reserved.
from mmengine.dataset import DefaultSampler, InfiniteSampler

from mmagic.datasets import CIFAR10
from mmagic.datasets.transforms import Flip, PackInputs

cifar_pipeline = [
dict(type=Flip, keys=['gt'], flip_ratio=0.5, direction='horizontal'),
dict(type=PackInputs)
]
cifar_dataset = dict(
type=CIFAR10,
data_root='./data',
data_prefix='cifar10',
test_mode=False,
pipeline=cifar_pipeline)

# test dataset do not use flip
cifar_pipeline_test = [dict(type=PackInputs)]
cifar_dataset_test = dict(
type=CIFAR10,
data_root='./data',
data_prefix='cifar10',
test_mode=False,
pipeline=cifar_pipeline_test)

train_dataloader = dict(
num_workers=2,
dataset=cifar_dataset,
sampler=dict(type=InfiniteSampler, shuffle=True),
persistent_workers=True)

val_dataloader = dict(
batch_size=32,
num_workers=2,
dataset=cifar_dataset_test,
sampler=dict(type=DefaultSampler, shuffle=False),
persistent_workers=True)

test_dataloader = dict(
batch_size=32,
num_workers=2,
dataset=cifar_dataset_test,
sampler=dict(type=DefaultSampler, shuffle=False),
persistent_workers=True)
45 changes: 45 additions & 0 deletions mmagic/configs/_base_/datasets/comp1k.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# Copyright (c) OpenMMLab. All rights reserved.
from mmengine.dataset import DefaultSampler, InfiniteSampler

from mmagic.evaluation import SAD, ConnectivityError, GradientError, MattingMSE

# Base config for Composition-1K dataset

# dataset settings
dataset_type = 'AdobeComp1kDataset'
data_root = 'data/adobe_composition-1k'

train_dataloader = dict(
num_workers=4,
persistent_workers=False,
sampler=dict(type=InfiniteSampler, shuffle=True),
dataset=dict(
type=dataset_type,
data_root=data_root,
ann_file='training_list.json',
test_mode=False,
))

val_dataloader = dict(
num_workers=4,
persistent_workers=False,
drop_last=False,
sampler=dict(type=DefaultSampler, shuffle=False),
dataset=dict(
type=dataset_type,
data_root=data_root,
ann_file='test_list.json',
test_mode=True,
))

test_dataloader = val_dataloader

# TODO: matting
val_evaluator = [
dict(type=SAD),
dict(type=MattingMSE),
dict(type=GradientError),
dict(type=ConnectivityError),
]

test_evaluator = val_evaluator
Loading

0 comments on commit 47c727f

Please sign in to comment.