-
Notifications
You must be signed in to change notification settings - Fork 5
/
evaluate.py
102 lines (82 loc) · 3.95 KB
/
evaluate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import torch
from torch.utils.data import DataLoader
import numpy as np
import os
import argparse
from tqdm import tqdm
from utils.config import *
# from model.model_proxy import Proxy, ModelWithLoss
from model.model_proxy_SAM import SonarSAM, ModelWithLoss
from model.loss_functions import compute_dice_accuracy, compute_multilabel_dice_accuracy, compute_multilabel_IoU
from dataloader.data_loader import DebrisDataset, collate_fn_seq
from utils.utils import rand_seed
from model.segment_anything.utils.transforms import ResizeLongestSide
label_list = ["Background", "Bottle", "Can", "Chain", "Drink-carton", "Hook",
"Propeller", "Shampoo-bottle", "Standing-bottle", "Tire", "Valve",
"Wall"]
def evaluate(net, val_loader, device, opt):
dice_ = [[], [], [], [], [], [], [], [], [], [], [], []]
net.eval()
with torch.no_grad():
for val_step, (images, masks, boxes) in enumerate(tqdm(val_loader)):
images = images.to(device)
masks = masks.to(device)
if torch.sum(masks) == 0:
continue
boxes = None
predictions = net.forward(images, boxes)
start_x = int(opt.INPUT_SIZE / 3.0) // 2
end_x = opt.INPUT_SIZE -1 -start_x
masks = masks[:, :, :, start_x:end_x].contiguous()
predictions = predictions[:, :, :, start_x:end_x].contiguous()
# eval metric
if opt.EVAL_METRIC == 'DICE':
dice_iter = compute_multilabel_dice_accuracy(masks, predictions)
elif opt.EVAL_METRIC == 'IoU':
dice_iter = compute_multilabel_IoU(masks, predictions)
else:
raise ValueError
dice_iter = dice_iter.squeeze()
for i in range(len(dice_iter)):
# skip no-exist targets
if torch.sum(masks[:, i, ...]) > 0:
dice_[i].append(dice_iter[i].cpu().item())
# store in dict
avg_list = []
metrics_dict = {}
for i in range(len(label_list)):
if len(dice_[i]) == 0:
d = torch.tensor(0)
else:
d = torch.mean(torch.tensor(dice_[i]))
metrics_dict[label_list[i]] = d
avg_list.append(d)
metrics_dict['avg'] = torch.mean(torch.tensor(avg_list))
metrics_dict['avg(exclude_bg)'] = torch.mean(torch.tensor(avg_list[1:]))
return metrics_dict
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--config", type=str, help="config path (*.yaml)", required=True)
parser.add_argument("--save_path", type=str, help="save path", required=True)
args = parser.parse_args()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# opt = Config(config_path=args.config)
opt = Config_SAM(config_path=args.config)
# dataset
test_dataset = DebrisDataset(root_path=opt.DATA_PATH, image_list=os.path.join(opt.IMAGE_LIST_PATH, 'test.txt'),
input_size=opt.INPUT_SIZE, use_augment=False)
test_loader = DataLoader(test_dataset, batch_size=opt.VAL_BATCHSIZE, shuffle=False, collate_fn=collate_fn_seq)
rand_seed(opt.RANDOM_SEED)
net = SonarSAM(model_name=opt.SAM_NAME, checkpoint=opt.SAM_CHECKPOINT, num_classes=opt.OUTPUT_CHN,
is_finetune_image_encoder=opt.IS_FINETUNE_IMAGE_ENCODER,
use_adaptation=opt.USE_ADAPTATION, adaptation_type=opt.ADAPTATION_TYPE,
head_type=opt.HEAD_TYPE,
reduction=4, upsample_times=2, groups=4)
net = ModelWithLoss(net)
ckpt = torch.load(os.path.join(args.save_path, '{}_best.pth'.format(opt.MODEL_NAME)))
net.load_state_dict(ckpt['state_dict'])
net.to(device)
metrics_dict = evaluate(net.model, test_loader, device, opt)
print("Dice on Test set:")
for key in metrics_dict.keys():
print("{}:\t{:.2f}".format(key, metrics_dict[key]*100))