-
Notifications
You must be signed in to change notification settings - Fork 39
/
Copy pathaae_mnist.py
364 lines (305 loc) · 14.4 KB
/
aae_mnist.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: aae_mnist.py
# Author: Qian Ge <geqian1001@gmail.com>
import sys
import argparse
import numpy as np
import tensorflow as tf
sys.path.append('../')
from src.dataflow.mnist import MNISTData
from src.models.aae import AAE
from src.helper.trainer import Trainer
from src.helper.generator import Generator
from src.helper.visualizer import Visualizer
DATA_PATH = '/home/qge2/workspace/data/MNIST_data/'
SAVE_PATH = '/home/qge2/workspace/data/out/vae/'
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument('--train', action='store_true',
help='Train the model of Fig 1 and 3 in the paper.')
parser.add_argument('--train_supervised', action='store_true',
help='Train the model of Fig 6 in the paper.')
parser.add_argument('--train_semisupervised', action='store_true',
help='Train the model of Fig 8 in the paper.')
parser.add_argument('--label', action='store_true',
help='Incorporate label info (Fig 3 in the paper).')
parser.add_argument('--generate', action='store_true',
help='Sample images from trained model.')
parser.add_argument('--viz', action='store_true',
help='Visualize learned model when ncode=2.')
parser.add_argument('--supervise', action='store_true',
help='Sampling from supervised model (Fig 6 in the paper).')
parser.add_argument('--load', type=int, default=99,
help='The epoch ID of pre-trained model to be restored.')
parser.add_argument('--ncode', type=int, default=2,
help='Dimension of code')
parser.add_argument('--dist_type', type=str, default='gaussian',
help='Prior distribution to be imposed on latent z (gaussian and gmm).')
parser.add_argument('--noise', action='store_true',
help='Add noise to encoder input (Gaussian with std=0.6).')
parser.add_argument('--lr', type=float, default=2e-4,
help='Initial learning rate')
parser.add_argument('--dropout', type=float, default=1.0,
help='Keep probability for dropout')
parser.add_argument('--bsize', type=int, default=128,
help='Batch size')
parser.add_argument('--maxepoch', type=int, default=100,
help='Max number of epochs')
parser.add_argument('--encw', type=float, default=1.,
help='Weight of autoencoder loss')
parser.add_argument('--genw', type=float, default=6.,
help='Weight of z generator loss')
parser.add_argument('--disw', type=float, default=6.,
help='Weight of z discriminator loss')
parser.add_argument('--clsw', type=float, default=1.,
help='Weight of semi-supervised loss')
parser.add_argument('--ygenw', type=float, default=6.,
help='Weight of y generator loss')
parser.add_argument('--ydisw', type=float, default=6.,
help='Weight of y discriminator loss')
return parser.parse_args()
def preprocess_im(im):
""" normalize input image to [-1., 1.] """
im = im / 255. * 2. - 1.
return im
def read_train_data(batch_size, n_use_label=None, n_use_sample=None):
""" Function for load training data
If n_use_label or n_use_sample is not None, samples will be
randomly picked to have a balanced number of examples
Args:
batch_size (int): batch size
n_use_label (int): how many labels are used for training
n_use_sample (int): how many samples are used for training
Retuns:
MNISTData
"""
data = MNISTData('train',
data_dir=DATA_PATH,
shuffle=True,
pf=preprocess_im,
n_use_label=n_use_label,
n_use_sample=n_use_sample,
batch_dict_name=['im', 'label'])
data.setup(epoch_val=0, batch_size=batch_size)
return data
def read_valid_data(batch_size):
""" Function for load validation data """
data = MNISTData('test',
data_dir=DATA_PATH,
shuffle=True,
pf=preprocess_im,
batch_dict_name=['im', 'label'])
data.setup(epoch_val=0, batch_size=batch_size)
return data
def semisupervised_train():
""" Function for semisupervised training (Fig 8 in the paper)
Validation will be processed after each epoch of training
Loss of each modules will be averaged and saved in summaries
every 100 steps.
"""
FLAGS = get_args()
# load dataset
train_data_unlabel = read_train_data(FLAGS.bsize)
train_data_label = read_train_data(FLAGS.bsize, n_use_sample=1280)
train_data = {'unlabeled': train_data_unlabel, 'labeled': train_data_label}
valid_data = read_valid_data(FLAGS.bsize)
# create an AAE model for semisupervised training
train_model = AAE(
n_code=FLAGS.ncode, wd=0, n_class=10, add_noise=FLAGS.noise,
enc_weight=FLAGS.encw, gen_weight=FLAGS.genw, dis_weight=FLAGS.disw,
cat_dis_weight=FLAGS.ydisw, cat_gen_weight=FLAGS.ygenw, cls_weight=FLAGS.clsw)
train_model.create_semisupervised_train_model()
# create an separated AAE model for semisupervised validation
# shared weights with training model
cls_valid_model = AAE(n_code=FLAGS.ncode, n_class=10)
cls_valid_model.create_semisupervised_test_model()
# initialize a trainer for training
trainer = Trainer(train_model,
cls_valid_model=cls_valid_model,
generate_model=None,
train_data=train_data,
init_lr=FLAGS.lr,
save_path=SAVE_PATH)
sessconfig = tf.ConfigProto()
sessconfig.gpu_options.allow_growth = True
with tf.Session(config=sessconfig) as sess:
writer = tf.summary.FileWriter(SAVE_PATH)
sess.run(tf.global_variables_initializer())
writer.add_graph(sess.graph)
for epoch_id in range(FLAGS.maxepoch):
trainer.train_semisupervised_epoch(
sess, ae_dropout=FLAGS.dropout, summary_writer=writer)
trainer.valid_semisupervised_epoch(
sess, valid_data, summary_writer=writer)
def supervised_train():
""" Function for supervised training (Fig 6 in the paper)
Validation will be processed after each epoch of training.
Loss of each modules will be averaged and saved in summaries
every 100 steps. Every 10 epochs, 10 different style for 10 digits
will be saved.
"""
FLAGS = get_args()
# load dataset
train_data = read_train_data(FLAGS.bsize)
valid_data = read_valid_data(FLAGS.bsize)
# create an AAE model for supervised training
model = AAE(n_code=FLAGS.ncode, wd=0, n_class=10,
use_supervise=True, add_noise=FLAGS.noise,
enc_weight=FLAGS.encw, gen_weight=FLAGS.genw,
dis_weight=FLAGS.disw)
model.create_train_model()
# Create an separated AAE model for supervised validation
# shared weights with training model. This model is used to
# generate 10 different style for 10 digits for every 10 epochs.
valid_model = AAE(n_code=FLAGS.ncode, use_supervise=True, n_class=10)
valid_model.create_generate_style_model(n_sample=10)
# initialize a trainer for training
trainer = Trainer(model, valid_model, train_data,
init_lr=FLAGS.lr, save_path=SAVE_PATH)
# initialize a generator for generating style images
generator = Generator(
generate_model=valid_model, save_path=SAVE_PATH, n_labels=10)
sessconfig = tf.ConfigProto()
sessconfig.gpu_options.allow_growth = True
with tf.Session(config=sessconfig) as sess:
writer = tf.summary.FileWriter(SAVE_PATH)
saver = tf.train.Saver()
sess.run(tf.global_variables_initializer())
writer.add_graph(sess.graph)
for epoch_id in range(FLAGS.maxepoch):
trainer.train_z_gan_epoch(
sess, ae_dropout=FLAGS.dropout, summary_writer=writer)
trainer.valid_epoch(sess, dataflow=valid_data, summary_writer=writer)
if epoch_id % 10 == 0:
saver.save(sess, '{}aae-epoch-{}'.format(SAVE_PATH, epoch_id))
generator.sample_style(sess, valid_data, plot_size=10,
file_id=epoch_id, n_sample=10)
saver.save(sess, '{}aae-epoch-{}'.format(SAVE_PATH, epoch_id))
def train():
""" Function for unsupervised training and incorporate
label info in adversarial regularization
(Fig 1 and 3 in the paper)
Validation will be processed after each epoch of training.
Loss of each modules will be averaged and saved in summaries
every 100 steps. Random samples and learned latent space will
be saved for every 10 epochs.
"""
FLAGS = get_args()
# image size for visualization. plot_size * plot_size digits will be visualized.
plot_size = 20
# Use 10000 labels info to train latent space
n_use_label = 10000
# load data
train_data = read_train_data(FLAGS.bsize, n_use_label=n_use_label)
valid_data = read_valid_data(FLAGS.bsize)
# create an AAE model for training
model = AAE(n_code=FLAGS.ncode, wd=0, n_class=10,
use_label=FLAGS.label, add_noise=FLAGS.noise,
enc_weight=FLAGS.encw, gen_weight=FLAGS.genw,
dis_weight=FLAGS.disw)
model.create_train_model()
# Create an separated AAE model for validation shared weights
# with training model. This model is used to
# randomly sample model data every 10 epoches.
valid_model = AAE(n_code=FLAGS.ncode, n_class=10)
valid_model.create_generate_model(b_size=400)
# initialize a trainer for training
trainer = Trainer(model, valid_model, train_data,
distr_type=FLAGS.dist_type, use_label=FLAGS.label,
init_lr=FLAGS.lr, save_path=SAVE_PATH)
# Initialize a visualizer and a generator to monitor learned
# latent space and data generation.
# Latent space visualization only for code dim = 2
if FLAGS.ncode == 2:
visualizer = Visualizer(model, save_path=SAVE_PATH)
generator = Generator(generate_model=valid_model, save_path=SAVE_PATH,
distr_type=FLAGS.dist_type, n_labels=10,
use_label=FLAGS.label)
sessconfig = tf.ConfigProto()
sessconfig.gpu_options.allow_growth = True
with tf.Session(config=sessconfig) as sess:
writer = tf.summary.FileWriter(SAVE_PATH)
saver = tf.train.Saver()
sess.run(tf.global_variables_initializer())
writer.add_graph(sess.graph)
for epoch_id in range(FLAGS.maxepoch):
trainer.train_z_gan_epoch(sess, ae_dropout=FLAGS.dropout, summary_writer=writer)
trainer.valid_epoch(sess, dataflow=valid_data, summary_writer=writer)
if epoch_id % 10 == 0:
saver.save(sess, '{}aae-epoch-{}'.format(SAVE_PATH, epoch_id))
generator.generate_samples(sess, plot_size=plot_size, file_id=epoch_id)
if FLAGS.ncode == 2:
visualizer.viz_2Dlatent_variable(sess, valid_data, file_id=epoch_id)
saver.save(sess, '{}aae-epoch-{}'.format(SAVE_PATH, epoch_id))
def generate():
""" function for sampling images from trained model """
FLAGS = get_args()
plot_size = 20
# Greate model for sampling
generate_model = AAE(n_code=FLAGS.ncode, n_class=10)
if FLAGS.supervise:
# create samping model of Fig 6 in the paper
generate_model.create_generate_style_model(n_sample=10)
else:
# create samping model of Fig 1 and 3 in the paper
generate_model.create_generate_model(b_size=plot_size*plot_size)
# initalize the Generator for sampling
generator = Generator(generate_model=generate_model, save_path=SAVE_PATH,
distr_type=FLAGS.dist_type, n_labels=10, use_label=FLAGS.label)
sessconfig = tf.ConfigProto()
sessconfig.gpu_options.allow_growth = True
with tf.Session(config=sessconfig) as sess:
saver = tf.train.Saver()
sess.run(tf.global_variables_initializer())
saver.restore(sess, '{}aae-epoch-{}'.format(SAVE_PATH, FLAGS.load))
if FLAGS.supervise:
generator.sample_style(sess, plot_size=10, n_sample=10)
else:
generator.generate_samples(sess, plot_size=plot_size)
def visualize():
""" function for visualize latent space of trained model when ncode = 2 """
FLAGS = get_args()
if FLAGS.ncode != 2:
raise ValueError('Visualization only for ncode = 2!')
plot_size = 20
# read validation set
valid_data = MNISTData('test',
data_dir=DATA_PATH,
shuffle=True,
pf=preprocess_im,
batch_dict_name=['im', 'label'])
valid_data.setup(epoch_val=0, batch_size=FLAGS.bsize)
# create model for computing the latent z
model = AAE(n_code=FLAGS.ncode, use_label=FLAGS.label, n_class=10)
model.create_train_model()
# create model for sampling images
valid_model = AAE(n_code=FLAGS.ncode)
valid_model.create_generate_model(b_size=400)
# initialize Visualizer and Generator
visualizer = Visualizer(model, save_path=SAVE_PATH)
generator = Generator(generate_model=valid_model, save_path=SAVE_PATH,
distr_type=FLAGS.dist_type, n_labels=10,
use_label=FLAGS.label)
sessconfig = tf.ConfigProto()
sessconfig.gpu_options.allow_growth = True
with tf.Session(config=sessconfig) as sess:
saver = tf.train.Saver()
sess.run(tf.global_variables_initializer())
saver.restore(sess, '{}aae-epoch-{}'.format(SAVE_PATH, FLAGS.load))
# visulize the learned latent space
visualizer.viz_2Dlatent_variable(sess, valid_data)
# visulize the learned manifold
generator.generate_samples(sess, plot_size=plot_size, manifold=True)
if __name__ == '__main__':
FLAGS = get_args()
if FLAGS.train:
train()
elif FLAGS.train_supervised:
supervised_train()
elif FLAGS.train_semisupervised:
semisupervised_train()
elif FLAGS.generate:
generate()
elif FLAGS.viz:
visualize()