-
Notifications
You must be signed in to change notification settings - Fork 0
/
test_sequence.py
157 lines (123 loc) · 4.32 KB
/
test_sequence.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
# Package Includes
from __future__ import division
import os
import socket
import timeit
from datetime import datetime
# PyTorch includes
import torch
import torch.optim as optim
from torchvision import transforms
from torch.utils.data import DataLoader
# Custom includes
from dataloaders import sequence_loader as db
from dataloaders import custom_transforms as tr
import networks.salar as salar
from dataloaders.helpers import *
from mypath import Path
from saliency.saliency_metrics import AUC_Judd, CC, NSS, SIM
import cv2
from saliency.postprocess_util import postprocess_prediction
from saliency.postprocess_util import normalize_map
import numpy as np
import imageio
def superimpose(image, heatmap):
hmap = heatmap/heatmap.max()
hmap = (hmap*255).astype(np.uint8)
hmap = cv2.applyColorMap(hmap, 4)
img = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
img = 0.3*img + 0.7*hmap
img = (img).astype(np.uint8)
return img
seq_name = 'Diving-Side-001'
save_dir = Path.save_root_dir()
# Select which GPU, -1 if CPU
gpu_id = 1
device = torch.device("cuda:"+str(gpu_id) if torch.cuda.is_available() else "cpu")
if torch.cuda.is_available():
print('Using GPU: {} '.format(gpu_id))
VIS_RES = 1 # Visualize the results?
SAVE_RES = 1 # Save the results?
WRITE_RES = 1 # Write the results?
modelName = 'ucf_salar'
ucf = True
hollywood = False
dhf1k = False
if ucf:
db_root_dir = './dataloaders/ucf'
inputRes = (180, 320, 3)
elif hollywood:
db_root_dir = './dataloaders/hollywood'
inputRes = (180, 320, 3)
elif dhf1k:
db_root_dir = './dataloaders/dhf1k'
inputRes = (180, 320, 3)
print('Using model: ' + str(modelName))
HOME_PATH = './runs'
if ucf:
print("ucf")
HOME_PATH = os.path.join(HOME_PATH, 'ucf')
if hollywood:
print("hollywood")
HOME_PATH = os.path.join(HOME_PATH, 'hollywood')
if dhf1k:
print("dhf1k")
HOME_PATH = os.path.join(HOME_PATH, 'dhf1k')
OUTPUT_PATH = os.path.join(HOME_PATH, 'visualize')
# create output file
if not os.path.exists(OUTPUT_PATH):
os.makedirs(OUTPUT_PATH)
# Network definition
net = salar.SalAR(pretrained=0)
checkpoint = torch.load(os.path.join(save_dir, modelName + '.pth'), map_location=lambda storage, loc: storage)
net.load_state_dict(checkpoint)
net.to(device)
# Testing dataset and its iterator
db_test = db.SequenceLoader(inputRes=inputRes, originalRes=None, db_root_dir=db_root_dir, transform=tr.ToTensor(), seq_name=seq_name)
testloader = DataLoader(db_test, batch_size=1, shuffle=False, num_workers=0)
num_img_ts = len(testloader)
loss_tr = []
aveGrad = 0
import matplotlib.pyplot as plt
fig, ax = plt.subplots(1, 3)
anim_img = []
anim_gt = []
anim_pred = []
anim_len = len(testloader)
net = net.eval()
with torch.no_grad():
animation = []
for ii, sample_batched in enumerate(testloader):
img, gt, fixation, img_orig = sample_batched['image'], sample_batched['gt'], sample_batched['fixation'], sample_batched['img_orig']
# Forward of the mini-batch
inputs = img.to(device)
outputs = net.forward(inputs)
for jj in range(int(inputs.size()[0])):
pred = np.transpose(torch.relu(outputs[-1]).cpu().data.numpy()[jj, :, :, :], (1, 2, 0))
pred = np.squeeze(pred)
img_orig = np.transpose(img_orig.numpy()[jj, :, :, :], (1, 2, 0))
prediction = normalize_map(pred)
prediction = postprocess_prediction(prediction, (gt.shape[2], gt.shape[3]))
prediction = normalize_map(prediction)
prediction *= 255
anim_img.append(img_orig)
anim_gt.append(gt.squeeze().data.cpu().numpy())
anim_pred.append(im_normalize(prediction))
def update_animation(i):
ax[0].cla()
ax[1].cla()
ax[2].cla()
ax[0].set_title('Image')
ax[1].set_title('Ground Truth')
ax[2].set_title('Prediction')
ax[0].imshow(cv2.cvtColor(anim_img[i], cv2.COLOR_BGR2RGB))
ax[1].imshow(superimpose(anim_img[i], anim_gt[i]))
ax[2].imshow(superimpose(anim_img[i], anim_pred[i]))
ax[0].axis('off')
ax[1].axis('off')
ax[2].axis('off')
from matplotlib import animation
save_path = os.path.join(OUTPUT_PATH, seq_name + '.mp4')
FFwriter = animation.FFMpegWriter(fps=30, codec="libx264")
anim = animation.FuncAnimation(fig, update_animation, frames=anim_len, interval=100, save_count=anim_len)
anim.save(save_path, writer=FFwriter)