-
Notifications
You must be signed in to change notification settings - Fork 2
/
utils.py
93 lines (82 loc) · 3.12 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import os
import random
import torch
import numpy as np
import yaml
# some constants from ViT-B_16 pretrained state dict
ATTENTION_OUT = "MultiHeadDotProductAttention_1/out"
ATTENTION_Q = "MultiHeadDotProductAttention_1/query"
ATTENTION_K = "MultiHeadDotProductAttention_1/key"
ATTENTION_V = "MultiHeadDotProductAttention_1/value"
ATTENTION_OUT = "MultiHeadDotProductAttention_1/out"
FC_0 = "MlpBlock_3/Dense_0"
FC_1 = "MlpBlock_3/Dense_1"
ATTENTION_NORM = "LayerNorm_0"
MLP_NORM = "LayerNorm_2"
# some constants from SimpleViT
HEAD = "linear_head"
def np2th(weights, conv=False):
"""For loading from pretrained. Possibly convert HWIO to OIHW."""
if conv:
weights = weights.transpose([3, 2, 0, 1])
return torch.from_numpy(weights)
def seed_everything(seed):
random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
def load_config(config_dir, config_id):
with open(os.path.join(config_dir, f"{config_id}.YAML")) as file:
config = yaml.load(file, Loader=yaml.FullLoader)
return config
def get_cv_folds(cv, test_kfold):
test_kfold = str(test_kfold)
test_fold = cv[test_kfold]
train_fold = [x for k, x in cv.items() if k!=test_kfold]
train_fold = [x for fold in train_fold for x in fold]
return train_fold, test_fold
class EarlyStopper():
def __init__(self, agg=200, delta=0.005):
"""
Stopping criteria: running median over agg steps is worse than any previous median by more than delta
:param agg: number of steps to aggregate metric over
:param delta: maximum change in running median before stopping
self.current_step counts the number of val performed, not the global steps
"""
self.history = []
self.medians = []
self.agg = agg
self.delta = delta
self.current_step = 0
def step(self, v):
self.history.append(v)
self.current_step += 1
def loss_check_stop(self):
# stop if current median HIGHER than previous median by delta
if self.current_step < self.agg:
return False
else:
# running median in agg range
current = np.median(self.history[self.current_step-self.agg:self.current_step])
self.medians.append(current)
# check if current median worse
for m in self.medians:
if current > (m + self.delta):
return True
return False
def acc_check_stop(self):
# stop if current median LOWER than previous median by delta
if self.current_step < self.agg:
return False
else:
# running median in agg range
current = np.median(self.history[self.current_step-self.agg:self.current_step])
self.medians.append(current)
# check if current median worse
for m in self.medians:
if current < (m - self.delta):
return True
return False