-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathDEMO
320 lines (292 loc) · 12.1 KB
/
DEMO
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
# import the necessary packages
import os
import random
import time
import cv2
import numpy as np
import PIL
import torch
import torch.nn as nn
from torchvision import transforms
from unisal import model_v3 as model
from unisal import utils
#from . import data
cv2.setNumThreads(0)
torch.set_num_threads(6)
class Trainer(utils.KwConfigClass):
"""
Trainer class that handles training, evaluation and inference.
Arguments:
num_epochs: Number of training epochs
optim_algo: Optimization algorithm (e.g. 'SGD')
momentum: Optimizer momentum if applicable
lr: Learning rate
lr_scheduler: Learning rate scheduler (e.g. 'ExponentialLR')
lr_gamma: Learnign rate decay for 'ExponentialLR' scheduler
weight_decay: Weight decay (except for CNN)
cnn_weight_decay: Backbone CNN weight decay
grad_clip: Gradient clipping magnitude
loss_metrics: Loss metrics. Defautls equivalent to [1].
loss_weights: Weights of the individual loss metrics. Defaults
equivalent to [1].
data_sources: Data sources. Default equivalent to [1].
batch_size: DHF1K batch size
salicon_batch_size: SALICON batch size
hollywood_batch_size: Hollywood-2 batch size
ucfsports_batch_size: UCFSports batch size
salicon_weight: Weight of the SALICON loss. Default is 0.5 to
account for the larger number of batches.
hollywood_weight: Weight of the Hollywood-2 loss.
ucfsports_weight: Weight of the UCF Sports loss.
data_cfg: Dictionary with kwargs for the DHF1KDataset class.
salicon_cfg: Dictionary with kwargs for the SALICONDataset class.
hollywood_cfg: Dictionary with kwargs for the HollywoodDataset
class.
ucfsports_cfg: Dictionary with kwargs for the UCFSportsDataset
class.
shuffle_datasets: Whether to train on batches of the individual
datasets in random order. If False, batches are drawn
in alternating order.
cnn_lr_factor: Factor of the backbone CNN learnign rate compared to
the overall learning rate.
train_cnn_after: Freeze the backbone CNN for N epochs.
cnn_eval: If True, keep the backbone CNN in evaluation mode (use
pretrained BatchNorm running estimates for mean and variance).
model_cfg: Dictionary with kwards for the model class
prefix: Prefix for the training folder name. Defaults to timestamp.
suffix: Suffix for the training folder name.
num_workers: Number of parallel workers for data loading.
chkpnt_warmup: Number of epochs before saving the first checkpoint.
chkpnt_epochs: Save a checkpoint every N epchs.
tboard: Use TensorboardX to visualize the training.
debug: Debug mode.
new_instance: Always leave this parameter as True. Reserved for
loading an Trainer class from a saved configuration file.
[1] https://arxiv.org/abs/1801.07424
"""
phases = ('train', 'valid')
all_data_sources = ('SALICON',)
def __init__(self,
num_epochs=32,
optim_algo='SGD',
# optim_algo='Adam',
momentum=0.9,
lr=0.04,
lr_scheduler='ExponentialLR',
lr_gamma=0.99,
weight_decay=1e-4,
cnn_weight_decay=1e-5,
grad_clip=2.,
loss_metrics=('kld', 'nss', 'cc'),
loss_weights=(1.5, -0.1, -0.1), # 0.1
data_sources=('SALICON',),
batch_size=4,
salicon_batch_size=32,
hollywood_batch_size=4,
ucfsports_batch_size=4,
salicon_weight=.5,
hollywood_weight=1.,
ucfsports_weight=1.,
data_cfg=None,
salicon_cfg=None,
hollywood_cfg=None,
ucfsports_cfg=None,
shuffle_datasets=True,
cnn_lr_factor=0.1,
train_cnn_after=2,
cnn_eval=True,
model_cfg=None,
prefix=None,
suffix='unisal',
num_workers=6,
chkpnt_warmup=3,
chkpnt_epochs=2,
tboard=True,
debug=False,
):
# Save training parameters
self.num_epochs = num_epochs
self.optim_algo = optim_algo
self.momentum = momentum
self.lr = lr
self.lr_scheduler = lr_scheduler
self.lr_gamma = lr_gamma
self.weight_decay = weight_decay
self.cnn_weight_decay = cnn_weight_decay
self.grad_clip = grad_clip
self.loss_metrics = loss_metrics
self.loss_weights = loss_weights
self.data_sources = data_sources
self.batch_size = batch_size
self.salicon_batch_size = salicon_batch_size
self.hollywood_batch_size = hollywood_batch_size
self.ucfsports_batch_size = ucfsports_batch_size
self.salicon_weight = salicon_weight
self.hollywood_weight = hollywood_weight
self.ucfsports_weight = ucfsports_weight
self.data_cfg = data_cfg or {}
self.salicon_cfg = salicon_cfg or {}
self.hollywood_cfg = hollywood_cfg or {}
self.ucfsports_cfg = ucfsports_cfg or {}
self.shuffle_datasets = shuffle_datasets
self.cnn_lr_factor = cnn_lr_factor
self.train_cnn_after = train_cnn_after
self.cnn_eval = cnn_eval
self.model_cfg = model_cfg or {}
if 'sources' not in self.model_cfg:
self.model_cfg['sources'] = data_sources
# Create training directory. Uses env var TRAIN_DIR
self.suffix = suffix
if prefix is None:
prefix = utils.get_timestamp()
self.prefix = prefix
# Other opertational parameters
self.num_workers = num_workers
self.chkpnt_warmup = chkpnt_warmup
self.chkpnt_epochs = chkpnt_epochs
# device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
device = 'cpu'
self.device = torch.device(device)
self.tboard = tboard
self.debug = debug
# Initialize properties etc.
self.epoch = 0
self.phase = None
self._datasets = {}
self._dataloaders = {}
self._scheduler = None
self._optimizer = None
self._model = None
self.best_epoch = 0
self.best_val_score = None
self.is_best = False
self.all_scalars = {}
self._writer = None
self._salicon_datasets = {}
self._salicon_dataloaders = {}
self._hollywood_datasets = {}
self._hollywood_dataloaders = {}
self._ucfsports_datasets = {}
self._ucfsports_dataloaders = {}
self.mit1003_finetuned = False
self.preproc_cfg = {
'rgb_mean': (0.485, 0.456, 0.406),
'rgb_std': (0.229, 0.224, 0.225),
}
@property
def model(self):
"""Set the model and move it to self.device"""
if self._model is None:
model_cls = model.get_model()
self._model = model_cls(**self.model_cfg)
self._model.to(self.device)
return self._model
def preprocess(self, img, out_size):
transformations = [
transforms.ToPILImage(),
transforms.Resize(
out_size, interpolation=PIL.Image.LANCZOS),
transforms.ToTensor(),
]
if 'rgb_mean' in self.preproc_cfg:
transformations.append(
transforms.Normalize(
self.preproc_cfg['rgb_mean'], self.preproc_cfg['rgb_std']))
processing = transforms.Compose(transformations)
tensor = processing(img)
return tensor
def run_inference(self, sample):
random.seed(27)
# Get the original resolution
target_size=(96,128)
# Define input sequence length
seq_len = int(6)
# Set the keyword arguments for the forward pass
model_kwargs = {
'source': 'SALICON',
'target_size': target_size}
# Select static or dynamic forward pass for Bypass-RNN
model_kwargs.update({'static': 'SALICON'})
# Prepare the prediction and target tensors
results_size = (1, 1, 1, model_kwargs['target_size'][0],model_kwargs['target_size'][1])
pred_seq = torch.full(results_size, 0, dtype=torch.float)
# Iterate over different offsets to create the interleaved predictions
# Get the data
# Preprocess the data
frame_seq = sample
if frame_seq.dim() == 3:
frame_seq = frame_seq.unsqueeze(0)
frame_seq = frame_seq.unsqueeze(0).float()
frame_idx_array=[0]
frame_seq = frame_seq.to(self.device)
# Run all sequences of the current offset
h0 = [None]
for start in range(0, 1, seq_len):
# Select the frames
end = min(len(frame_idx_array), start + seq_len)
this_frame_seq = frame_seq[:, start:end, :, :, :]
this_frame_idx_array = frame_idx_array[start:end]
# Forward pass
this_pred_seq, h0 = self.model(
this_frame_seq, h0=h0, return_hidden=True,
**model_kwargs)
# Insert the predictions into the prediction array
this_pred_seq = this_pred_seq.cpu()
pred_seq[:, this_frame_idx_array, :, :, :] =\
this_pred_seq
return pred_seq
def generate_predictions(self,img,train_id=None, save_predictions=True,load_weights=True):
"""Generate predictions for submission and visualization"""
self.train_dir='/home/liucong/unisal-master/training_runs/'+train_id
img = np.ascontiguousarray(img[:, :, ::-1])
img = self.preprocess(img, (96, 128))
if load_weights:
# Load the best weights, if available, otherwise the weights of
# the last epoch
try:
self.model.load_best_weights1(self.train_dir)
print('Best weights loaded')
except FileNotFoundError:
print('No best weights found')
self.model.load_last_chkpnt(self.train_dir)
print('Last checkpoint loaded')
with torch.no_grad():
# Prepare the model
self.model.to(self.device)
self.model.eval()
torch.cuda.empty_cache()
t0=time.time()
pred_seq = self.run_inference(img)
t1=time.time()-t0
print(t1)
times=[]
for i in range(1000):
to = time.time()
pred_seq=self.run_inference(img)
t1 = time.time() - to
times.append(t1)
print()
dt = sum(times) / len(times)
print("Avg single-frame CPU time:" + str(dt) + "s(" + str(1 / dt) + "fps)")
dt = min(times)
print("Min single-frame CPU time:" + str(dt) + "s(" + str(1 / dt) + "fps)")
dt = max(times)
print("Avg single-frame CPU time:" + str(dt) + "s(" + str(1 / dt) + "fps)")
if save_predictions:
smap = pred_seq[:, 0, ...]
# Posporcess prediction
smap = smap.exp()
smap = torch.squeeze(smap)
smap = utils.to_numpy(smap)
# Save prediction as image
filename = 'boss1_000001.jpg'
# print(filename)
smap = (smap / np.amax(smap) * 255).astype(np.uint8)
pred_file = self.train_dir +'/'+filename
cv2.imwrite(
str(pred_file), smap, [cv2.IMWRITE_JPEG_QUALITY, 100])
if __name__ == '__main__':
unisal = Trainer()
input_img = cv2.imread('./boss1_000001.jpg')
result = unisal.generate_predictions(input_img,train_id = '2020-10-25_10:42:52_unisal')
print(result)