-
Notifications
You must be signed in to change notification settings - Fork 280
/
__init__.py
427 lines (363 loc) · 14 KB
/
__init__.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
import os
import time
from argparse import Namespace
from pathlib import Path
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from tensorboardX import SummaryWriter
from torch.autograd import Variable
from torchbenchmark import DATA_PATH
from torchbenchmark.tasks import COMPUTER_VISION
from ...util.model import BenchmarkModel
from .data_loader import VideoData
from .functions import compose_image_withshift, write_tb_log
from .loss_functions import alpha_gradient_loss, alpha_loss, compose_loss, GANloss
from .networks import conv_init, MultiscaleDiscriminator, ResnetConditionHR
torch.backends.cudnn.deterministic = False
torch.backends.cudnn.benchmark = True
def _collate_filter_none(batch):
batch = list(filter(lambda x: x is not None, batch))
return torch.utils.data.dataloader.default_collate(batch)
def _create_data_dir():
data_dir = Path(__file__).parent.joinpath(".data")
data_dir.mkdir(parents=True, exist_ok=True)
return data_dir
class Model(BenchmarkModel):
task = COMPUTER_VISION.PATTERN_RECOGNITION
# Original btach size: 4
# Original hardware: unknown
# Source: https://arxiv.org/pdf/2004.00626.pdf
DEFAULT_TRAIN_BSIZE = 4
DEFAULT_EVAL_BSIZE = 1
ALLOW_CUSTOMIZE_BSIZE = False
def __init__(self, test, device, batch_size=None, extra_args=[]):
super().__init__(
test=test, device=device, batch_size=batch_size, extra_args=extra_args
)
self.opt = Namespace(
**{
"n_blocks1": 7,
"n_blocks2": 3,
"batch_size": self.batch_size,
"resolution": 512,
"name": "Real_fixed",
}
)
input_data_dir_name = "Background_Matting_inputs"
datadir = os.path.join(DATA_PATH, input_data_dir_name)
if not os.path.exists(datadir):
try:
import shutil
from torchbenchmark.util.framework.fb.installer import install_data
datadir = install_data(input_data_dir_name)
# Input data files are decompressed into the folder one level deeper.
datadir = os.path.join(datadir, input_data_dir_name)
except Exception as e:
msg = f"Failed to download data from manifold: {e}"
raise RuntimeError(msg) from e
csv_file_path = _create_data_dir().joinpath("Video_data_train_processed.csv")
with open(f"{datadir}/Video_data_train.csv", "r") as r:
with open(csv_file_path, "w") as w:
w.write(r.read().format(scriptdir=datadir))
data_config_train = {"reso": (self.opt.resolution, self.opt.resolution)}
traindata = VideoData(
csv_file=csv_file_path, data_config=data_config_train, transform=None
)
train_loader = torch.utils.data.DataLoader(
traindata,
batch_size=self.opt.batch_size,
shuffle=True,
num_workers=0,
collate_fn=_collate_filter_none,
)
self.train_data = []
for data in train_loader:
self.train_data.append(data)
for key in data:
data[key].to(self.device)
netB = ResnetConditionHR(
input_nc=(3, 3, 1, 4),
output_nc=4,
n_blocks1=self.opt.n_blocks1,
n_blocks2=self.opt.n_blocks2,
)
netB.to(self.device)
netB.eval()
for param in netB.parameters(): # freeze netB
param.requires_grad = False
self.netB = netB
netG = ResnetConditionHR(
input_nc=(3, 3, 1, 4),
output_nc=4,
n_blocks1=self.opt.n_blocks1,
n_blocks2=self.opt.n_blocks2,
)
netG.apply(conv_init)
self.netG = netG
self.netG.to(self.device)
netD = MultiscaleDiscriminator(
input_nc=3, num_D=1, norm_layer=nn.InstanceNorm2d, ndf=64
)
netD.apply(conv_init)
# netD = nn.DataParallel(netD)
self.netD = netD
self.netD.to(self.device)
self.l1_loss = alpha_loss()
self.c_loss = compose_loss()
self.g_loss = alpha_gradient_loss()
self.GAN_loss = GANloss()
self.optimizerG = optim.Adam(netG.parameters(), lr=1e-4)
self.optimizerD = optim.Adam(netD.parameters(), lr=1e-5)
self.log_writer = SummaryWriter(datadir)
self.model_dir = datadir
def jit_callback(self):
for data in self.train_data:
bg, image, seg, multi_fr = (
data["bg"],
data["image"],
data["seg"],
data["multi_fr"],
)
bg, image, seg, multi_fr = (
Variable(bg.to(self.device)),
Variable(image.to(self.device)),
Variable(seg.to(self.device)),
Variable(multi_fr.to(self.device)),
)
self.netB = torch.jit.trace(self.netB, (image, bg, seg, multi_fr))
self.netG = torch.jit.trace(self.netG, (image, bg, seg, multi_fr))
break
def get_module(self):
# use netG (generation) for the return module
for _i, data in enumerate(self.train_data):
bg, image, seg, multi_fr, seg_gt, back_rnd = (
data["bg"],
data["image"],
data["seg"],
data["multi_fr"],
data["seg-gt"],
data["back-rnd"],
)
return self.netG, (
image.to(self.device),
bg.to(self.device),
seg.to(self.device),
multi_fr.to(self.device),
)
def set_module(self, module):
self.netG = module
def train(self):
self.netG.train()
self.netD.train()
lG, lD, GenL, DisL_r, DisL_f, alL, fgL, compL, elapse_run, elapse = (
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
)
t0 = time.time()
KK = len(self.train_data)
wt = 1
epoch = 0
step = 50
num_of_batches = 1
for i, data in zip(range(num_of_batches), self.train_data):
# Initiating
bg, image, seg, multi_fr, seg_gt, back_rnd = (
data["bg"],
data["image"],
data["seg"],
data["multi_fr"],
data["seg-gt"],
data["back-rnd"],
)
bg, image, seg, multi_fr, seg_gt, back_rnd = (
Variable(bg.to(self.device)),
Variable(image.to(self.device)),
Variable(seg.to(self.device)),
Variable(multi_fr.to(self.device)),
Variable(seg_gt.to(self.device)),
Variable(back_rnd.to(self.device)),
)
mask0 = Variable(torch.ones(seg.shape).to(self.device))
tr0 = time.time()
# pseudo-supervision
alpha_pred_sup, fg_pred_sup = self.netB(image, bg, seg, multi_fr)
if self.device == "cuda":
mask = (alpha_pred_sup > -0.98).type(torch.cuda.FloatTensor)
mask1 = (seg_gt > 0.95).type(torch.cuda.FloatTensor)
else:
mask = (alpha_pred_sup > -0.98).type(torch.FloatTensor)
mask1 = (seg_gt > 0.95).type(torch.FloatTensor)
# Train Generator
alpha_pred, fg_pred = self.netG(image, bg, seg, multi_fr)
# pseudo-supervised losses
al_loss = self.l1_loss(
alpha_pred_sup, alpha_pred, mask0
) + 0.5 * self.g_loss(alpha_pred_sup, alpha_pred, mask0)
fg_loss = self.l1_loss(fg_pred_sup, fg_pred, mask)
# compose into same background
comp_loss = self.c_loss(image, alpha_pred, fg_pred, bg, mask1)
# randomly permute the background
perm = torch.LongTensor(np.random.permutation(bg.shape[0]))
bg_sh = bg[perm, :, :, :]
if self.device == "cuda":
al_mask = (alpha_pred > 0.95).type(torch.cuda.FloatTensor)
else:
al_mask = (alpha_pred > 0.95).type(torch.FloatTensor)
# Choose the target background for composition
# back_rnd: contains separate set of background videos captured
# bg_sh: contains randomly permuted captured background from the same minibatch
if np.random.random_sample() > 0.5:
bg_sh = back_rnd
image_sh = compose_image_withshift(
alpha_pred, image * al_mask + fg_pred * (1 - al_mask), bg_sh, seg
)
fake_response = self.netD(image_sh)
loss_ganG = self.GAN_loss(fake_response, label_type=True)
lossG = loss_ganG + wt * (
0.05 * comp_loss + 0.05 * al_loss + 0.05 * fg_loss
)
self.optimizerG.zero_grad()
lossG.backward()
self.optimizerG.step()
# Train Discriminator
fake_response = self.netD(image_sh)
real_response = self.netD(image)
loss_ganD_fake = self.GAN_loss(fake_response, label_type=False)
loss_ganD_real = self.GAN_loss(real_response, label_type=True)
lossD = (loss_ganD_real + loss_ganD_fake) * 0.5
# Update discriminator for every 5 generator update
if i % 5 == 0:
self.optimizerD.zero_grad()
lossD.backward()
self.optimizerD.step()
lG += lossG.data
lD += lossD.data
GenL += loss_ganG.data
DisL_r += loss_ganD_real.data
DisL_f += loss_ganD_fake.data
alL += al_loss.data
fgL += fg_loss.data
compL += comp_loss.data
self.log_writer.add_scalar("Generator Loss", lossG.data, epoch * KK + i + 1)
self.log_writer.add_scalar(
"Discriminator Loss", lossD.data, epoch * KK + i + 1
)
self.log_writer.add_scalar(
"Generator Loss: Fake", loss_ganG.data, epoch * KK + i + 1
)
self.log_writer.add_scalar(
"Discriminator Loss: Real", loss_ganD_real.data, epoch * KK + i + 1
)
self.log_writer.add_scalar(
"Discriminator Loss: Fake", loss_ganD_fake.data, epoch * KK + i + 1
)
self.log_writer.add_scalar(
"Generator Loss: Alpha", al_loss.data, epoch * KK + i + 1
)
self.log_writer.add_scalar(
"Generator Loss: Fg", fg_loss.data, epoch * KK + i + 1
)
self.log_writer.add_scalar(
"Generator Loss: Comp", comp_loss.data, epoch * KK + i + 1
)
t1 = time.time()
elapse += t1 - t0
elapse_run += t1 - tr0
t0 = t1
if i % step == (step - 1):
print(
"[%d, %5d] Gen-loss: %.4f Disc-loss: %.4f Alpha-loss: %.4f Fg-loss: %.4f Comp-loss: %.4f Time-all: %.4f Time-fwbw: %.4f"
% (
epoch + 1,
i + 1,
lG / step,
lD / step,
alL / step,
fgL / step,
compL / step,
elapse / step,
elapse_run / step,
)
)
lG, lD, GenL, DisL_r, DisL_f, alL, fgL, compL, elapse_run, elapse = (
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
)
write_tb_log(image, "image", self.log_writer, i)
write_tb_log(seg, "seg", self.log_writer, i)
write_tb_log(alpha_pred_sup, "alpha-sup", self.log_writer, i)
write_tb_log(alpha_pred, "alpha_pred", self.log_writer, i)
write_tb_log(fg_pred_sup * mask, "fg-pred-sup", self.log_writer, i)
write_tb_log(fg_pred * mask, "fg_pred", self.log_writer, i)
# composition
alpha_pred = (alpha_pred + 1) / 2
comp = fg_pred * alpha_pred + (1 - alpha_pred) * bg
write_tb_log(comp, "composite-same", self.log_writer, i)
write_tb_log(image_sh, "composite-diff", self.log_writer, i)
del comp
del (
mask,
back_rnd,
mask0,
seg_gt,
mask1,
bg,
alpha_pred,
alpha_pred_sup,
image,
fg_pred_sup,
fg_pred,
seg,
multi_fr,
image_sh,
bg_sh,
fake_response,
real_response,
al_loss,
fg_loss,
comp_loss,
lossG,
lossD,
loss_ganD_real,
loss_ganD_fake,
loss_ganG,
)
if epoch % 2 == 0:
torch.save(
self.netG.state_dict(),
os.path.join(self.model_dir, "netG_epoch_%d.pth" % (epoch)),
)
torch.save(
self.optimizerG.state_dict(),
os.path.join(self.model_dir, "optimG_epoch_%d.pth" % (epoch)),
)
torch.save(
self.netD.state_dict(),
os.path.join(self.model_dir, "netD_epoch_%d.pth" % (epoch)),
)
torch.save(
self.optimizerD.state_dict(),
os.path.join(self.model_dir, "optimD_epoch_%d.pth" % (epoch)),
)
# Change weight every 2 epoch to put more stress on discriminator weight and less on pseudo-supervision
wt = wt / 2
def eval(self):
raise NotImplementedError()