-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathgetsmaps.py
90 lines (71 loc) · 2.76 KB
/
getsmaps.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import numpy as np
import torch
import torch.utils.data as Data
import os
from torch import nn
from torch import optim
from torch.utils.data import DataLoader
from ORSI_SOD_dataset import ORSI_SOD_Dataset
from tqdm import tqdm
from src.UG2L import net as Net
from evaluator import Eval_thread
from PIL import Image
import time
os.environ['CUDA_VISIBLE_DEVICES'] = '2'
def unload(x):
y = x.squeeze().cpu().data.numpy()
return y
def convert2img(x):
return Image.fromarray(x*255).convert('L')
def min_max_normalization(x):
x_normed = (x - np.min(x)) / (np.max(x)-np.min(x))
return x_normed
def save_smap(smap, path, negative_threshold=0.25):
# smap: [1, H, W]
if torch.max(smap) <= negative_threshold:
smap[smap<negative_threshold] = 0
smap = convert2img(unload(smap))
else:
smap = convert2img(min_max_normalization(unload(smap)))
smap.save(path)
def getsmaps(dataset_name):
##define dataset
dataset_root = "/data/iopen/lyf/SaliencyOD_in_RSIs/" + dataset_name + " dataset/"
test_set = ORSI_SOD_Dataset(root = dataset_root, mode = "test", aug = False)
test_loader = DataLoader(test_set, batch_size = 1, shuffle = True, num_workers = 1)
##define network and load weight
net = Net().cuda().eval()
if dataset_name == "ORSSD":
net.load_state_dict(torch.load("./data/weights/ORSSD_weights.pth")) ##UG2L
elif dataset_name == "EORSSD":
net.load_state_dict(torch.load("./data/weights/EORSSD_weights.pth")) ##UG2L
elif dataset_name == "ORS_4199":
net.load_state_dict(torch.load("./data/weights/ORS_4199_weights.pth")) ##UG2L
##save saliency map
infer_time = 0
for image, label, _, name in tqdm(test_loader):
with torch.no_grad():
image, label = image.cuda(), label.cuda()
t1 = time.time()
smap = net(image)
t2 = time.time()
infer_time += (t2 - t1)
##if not exist then define
dirs = "./data/output/predict_smaps" + "_UG2L_" + dataset_name
if not os.path.exists(dirs):
os.makedirs(dirs)
path = os.path.join(dirs, name[0] + "_UG2L" + '.png')
save_smap(smap, path)
print(len(test_loader))
print(infer_time)
print(len(test_loader) / infer_time) # inference speed (without I/O time),
if __name__ == "__main__":
net = Net().cuda().eval()
from thop import profile
from thop import clever_format
x = torch.Tensor(1,3,448,448).cuda()
macs, params = profile(net, inputs=(x, ), verbose = False)
print('flops: ', f'{macs/1e9}GMac', 'params: ', f'{params/1e6}M')
dataset = ["ORSSD", "EORSSD", "ORS_4199"]
for datseti in dataset:
getsmaps(datseti)