-
Notifications
You must be signed in to change notification settings - Fork 22
/
make_results.py
78 lines (63 loc) · 1.96 KB
/
make_results.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import time
import os
import math
import argparse
from glob import glob
from collections import OrderedDict
import random
import warnings
from datetime import datetime
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import pandas as pd
import joblib
from sklearn.model_selection import StratifiedKFold, train_test_split
from skimage.io import imread
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import torch.optim as optim
from torch.optim import lr_scheduler
from torch.utils.data import DataLoader
import torch.backends.cudnn as cudnn
import torchvision
from torchvision import datasets, models, transforms
from lib.dataset import Dataset
from lib.models.model_factory import get_model
from lib.utils import *
from lib.metrics import *
from lib.losses import *
from lib.preprocess import preprocess
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--name', default=None,
help='model name: (default: arch+timestamp)')
args = parser.parse_args()
return args
def main():
test_args = parse_args()
args = joblib.load('models/%s/args.pkl' %test_args.name)
folds = []
losses = []
scores = []
for fold in range(args.n_splits):
log_path = 'models/%s/log_%d.csv' %(args.name, fold+1)
if not os.path.exists(log_path):
continue
log = pd.read_csv('models/%s/log_%d.csv' %(args.name, fold+1))
loss, score = log.loc[log['val_loss'].values.argmin(), ['val_loss', 'val_score']].values
print(loss, score)
folds.append(str(fold+1))
losses.append(loss)
scores.append(score)
results = pd.DataFrame({
'fold': folds + ['mean'],
'loss': losses + [np.mean(losses)],
'score': scores + [np.mean(scores)],
})
print(results)
results.to_csv('models/%s/results.csv' % args.name, index=False)
if __name__ == '__main__':
main()