-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutils.py
119 lines (101 loc) · 3.89 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import torch
from torchvision import transforms
from PIL import Image
import torch.nn.functional as F
import numpy as np
import cv2
import os
def get_clothes_mask(old_label) :
clothes = torch.FloatTensor((old_label.cpu().numpy() == 3).astype(np.int))
return clothes
def changearm(old_label):
label=old_label
arm1=torch.FloatTensor((old_label.cpu().numpy()==5).astype(np.int))
arm2=torch.FloatTensor((old_label.cpu().numpy()==6).astype(np.int))
label=label*(1-arm1)+arm1*3
label=label*(1-arm2)+arm2*3
return label
def gen_noise(shape):
noise = np.zeros(shape, dtype=np.uint8)
### noise
noise = cv2.randn(noise, 0, 255)
noise = np.asarray(noise / 255, dtype=np.uint8)
noise = torch.tensor(noise, dtype=torch.float32)
return noise
def cross_entropy2d(input, target, weight=None, size_average=True):
n, c, h, w = input.size()
nt, ht, wt = target.size()
# Handle inconsistent size between input and target
if h != ht or w != wt:
input = F.interpolate(input, size=(ht, wt), mode="bilinear", align_corners=True)
input = input.transpose(1, 2).transpose(2, 3).contiguous().view(-1, c)
target = target.view(-1)
loss = F.cross_entropy(
input, target, weight=weight, size_average=size_average, ignore_index=250
)
return loss
def ndim_tensor2im(image_tensor, imtype=np.uint8, batch=0):
image_numpy = image_tensor[batch].cpu().float().numpy()
result = np.argmax(image_numpy, axis=0)
return result.astype(imtype)
def visualize_segmap(input, multi_channel=True, tensor_out=True, batch=0) :
palette = [
0, 0, 0, 128, 0, 0, 254, 0, 0, 0, 85, 0, 169, 0, 51,
254, 85, 0, 0, 0, 85, 0, 119, 220, 85, 85, 0, 0, 85, 85,
85, 51, 0, 52, 86, 128, 0, 128, 0, 0, 0, 254, 51, 169, 220,
0, 254, 254, 85, 254, 169, 169, 254, 85, 254, 254, 0, 254, 169, 0
]
input = input.detach()
if multi_channel :
input = ndim_tensor2im(input,batch=batch)
else :
input = input[batch][0].cpu()
input = np.asarray(input)
input = input.astype(np.uint8)
input = Image.fromarray(input, 'P')
input.putpalette(palette)
if tensor_out :
trans = transforms.ToTensor()
return trans(input.convert('RGB'))
return input
def pred_to_onehot(prediction) :
size = prediction.shape
prediction_max = torch.argmax(prediction, dim=1)
oneHot_size = (size[0], 13, size[2], size[3])
pred_onehot = torch.FloatTensor(torch.Size(oneHot_size)).zero_()
pred_onehot = pred_onehot.scatter_(1, prediction_max.unsqueeze(1).data.long(), 1.0)
return pred_onehot
def cal_miou(prediction, target) :
size = prediction.shape
target = target.cpu()
prediction = pred_to_onehot(prediction.detach().cpu())
list = [1,2,3,4,5,6,7,8]
union = 0
intersection = 0
for b in range(size[0]) :
for c in list :
intersection += torch.logical_and(target[b,c], prediction[b,c]).sum()
union += torch.logical_or(target[b,c], prediction[b,c]).sum()
return intersection.item()/union.item()
def save_images(img_tensors, img_names, save_dir):
for img_tensor, img_name in zip(img_tensors, img_names):
tensor = (img_tensor.clone() + 1) * 0.5 * 255
tensor = tensor.cpu().clamp(0, 255)
try:
array = tensor.numpy().astype('uint8')
except:
array = tensor.detach().numpy().astype('uint8')
if array.shape[0] == 1:
array = array.squeeze(0)
elif array.shape[0] == 3:
array = array.swapaxes(0, 1).swapaxes(1, 2)
im = Image.fromarray(array)
im.save(os.path.join(save_dir, img_name), format='JPEG')
def create_network(cls, opt):
net = cls(opt)
net.print_network()
if len(opt.gpu_ids) > 0:
assert(torch.cuda.is_available())
net.cuda()
net.init_weights(opt.init_type, opt.init_variance)
return net