-
Notifications
You must be signed in to change notification settings - Fork 1.1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'main' into fix_fast_composer
- Loading branch information
Showing
48 changed files
with
3,197 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,192 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
from mmengine.dataset import DefaultSampler | ||
|
||
from mmagic.datasets import BasicFramesDataset | ||
from mmagic.datasets.transforms import (GenerateSegmentIndices, | ||
LoadImageFromFile, MirrorSequence, | ||
PackInputs) | ||
from mmagic.engine.runner import MultiTestLoop | ||
from mmagic.evaluation import PSNR, SSIM | ||
|
||
# configs for REDS4 | ||
reds_data_root = 'data/REDS' | ||
|
||
reds_pipeline = [ | ||
dict(type=GenerateSegmentIndices, interval_list=[1]), | ||
dict(type=LoadImageFromFile, key='img', channel_order='rgb'), | ||
dict(type=LoadImageFromFile, key='gt', channel_order='rgb'), | ||
dict(type=PackInputs) | ||
] | ||
|
||
reds_dataloader = dict( | ||
num_workers=1, | ||
batch_size=1, | ||
persistent_workers=False, | ||
sampler=dict(type=DefaultSampler, shuffle=False), | ||
dataset=dict( | ||
type=BasicFramesDataset, | ||
metainfo=dict(dataset_type='reds_reds4', task_name='vsr'), | ||
data_root=reds_data_root, | ||
data_prefix=dict(img='train_sharp_bicubic/X4', gt='train_sharp'), | ||
ann_file='meta_info_reds4_val.txt', | ||
depth=1, | ||
num_input_frames=100, | ||
fixed_seq_len=100, | ||
pipeline=reds_pipeline)) | ||
|
||
reds_evaluator = [ | ||
dict(type=PSNR, prefix='REDS4-BIx4-RGB'), | ||
dict(type=SSIM, prefix='REDS4-BIx4-RGB') | ||
] | ||
|
||
# configs for vimeo90k-bd and vimeo90k-bi | ||
vimeo_90k_data_root = 'data/vimeo90k' | ||
vimeo_90k_file_list = [ | ||
'im1.png', 'im2.png', 'im3.png', 'im4.png', 'im5.png', 'im6.png', 'im7.png' | ||
] | ||
|
||
vimeo_90k_pipeline = [ | ||
dict(type=LoadImageFromFile, key='img', channel_order='rgb'), | ||
dict(type=LoadImageFromFile, key='gt', channel_order='rgb'), | ||
dict(type=MirrorSequence, keys=['img']), | ||
dict(type=PackInputs) | ||
] | ||
|
||
vimeo_90k_bd_dataloader = dict( | ||
num_workers=1, | ||
batch_size=1, | ||
persistent_workers=False, | ||
sampler=dict(type=DefaultSampler, shuffle=False), | ||
dataset=dict( | ||
type=BasicFramesDataset, | ||
metainfo=dict(dataset_type='vimeo90k_seq', task_name='vsr'), | ||
data_root=vimeo_90k_data_root, | ||
data_prefix=dict(img='BDx4', gt='GT'), | ||
ann_file='meta_info_Vimeo90K_test_GT.txt', | ||
depth=2, | ||
num_input_frames=7, | ||
fixed_seq_len=7, | ||
load_frames_list=dict(img=vimeo_90k_file_list, gt=['im4.png']), | ||
pipeline=vimeo_90k_pipeline)) | ||
|
||
vimeo_90k_bi_dataloader = dict( | ||
num_workers=1, | ||
batch_size=1, | ||
persistent_workers=False, | ||
sampler=dict(type=DefaultSampler, shuffle=False), | ||
dataset=dict( | ||
type=BasicFramesDataset, | ||
metainfo=dict(dataset_type='vimeo90k_seq', task_name='vsr'), | ||
data_root=vimeo_90k_data_root, | ||
data_prefix=dict(img='BIx4', gt='GT'), | ||
ann_file='meta_info_Vimeo90K_test_GT.txt', | ||
depth=2, | ||
num_input_frames=7, | ||
fixed_seq_len=7, | ||
load_frames_list=dict(img=vimeo_90k_file_list, gt=['im4.png']), | ||
pipeline=vimeo_90k_pipeline)) | ||
|
||
vimeo_90k_bd_evaluator = [ | ||
dict(type=PSNR, convert_to='Y', prefix='Vimeo-90K-T-BDx4-Y'), | ||
dict(type=SSIM, convert_to='Y', prefix='Vimeo-90K-T-BDx4-Y'), | ||
] | ||
|
||
vimeo_90k_bi_evaluator = [ | ||
dict(type=PSNR, convert_to='Y', prefix='Vimeo-90K-T-BIx4-Y'), | ||
dict(type=SSIM, convert_to='Y', prefix='Vimeo-90K-T-BIx4-Y'), | ||
] | ||
|
||
# config for UDM10 (BDx4) | ||
udm10_data_root = 'data/UDM10' | ||
|
||
udm10_pipeline = [ | ||
dict( | ||
type=GenerateSegmentIndices, | ||
interval_list=[1], | ||
filename_tmpl='{:04d}.png'), | ||
dict(type=LoadImageFromFile, key='img', channel_order='rgb'), | ||
dict(type=LoadImageFromFile, key='gt', channel_order='rgb'), | ||
dict(type=PackInputs) | ||
] | ||
|
||
udm10_dataloader = dict( | ||
num_workers=1, | ||
batch_size=1, | ||
persistent_workers=False, | ||
sampler=dict(type=DefaultSampler, shuffle=False), | ||
dataset=dict( | ||
type=BasicFramesDataset, | ||
metainfo=dict(dataset_type='udm10', task_name='vsr'), | ||
data_root=udm10_data_root, | ||
data_prefix=dict(img='BDx4', gt='GT'), | ||
pipeline=udm10_pipeline)) | ||
|
||
udm10_evaluator = [ | ||
dict(type=PSNR, convert_to='Y', prefix='UDM10-BDx4-Y'), | ||
dict(type=SSIM, convert_to='Y', prefix='UDM10-BDx4-Y') | ||
] | ||
|
||
# config for vid4 | ||
vid4_data_root = 'data/Vid4' | ||
|
||
vid4_pipeline = [ | ||
dict(type=GenerateSegmentIndices, interval_list=[1]), | ||
dict(type=LoadImageFromFile, key='img', channel_order='rgb'), | ||
dict(type=LoadImageFromFile, key='gt', channel_order='rgb'), | ||
dict(type=PackInputs) | ||
] | ||
vid4_bd_dataloader = dict( | ||
num_workers=1, | ||
batch_size=1, | ||
persistent_workers=False, | ||
sampler=dict(type=DefaultSampler, shuffle=False), | ||
dataset=dict( | ||
type=BasicFramesDataset, | ||
metainfo=dict(dataset_type='vid4', task_name='vsr'), | ||
data_root=vid4_data_root, | ||
data_prefix=dict(img='BDx4', gt='GT'), | ||
ann_file='meta_info_Vid4_GT.txt', | ||
depth=1, | ||
pipeline=vid4_pipeline)) | ||
|
||
vid4_bi_dataloader = dict( | ||
num_workers=1, | ||
batch_size=1, | ||
persistent_workers=False, | ||
sampler=dict(type=DefaultSampler, shuffle=False), | ||
dataset=dict( | ||
type=BasicFramesDataset, | ||
metainfo=dict(dataset_type='vid4', task_name='vsr'), | ||
data_root=vid4_data_root, | ||
data_prefix=dict(img='BIx4', gt='GT'), | ||
ann_file='meta_info_Vid4_GT.txt', | ||
depth=1, | ||
pipeline=vid4_pipeline)) | ||
|
||
vid4_bd_evaluator = [ | ||
dict(type=PSNR, convert_to='Y', prefix='VID4-BDx4-Y'), | ||
dict(type=SSIM, convert_to='Y', prefix='VID4-BDx4-Y'), | ||
] | ||
vid4_bi_evaluator = [ | ||
dict(type=PSNR, convert_to='Y', prefix='VID4-BIx4-Y'), | ||
dict(type=SSIM, convert_to='Y', prefix='VID4-BIx4-Y'), | ||
] | ||
|
||
# config for test | ||
test_cfg = dict(type=MultiTestLoop) | ||
test_dataloader = [ | ||
reds_dataloader, | ||
vimeo_90k_bd_dataloader, | ||
vimeo_90k_bi_dataloader, | ||
udm10_dataloader, | ||
vid4_bd_dataloader, | ||
vid4_bi_dataloader, | ||
] | ||
test_evaluator = [ | ||
reds_evaluator, | ||
vimeo_90k_bd_evaluator, | ||
vimeo_90k_bi_evaluator, | ||
udm10_evaluator, | ||
vid4_bd_evaluator, | ||
vid4_bi_evaluator, | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
from mmengine.dataset import DefaultSampler, InfiniteSampler | ||
|
||
from mmagic.evaluation import MAE, PSNR, SSIM | ||
|
||
# Base config for CelebA-HQ dataset | ||
|
||
# dataset settings | ||
dataset_type = 'BasicImageDataset' | ||
data_root = 'data/CelebA-HQ' | ||
|
||
train_dataloader = dict( | ||
num_workers=4, | ||
persistent_workers=False, | ||
sampler=dict(type=InfiniteSampler, shuffle=True), | ||
dataset=dict( | ||
type=dataset_type, | ||
data_root=data_root, | ||
data_prefix=dict(gt=''), | ||
ann_file='train_celeba_img_list.txt', | ||
test_mode=False, | ||
)) | ||
|
||
val_dataloader = dict( | ||
num_workers=4, | ||
persistent_workers=False, | ||
drop_last=False, | ||
sampler=dict(type=DefaultSampler, shuffle=False), | ||
dataset=dict( | ||
type=dataset_type, | ||
data_root=data_root, | ||
data_prefix=dict(gt=''), | ||
ann_file='val_celeba_img_list.txt', | ||
test_mode=True, | ||
)) | ||
|
||
test_dataloader = val_dataloader | ||
|
||
val_evaluator = [ | ||
dict(type=MAE, mask_key='mask', scaling=100), | ||
# By default, compute with pixel value from 0-1 | ||
# scale=2 to align with 1.0 | ||
# scale=100 seems to align with readme | ||
dict(type=PSNR), | ||
dict(type=SSIM), | ||
] | ||
|
||
test_evaluator = val_evaluator |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
from mmengine.dataset import DefaultSampler, InfiniteSampler | ||
|
||
from mmagic.datasets import CIFAR10 | ||
from mmagic.datasets.transforms import Flip, PackInputs | ||
|
||
cifar_pipeline = [ | ||
dict(type=Flip, keys=['gt'], flip_ratio=0.5, direction='horizontal'), | ||
dict(type=PackInputs) | ||
] | ||
cifar_dataset = dict( | ||
type=CIFAR10, | ||
data_root='./data', | ||
data_prefix='cifar10', | ||
test_mode=False, | ||
pipeline=cifar_pipeline) | ||
|
||
# test dataset do not use flip | ||
cifar_pipeline_test = [dict(type=PackInputs)] | ||
cifar_dataset_test = dict( | ||
type=CIFAR10, | ||
data_root='./data', | ||
data_prefix='cifar10', | ||
test_mode=False, | ||
pipeline=cifar_pipeline_test) | ||
|
||
train_dataloader = dict( | ||
num_workers=2, | ||
dataset=cifar_dataset, | ||
sampler=dict(type=InfiniteSampler, shuffle=True), | ||
persistent_workers=True) | ||
|
||
val_dataloader = dict( | ||
batch_size=32, | ||
num_workers=2, | ||
dataset=cifar_dataset_test, | ||
sampler=dict(type=DefaultSampler, shuffle=False), | ||
persistent_workers=True) | ||
|
||
test_dataloader = dict( | ||
batch_size=32, | ||
num_workers=2, | ||
dataset=cifar_dataset_test, | ||
sampler=dict(type=DefaultSampler, shuffle=False), | ||
persistent_workers=True) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
from mmengine.dataset import DefaultSampler, InfiniteSampler | ||
|
||
from mmagic.evaluation import SAD, ConnectivityError, GradientError, MattingMSE | ||
|
||
# Base config for Composition-1K dataset | ||
|
||
# dataset settings | ||
dataset_type = 'AdobeComp1kDataset' | ||
data_root = 'data/adobe_composition-1k' | ||
|
||
train_dataloader = dict( | ||
num_workers=4, | ||
persistent_workers=False, | ||
sampler=dict(type=InfiniteSampler, shuffle=True), | ||
dataset=dict( | ||
type=dataset_type, | ||
data_root=data_root, | ||
ann_file='training_list.json', | ||
test_mode=False, | ||
)) | ||
|
||
val_dataloader = dict( | ||
num_workers=4, | ||
persistent_workers=False, | ||
drop_last=False, | ||
sampler=dict(type=DefaultSampler, shuffle=False), | ||
dataset=dict( | ||
type=dataset_type, | ||
data_root=data_root, | ||
ann_file='test_list.json', | ||
test_mode=True, | ||
)) | ||
|
||
test_dataloader = val_dataloader | ||
|
||
# TODO: matting | ||
val_evaluator = [ | ||
dict(type=SAD), | ||
dict(type=MattingMSE), | ||
dict(type=GradientError), | ||
dict(type=ConnectivityError), | ||
] | ||
|
||
test_evaluator = val_evaluator |
Oops, something went wrong.