-
Notifications
You must be signed in to change notification settings - Fork 2
/
auxiliary_functions.py
62 lines (47 loc) · 1.51 KB
/
auxiliary_functions.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
# -*- coding: utf-8 -*-
"""
Created on Tue Jan 24 14:03:24 2023
@author: AmayaGS
"""
import os
import random
import numpy as np
from PIL import Image
import torch
from torch.utils.data import Dataset
class Accuracy_Logger(object):
"""Accuracy logger"""
def __init__(self, n_classes):
#super(Accuracy_Logger, self).__init__()
self.n_classes = n_classes
self.initialize()
def initialize(self):
self.data = [{"count": 0, "correct": 0} for i in range(self.n_classes)]
def log(self, Y_hat, Y):
Y_hat = int(Y_hat)
Y = int(Y)
self.data[Y]["count"] += 1
self.data[Y]["correct"] += (Y_hat == Y)
def log_batch(self, Y_hat, Y):
Y_hat = np.array(Y_hat).astype(int)
Y = np.array(Y).astype(int)
for label_class in np.unique(Y):
cls_mask = Y == label_class
self.data[label_class]["count"] += cls_mask.sum()
self.data[label_class]["correct"] += (Y_hat[cls_mask] == Y[cls_mask]).sum()
def get_summary(self, c):
count = self.data[c]["count"]
correct = self.data[c]["correct"]
if count == 0:
acc = None
else:
acc = float(correct) / count
return acc, correct, count
def seed_everything(seed=42):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = True