-
Notifications
You must be signed in to change notification settings - Fork 24
/
Copy pathtest.py
318 lines (285 loc) · 14.1 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
import os
import cv2
import json
import torch
import random
import logging
import argparse
import numpy as np
from PIL import Image
from skimage import measure
from tabulate import tabulate
import torch.nn.functional as F
import torchvision.transforms as transforms
from sklearn.metrics import auc, roc_auc_score, average_precision_score, f1_score, precision_recall_curve, pairwise
import open_clip
from few_shot import memory
from model import LinearLayer
from dataset import VisaDataset, MVTecDataset
from prompt_ensemble import encode_text_with_prompt_ensemble
def setup_seed(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
def normalize(pred, max_value=None, min_value=None):
if max_value is None or min_value is None:
return (pred - pred.min()) / (pred.max() - pred.min())
else:
return (pred - min_value) / (max_value - min_value)
def apply_ad_scoremap(image, scoremap, alpha=0.5):
np_image = np.asarray(image, dtype=float)
scoremap = (scoremap * 255).astype(np.uint8)
scoremap = cv2.applyColorMap(scoremap, cv2.COLORMAP_JET)
scoremap = cv2.cvtColor(scoremap, cv2.COLOR_BGR2RGB)
return (alpha * np_image + (1 - alpha) * scoremap).astype(np.uint8)
def cal_pro_score(masks, amaps, max_step=200, expect_fpr=0.3):
# ref: https://github.com/gudovskiy/cflow-ad/blob/master/train.py
binary_amaps = np.zeros_like(amaps, dtype=bool)
min_th, max_th = amaps.min(), amaps.max()
delta = (max_th - min_th) / max_step
pros, fprs, ths = [], [], []
for th in np.arange(min_th, max_th, delta):
binary_amaps[amaps <= th], binary_amaps[amaps > th] = 0, 1
pro = []
for binary_amap, mask in zip(binary_amaps, masks):
for region in measure.regionprops(measure.label(mask)):
tp_pixels = binary_amap[region.coords[:, 0], region.coords[:, 1]].sum()
pro.append(tp_pixels / region.area)
inverse_masks = 1 - masks
fp_pixels = np.logical_and(inverse_masks, binary_amaps).sum()
fpr = fp_pixels / inverse_masks.sum()
pros.append(np.array(pro).mean())
fprs.append(fpr)
ths.append(th)
pros, fprs, ths = np.array(pros), np.array(fprs), np.array(ths)
idxes = fprs < expect_fpr
fprs = fprs[idxes]
fprs = (fprs - fprs.min()) / (fprs.max() - fprs.min())
pro_auc = auc(fprs, pros[idxes])
return pro_auc
def test(args):
img_size = args.image_size
features_list = args.features_list
few_shot_features = args.few_shot_features
dataset_dir = args.data_path
save_path = args.save_path
dataset_name = args.dataset
if not os.path.exists(save_path):
os.makedirs(save_path)
device = "cuda" if torch.cuda.is_available() else "cpu"
txt_path = os.path.join(save_path, 'log.txt')
# clip
model, _, preprocess = open_clip.create_model_and_transforms(args.model, img_size, pretrained=args.pretrained)
model.to(device)
tokenizer = open_clip.get_tokenizer(args.model)
# logger
root_logger = logging.getLogger()
for handler in root_logger.handlers[:]:
root_logger.removeHandler(handler)
root_logger.setLevel(logging.WARNING)
logger = logging.getLogger('test')
formatter = logging.Formatter('%(asctime)s.%(msecs)03d - %(levelname)s: %(message)s',
datefmt='%y-%m-%d %H:%M:%S')
logger.setLevel(logging.INFO)
file_handler = logging.FileHandler(txt_path, mode='w')
file_handler.setFormatter(formatter)
logger.addHandler(file_handler)
console_handler = logging.StreamHandler()
console_handler.setFormatter(formatter)
logger.addHandler(console_handler)
# record parameters
for arg in vars(args):
if args.mode == 'zero_shot' and (arg == 'k_shot' or arg == 'few_shot_features'):
continue
logger.info(f'{arg}: {getattr(args, arg)}')
# seg
with open(args.config_path, 'r') as f:
model_configs = json.load(f)
linearlayer = LinearLayer(model_configs['vision_cfg']['width'], model_configs['embed_dim'],
len(features_list), args.model).to(device)
checkpoint = torch.load(args.checkpoint_path)
linearlayer.load_state_dict(checkpoint["trainable_linearlayer"])
# dataset
transform = transforms.Compose([
transforms.Resize((img_size, img_size)),
transforms.CenterCrop(img_size),
transforms.ToTensor()
])
if dataset_name == 'mvtec':
test_data = MVTecDataset(root=dataset_dir, transform=preprocess, target_transform=transform,
aug_rate=-1, mode='test')
else:
test_data = VisaDataset(root=dataset_dir, transform=preprocess, target_transform=transform, mode='test')
test_dataloader = torch.utils.data.DataLoader(test_data, batch_size=1, shuffle=False)
obj_list = test_data.get_cls_names()
# few shot
if args.mode == 'few_shot':
mem_features = memory(args.model, model, obj_list, dataset_dir, save_path, preprocess, transform,
args.k_shot, few_shot_features, dataset_name, device)
# text prompt
with torch.cuda.amp.autocast(), torch.no_grad():
text_prompts = encode_text_with_prompt_ensemble(model, obj_list, tokenizer, device)
results = {}
results['cls_names'] = []
results['imgs_masks'] = []
results['anomaly_maps'] = []
results['gt_sp'] = []
results['pr_sp'] = []
for items in test_dataloader:
image = items['img'].to(device)
cls_name = items['cls_name']
results['cls_names'].append(cls_name[0])
gt_mask = items['img_mask']
gt_mask[gt_mask > 0.5], gt_mask[gt_mask <= 0.5] = 1, 0
results['imgs_masks'].append(gt_mask) # px
results['gt_sp'].append(items['anomaly'].item())
with torch.no_grad(), torch.cuda.amp.autocast():
image_features, patch_tokens = model.encode_image(image, features_list)
image_features /= image_features.norm(dim=-1, keepdim=True)
text_features = []
for cls in cls_name:
text_features.append(text_prompts[cls])
text_features = torch.stack(text_features, dim=0)
# sample
text_probs = (100.0 * image_features @ text_features[0]).softmax(dim=-1)
results['pr_sp'].append(text_probs[0][1].cpu().item())
# pixel
patch_tokens = linearlayer(patch_tokens)
anomaly_maps = []
for layer in range(len(patch_tokens)):
patch_tokens[layer] /= patch_tokens[layer].norm(dim=-1, keepdim=True)
anomaly_map = (100.0 * patch_tokens[layer] @ text_features)
B, L, C = anomaly_map.shape
H = int(np.sqrt(L))
anomaly_map = F.interpolate(anomaly_map.permute(0, 2, 1).view(B, 2, H, H),
size=img_size, mode='bilinear', align_corners=True)
anomaly_map = torch.softmax(anomaly_map, dim=1)[:, 1, :, :]
anomaly_maps.append(anomaly_map.cpu().numpy())
anomaly_map = np.sum(anomaly_maps, axis=0)
# few shot
if args.mode == 'few_shot':
image_features, patch_tokens = model.encode_image(image, few_shot_features)
anomaly_maps_few_shot = []
for idx, p in enumerate(patch_tokens):
if 'ViT' in args.model:
p = p[0, 1:, :]
else:
p = p[0].view(p.shape[1], -1).permute(1, 0).contiguous()
cos = pairwise.cosine_similarity(mem_features[cls_name[0]][idx].cpu(), p.cpu())
height = int(np.sqrt(cos.shape[1]))
anomaly_map_few_shot = np.min((1 - cos), 0).reshape(1, 1, height, height)
anomaly_map_few_shot = F.interpolate(torch.tensor(anomaly_map_few_shot),
size=img_size, mode='bilinear', align_corners=True)
anomaly_maps_few_shot.append(anomaly_map_few_shot[0].cpu().numpy())
anomaly_map_few_shot = np.sum(anomaly_maps_few_shot, axis=0)
anomaly_map = anomaly_map + anomaly_map_few_shot
results['anomaly_maps'].append(anomaly_map)
# visualization
path = items['img_path']
cls = path[0].split('/')[-2]
filename = path[0].split('/')[-1]
vis = cv2.cvtColor(cv2.resize(cv2.imread(path[0]), (img_size, img_size)), cv2.COLOR_BGR2RGB) # RGB
mask = normalize(anomaly_map[0])
vis = apply_ad_scoremap(vis, mask)
vis = cv2.cvtColor(vis, cv2.COLOR_RGB2BGR) # BGR
save_vis = os.path.join(save_path, 'imgs', cls_name[0], cls)
if not os.path.exists(save_vis):
os.makedirs(save_vis)
cv2.imwrite(os.path.join(save_vis, filename), vis)
# metrics
table_ls = []
auroc_sp_ls = []
auroc_px_ls = []
f1_sp_ls = []
f1_px_ls = []
aupro_ls = []
ap_sp_ls = []
ap_px_ls = []
for obj in obj_list:
table = []
gt_px = []
pr_px = []
gt_sp = []
pr_sp = []
pr_sp_tmp = []
table.append(obj)
for idxes in range(len(results['cls_names'])):
if results['cls_names'][idxes] == obj:
gt_px.append(results['imgs_masks'][idxes].squeeze(1).numpy())
pr_px.append(results['anomaly_maps'][idxes])
pr_sp_tmp.append(np.max(results['anomaly_maps'][idxes]))
gt_sp.append(results['gt_sp'][idxes])
pr_sp.append(results['pr_sp'][idxes])
gt_px = np.array(gt_px)
gt_sp = np.array(gt_sp)
pr_px = np.array(pr_px)
pr_sp = np.array(pr_sp)
if args.mode == 'few_shot':
pr_sp_tmp = np.array(pr_sp_tmp)
pr_sp_tmp = (pr_sp_tmp - pr_sp_tmp.min()) / (pr_sp_tmp.max() - pr_sp_tmp.min())
pr_sp = 0.5 * (pr_sp + pr_sp_tmp)
auroc_px = roc_auc_score(gt_px.ravel(), pr_px.ravel())
auroc_sp = roc_auc_score(gt_sp, pr_sp)
ap_sp = average_precision_score(gt_sp, pr_sp)
ap_px = average_precision_score(gt_px.ravel(), pr_px.ravel())
# f1_sp
precisions, recalls, thresholds = precision_recall_curve(gt_sp, pr_sp)
f1_scores = (2 * precisions * recalls) / (precisions + recalls)
f1_sp = np.max(f1_scores[np.isfinite(f1_scores)])
# f1_px
precisions, recalls, thresholds = precision_recall_curve(gt_px.ravel(), pr_px.ravel())
f1_scores = (2 * precisions * recalls) / (precisions + recalls)
f1_px = np.max(f1_scores[np.isfinite(f1_scores)])
# aupro
if len(gt_px.shape) == 4:
gt_px = gt_px.squeeze(1)
if len(pr_px.shape) == 4:
pr_px = pr_px.squeeze(1)
aupro = cal_pro_score(gt_px, pr_px)
table.append(str(np.round(auroc_px * 100, decimals=1)))
table.append(str(np.round(f1_px * 100, decimals=1)))
table.append(str(np.round(ap_px * 100, decimals=1)))
table.append(str(np.round(aupro * 100, decimals=1)))
table.append(str(np.round(auroc_sp * 100, decimals=1)))
table.append(str(np.round(f1_sp * 100, decimals=1)))
table.append(str(np.round(ap_sp * 100, decimals=1)))
table_ls.append(table)
auroc_sp_ls.append(auroc_sp)
auroc_px_ls.append(auroc_px)
f1_sp_ls.append(f1_sp)
f1_px_ls.append(f1_px)
aupro_ls.append(aupro)
ap_sp_ls.append(ap_sp)
ap_px_ls.append(ap_px)
# logger
table_ls.append(['mean', str(np.round(np.mean(auroc_px_ls) * 100, decimals=1)),
str(np.round(np.mean(f1_px_ls) * 100, decimals=1)), str(np.round(np.mean(ap_px_ls) * 100, decimals=1)),
str(np.round(np.mean(aupro_ls) * 100, decimals=1)), str(np.round(np.mean(auroc_sp_ls) * 100, decimals=1)),
str(np.round(np.mean(f1_sp_ls) * 100, decimals=1)), str(np.round(np.mean(ap_sp_ls) * 100, decimals=1))])
results = tabulate(table_ls, headers=['objects', 'auroc_px', 'f1_px', 'ap_px', 'aupro', 'auroc_sp',
'f1_sp', 'ap_sp'], tablefmt="pipe")
logger.info("\n%s", results)
if __name__ == '__main__':
parser = argparse.ArgumentParser("VAND Challenge", add_help=True)
# paths
parser.add_argument("--data_path", type=str, default="./data/visa", help="path to test dataset")
parser.add_argument("--save_path", type=str, default='./results/tiaoshi', help='path to save results')
parser.add_argument("--checkpoint_path", type=str, default='./exps/vit_huge_14/model_epoch12.pth', help='path to save results')
parser.add_argument("--config_path", type=str, default='./open_clip/model_configs/ViT-B-16.json', help="model configs")
# model
parser.add_argument("--dataset", type=str, default='mvtec', help="test dataset")
parser.add_argument("--model", type=str, default="ViT-B-16", help="model used")
parser.add_argument("--pretrained", type=str, default="laion400m_e32", help="pretrained weight used")
parser.add_argument("--features_list", type=int, nargs="+", default=[3, 6, 9], help="features used")
parser.add_argument("--few_shot_features", type=int, nargs="+", default=[3, 6, 9], help="features used for few shot")
parser.add_argument("--image_size", type=int, default=224, help="image size")
parser.add_argument("--mode", type=str, default="zero_shot", help="zero shot or few shot")
# few shot
parser.add_argument("--k_shot", type=int, default=10, help="e.g., 10-shot, 5-shot, 1-shot")
parser.add_argument("--seed", type=int, default=10, help="random seed")
args = parser.parse_args()
setup_seed(args.seed)
test(args)