forked from yunjey/stargan
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathmain.py
118 lines (112 loc) · 5.46 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import paddle
import os
import argparse
from solver import Solver
from data_loader import get_loader
def str2bool(v):
return v.lower() in 'true'
def main(config):
if not os.path.exists(config.log_dir):
os.makedirs(config.log_dir)
if not os.path.exists(config.model_save_dir):
os.makedirs(config.model_save_dir)
if not os.path.exists(config.sample_dir):
os.makedirs(config.sample_dir)
if not os.path.exists(config.result_dir):
os.makedirs(config.result_dir)
celeba_loader = None
rafd_loader = None
if config.dataset in ['CelebA', 'Both']:
celeba_loader = get_loader(config.celeba_image_dir, config.
attr_path, config.selected_attrs, config.celeba_crop_size,
config.image_size, config.batch_size, 'CelebA', config.mode,
config.num_workers)
if config.dataset in ['RaFD', 'Both']:
rafd_loader = get_loader(config.rafd_image_dir, None, None, config.
rafd_crop_size, config.image_size, config.batch_size, 'RaFD',
config.mode, config.num_workers)
solver = Solver(celeba_loader, rafd_loader, config)
if config.mode == 'train':
if config.dataset in ['CelebA', 'RaFD']:
solver.train()
elif config.dataset in ['Both']:
solver.train_multi()
elif config.mode == 'test':
if config.dataset in ['CelebA', 'RaFD']:
solver.test()
elif config.dataset in ['Both']:
solver.test_multi()
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--c_dim', type=int, default=5, help=\
'dimension of domain labels (1st dataset)')
parser.add_argument('--c2_dim', type=int, default=8, help=\
'dimension of domain labels (2nd dataset)')
parser.add_argument('--celeba_crop_size', type=int, default=178, help=\
'crop size for the CelebA dataset')
parser.add_argument('--rafd_crop_size', type=int, default=256, help=\
'crop size for the RaFD dataset')
parser.add_argument('--image_size', type=int, default=128, help=\
'image resolution')
parser.add_argument('--g_conv_dim', type=int, default=64, help=\
'number of conv filters in the first layer of G')
parser.add_argument('--d_conv_dim', type=int, default=64, help=\
'number of conv filters in the first layer of D')
parser.add_argument('--g_repeat_num', type=int, default=6, help=\
'number of residual blocks in G')
parser.add_argument('--d_repeat_num', type=int, default=6, help=\
'number of strided conv layers in D')
parser.add_argument('--lambda_cls', type=float, default=1, help=\
'weight for domain classification loss')
parser.add_argument('--lambda_rec', type=float, default=10, help=\
'weight for reconstruction loss')
parser.add_argument('--lambda_gp', type=float, default=10, help=\
'weight for gradient penalty')
parser.add_argument('--dataset', type=str, default='CelebA', choices=[
'CelebA', 'RaFD', 'Both'])
parser.add_argument('--batch_size', type=int, default=16, help=\
'mini-batch size')
parser.add_argument('--num_iters', type=int, default=200000, help=\
'number of total iterations for training D')
parser.add_argument('--num_iters_decay', type=int, default=100000, help
='number of iterations for decaying lr')
parser.add_argument('--g_lr', type=float, default=0.0001, help=\
'learning rate for G')
parser.add_argument('--d_lr', type=float, default=0.0001, help=\
'learning rate for D')
parser.add_argument('--n_critic', type=int, default=5, help=\
'number of D updates per each G update')
parser.add_argument('--beta1', type=float, default=0.5, help=\
'beta1 for Adam optimizer')
parser.add_argument('--beta2', type=float, default=0.999, help=\
'beta2 for Adam optimizer')
parser.add_argument('--resume_iters', type=int, default=None, help=\
'resume training from this step')
parser.add_argument('--selected_attrs', '--list', nargs='+', help=\
'selected attributes for the CelebA dataset', default=['Black_Hair',
'Blond_Hair', 'Brown_Hair', 'Male', 'Young'])
parser.add_argument('--test_iters', type=int, default=200000, help=\
'test model from this step')
parser.add_argument('--num_workers', type=int, default=1)
parser.add_argument('--mode', type=str, default='train', choices=[
'train', 'test'])
parser.add_argument('--use_tensorboard', type=str2bool, default=True)
parser.add_argument('--celeba_image_dir', type=str, default=\
'./data/celeba/images')
parser.add_argument('--attr_path', type=str, default=\
'./data/celeba/list_attr_celeba.txt')
parser.add_argument('--rafd_image_dir', type=str, default='data/RaFD/train'
)
parser.add_argument('--log_dir', type=str, default='stargan/logs')
parser.add_argument('--model_save_dir', type=str, default='stargan/models')
parser.add_argument('--sample_dir', type=str, default='stargan/samples')
parser.add_argument('--result_dir', type=str, default='stargan/results')
parser.add_argument('--log_step', type=int, default=10)
parser.add_argument('--sample_step', type=int, default=1000)
parser.add_argument('--model_save_step', type=int, default=10000)
parser.add_argument('--lr_update_step', type=int, default=1000)
config = parser.parse_args()
config.use_tensorboard = False
config.num_workers = 0
print(config)
main(config)