-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathconfig.py
174 lines (133 loc) · 6.9 KB
/
config.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
import argparse
import os
import numpy as np
import os
import random
import torch
from utils.logger import Logger
from utils.visualizer import Visualizer
is_train, split = None, None
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
# hardware
parser.add_argument('--n_workers', type=int, default=8, help='number of threads')
parser.add_argument('--gpus', type=str, default='1', help='visible GPU ids, separated by comma')
# data
parser.add_argument('--dset_dir', type=str, default='./')
parser.add_argument('--dset_name', type=str, default='moving_mnist')
# Input and Output length
parser.add_argument('--n_frames_input', type=int, default=10)
parser.add_argument('--n_frames_output', type=int, default=10)
# Components
parser.add_argument('--num_objects', type=int, nargs='+', default=[2], help='Max number of digits in Moving MNIST videos.')
parser.add_argument('--n_components', type=int, default=2)
# Dimensionality hyperparameters
parser.add_argument('--model', type=str, default='dive', help='Model name')
parser.add_argument('--image_latent_size', type=int, default=256,
help='Output size of image encoder')
parser.add_argument('--appearance_latent_size', type=int, default=128,
help='Size of appearance mixed vector and half of the size of the static appearance vector.')
parser.add_argument('--pose_latent_size', type=int, default=3,
help='Size of pose vector')
parser.add_argument('--hidden_size', type=int, default=64,
help='Size of the main hidden variables')
parser.add_argument('--var_app_size', type=int, default=48, help='Size of the varying appearance hidden representation')
parser.add_argument('--ngf', type=int, default=8,
help='number of channels in encoder and decoder')
parser.add_argument('--stn_scale_prior', type=float, default=3,
help='The scale of the spatial transformer prior.')
# ckpt and logging
parser.add_argument('--ckpt_dir', type=str,
default=os.path.join('./tensorboard', 'ckpt'),
help='directory for checkpoints and logs')
parser.add_argument('--ckpt_name', type=str, default='', help='checkpoint name')
parser.add_argument('--log_every', type=int, default=50, help='log every x steps')
parser.add_argument('--save_every', type=int, default=50, help='save every x epochs')
parser.add_argument('--evaluate_every', type=int, default=50, help='evaluate on val set every x epochs')
# Variations to moving MNIST
parser.add_argument('--image_size', type=int, default=[64, 64])
parser.add_argument('--crop_size', type=int, ,nargs='+', default=[64, 64], help='Visible size on the bottom right side of the frame')
parser.add_argument('--use_crop_size', type=bool, default=False, help='save every x epochs')
parser.add_argument('--num_missing', type=int, default=1, help='Number of timesteps with missing object per component')
parser.add_argument('--ini_et_alpha', type=int, default=100, help='Initial value alpha for the elastic transformation')
# Flags
parser.add_argument('--with_imputation', type=bool, default=True, help='Whether we use imputation')
parser.add_argument('--with_var_appearance', type=bool, default=True, help='Whether we use varying appearance')
parser.add_argument('--soft_labels', type=bool, default=False, help='Whether we use Soft-labels (True) or Hard-labels (False)')
# Training helpers
parser.add_argument('--gamma_p_app', type=float, default=0.3,
help='Probability of generating static-only appearance')
parser.add_argument('--gamma_p_imp', type=float, default=0.25,
help='Probability of imputing independently from the missing labels')
parser.add_argument('--gamma_switch_step', type=int, default=3e3,
help='How many iterations before reducing gamma of appearance.')
def parse(is_train):
if is_train:
parser.add_argument('--batch_size', type=int, default=64, help='batch size per gpu')
parser.add_argument('--n_epochs', type=int, default=1000, help='total # of epochs')
parser.add_argument('--n_iters', type=int, default=0, help='total # of iterations')
parser.add_argument('--start_epoch', type=int, default=0, help='starting epoch')
parser.add_argument('--lr_init', type=float, default=1e-3, help='initial learning rate')
parser.add_argument('--lr_decay', type=int, default=1, choices=[0, 1], help='whether to decay learning rate')
parser.add_argument('--load_ckpt_dir', type=str, default='', help='load checkpoint dir placeholder')
parser.add_argument('--load_ckpt_epoch', type=int, default=0, help='epoch to load checkpoint')
elif not is_train:
# hyperparameters
parser.add_argument('--batch_size', type=int, default=1, help='batch size')
parser.add_argument('--which_epochs', type=int, nargs='+', default=[-1],
help='which epochs to evaluate, -1 to load latest checkpoint')
parser.add_argument('--save_visuals', type=int, default=0, help='Save results to tensorboard')
parser.add_argument('--save_all_results', type=int, default=0, help='Save results to tensorboard')
opt = parser.parse_args()
if opt.dset_name == 'moving_mnist':
opt.n_channels = 1
opt.stn_scale_prior = 3
opt.num_objects = [2]
opt.n_components = 2
if opt.crop_size[0] < opt.image_size[0] or opt.crop_size[1] < opt.image_size[1]:
opt.use_crop_size = True
elif opt.dset_name == 'pedestrian':
opt.n_channels = 1
opt.image_size = [256, 256]
opt.crop_size = [256, 256]
opt.stn_scale_prior = 3.5
opt.num_objects = [3]
opt.n_components = 3
opt.n_frames_output = 5
opt.gamma_switch_step =5e3
opt.batch_size = 32
opt.num_missing = 1
else:
raise NotImplementedError
assert opt.n_frames_input > 0 and opt.n_frames_output > 0
opt.dset_path = os.path.join(opt.dset_dir, opt.dset_name)
if is_train:
opt.is_train = True
opt.split = 'train'
elif not is_train:
opt.is_train = False
opt.split = 'val'
ckpt_name = 'dive_2'
opt.ckpt_name = ckpt_name
opt.ckpt_path = os.path.join(opt.ckpt_dir, opt.dset_name, ckpt_name)
# Logging
log = ['Arguments: ']
for k, v in sorted(vars(opt).items()):
log.append('{}: {}'.format(k, v))
return opt, log
def build(is_train, tb_dir=None):
opt, log = parse(is_train)
os.environ['CUDA_VISIBLE_DEVICES'] = opt.gpus
os.makedirs(opt.ckpt_path, exist_ok=True)
# Set seed
torch.manual_seed(666)
torch.cuda.manual_seed_all(666)
np.random.seed(666)
random.seed(666)
logger = Logger(opt.ckpt_path, opt.split)
if tb_dir is not None:
tb_path = os.path.join(opt.ckpt_path, tb_dir)
vis = Visualizer(tb_path)
else:
vis = None
logger.print(log)
return opt, logger, vis