-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathstats.py
52 lines (43 loc) · 1.59 KB
/
stats.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
from argparse import ArgumentParser
from pytorch_lightning import Trainer
from numpy import mean, std
import os, gc
from classifier import Classifier
import warnings
warnings.filterwarnings("ignore") # to remove a userwarning from torch_sparse
parser = ArgumentParser()
parser.add_argument("path", help="path directory containing models")
args = parser.parse_args()
def get_test_valid(model_path):
model = Classifier.load_from_checkpoint(
checkpoint_path=model_path
)
trainer = Trainer(gpus=1, logger=False, weights_summary=None)
test = trainer.test(model, model.test_dataloader(), verbose=False)
valid = trainer.test(model, model.val_dataloader(), verbose=False)
del model
del trainer
gc.collect()
return *test, *valid
test = {}
valid = {}
paths = []
for dir_path, _, file_paths in os.walk(args.path):
for model_name in file_paths:
paths.append(os.path.join(dir_path, model_name))
for i, path in enumerate(paths):
print("evaluating model {}/{}".format(i + 1, len(paths)))
t, v = get_test_valid(path)
for k in t.keys():
if k in test:
test[k].append(t[k])
valid[k].append(v[k])
else:
test[k] = [t[k]]
valid[k] = [v[k]]
# .split('/')[0] removes '/valid'
test = {k.split('/')[0]: "{0:.4f} +/- {1:.4f}".format(mean(v), std(v)) for k, v in test.items()}
valid = {k.split('/')[0]: "{0:.4f} +/- {1:.4f}".format(mean(v), std(v)) for k, v in valid.items()}
print("\n\nRESULTS")
for (k, vt), (_, vv) in zip(test.items(), valid.items()):
print('{0}: TEST {1} VALID {2}'.format(k, vt, vv))