forked from SKA-INAF/radio-tiramisu
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcompute_iou.py
122 lines (96 loc) · 4.12 KB
/
compute_iou.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import argparse
import datetime
from pathlib import Path
import torch
from torchmetrics import JaccardIndex as IoU
from tqdm import tqdm
import utils.training as train_utils
from datasets.rg_masks import RGDataset
from models import tiramisu
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--resume", default='latest.th',
type=str, help="Weights path from which start training")
parser.add_argument(
"--data_dir", default="data/rg-dataset/data", help="Path of data folder")
parser.add_argument("--results_dir", default=".results",
help="Weights dir where to store results")
parser.add_argument("--log_file", default="log.txt",
help="Log text file path")
parser.add_argument("--batch_size", default=20)
parser.add_argument("--n_classes", default=4)
parser.add_argument("--device", default="cuda")
return parser
def main(args):
DATA_PATH = Path(args.data_dir)
RESULTS_PATH = Path(args.results_dir) / \
datetime.datetime.now().strftime("%Y-%m-%d_%H:%M")
RESULTS_PATH.mkdir(exist_ok=True, parents=True)
batch_size = args.batch_size
test_dset = RGDataset(DATA_PATH, "data/rg-dataset/val_mask.txt")
test_loader = torch.utils.data.DataLoader(
test_dset, batch_size=batch_size, shuffle=False, num_workers=4)
if args.device == 'cuda':
torch.cuda.manual_seed(0)
model = tiramisu.FCDenseNet67(n_classes=args.n_classes).to(args.device)
train_utils.load_weights(model, args.resume)
model.eval()
iou_all = IoU(task="multiclass", num_classes=4).to(args.device)
iou_ext = IoU(task="binary").to(args.device)
iou_comp = IoU(task="binary").to(args.device)
accs_all = []
accs_ext = []
accs_comp = []
for batch in tqdm(test_loader, desc="Testing"):
data, target = batch
data = data.to(args.device)
targets = target.to(args.device)
with torch.no_grad():
output = model(data)
preds = output.argmax(1)
preds = preds.to(args.device)
# Compute metrics for all classes
iou_all.update(preds, targets)
preds[preds == 0] = -1
batch_acc = (preds == targets).sum() / (targets != 0).sum()
accs_all.append(batch_acc)
# Compute metrics for extended class
preds_ext = preds.clone()
targets_ext = targets.clone()
preds_ext[preds_ext != 3] = 0
targets_ext[targets_ext != 3] = 0
preds_ext = preds_ext.clip(0, 1)
targets_ext = targets_ext.clip(0, 1)
iou_ext.update(preds_ext, targets_ext)
preds_ext[preds_ext == 0] = -1
batch_acc = (preds_ext == targets_ext).sum() / (targets_ext !=
0).sum() if (targets_ext != 0).sum() != 0 else 1
accs_ext.append(batch_acc)
# Compute metrics for compact class
preds_comp = preds.clone()
targets_comp = targets.clone()
preds_comp[preds_comp != 2] = 0
targets_comp[targets_comp != 2] = 0
preds_comp = preds_comp.clip(0, 1)
targets_comp = targets_comp.clip(0, 1)
iou_comp.update(preds_comp, targets_comp)
preds_comp[preds_comp == 0] = -1
# if (targets_comp != 0).sum() != 0:
# batch_acc = (preds_comp == targets_comp).sum() / (targets_comp != 0).sum()
# accs_comp.append(batch_acc)
batch_acc = (preds_comp == targets_comp).sum(
) / (targets_comp != 0).sum() if (targets_comp != 0).sum() != 0 else 1
accs_comp.append(batch_acc)
print("All classes")
print(f'Accuracy: {sum(accs_all) / len(accs_all) * 100:.2f}')
print(f'IoU: {iou_all.compute() * 100:.2f}')
print("\nOnly Extended")
print(f'Accuracy: {sum(accs_ext) / len(accs_ext) * 100:.2f}')
print(f'IoU: {iou_ext.compute() * 100:.2f}')
print("\nOnly Compact")
print(f'Accuracy: {sum(accs_comp) / len(accs_comp) * 100:.2f}')
print(f'IoU: {iou_comp.compute() * 100:.2f}')
if __name__ == '__main__':
args = get_args().parse_args()
args.resume = "weights/augmentation/augmented-compact.pth"
main(args)