-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy patheval.py
108 lines (82 loc) · 3.41 KB
/
eval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
from argparse import Namespace
import argparse
import numpy as np
import os
import torch
import torch.nn as nn
from torch.autograd import Variable
import tqdm
from data import get_dataset
parser = argparse.ArgumentParser(description='Predicting with high correlation features')
# Directories
parser.add_argument('--data', type=str, default='datasets/',
help='location of the data corpus')
parser.add_argument('--root_dir', type=str, default='default/',
help='root dir path to save the log and the final model')
parser.add_argument('--save_dir', type=str, default='0/',
help='dir path (inside root_dir) to save the log and the final model')
# dataset
parser.add_argument('--dataset', type=str, default='mnistm',
help='dataset name')
# Adaptive BN
parser.add_argument('--bn_eval', action='store_true',
help='adapt BN stats during eval')
# hyperparameters
parser.add_argument('--seed', type=int, default=1111,
help='random seed')
parser.add_argument('--bs', type=int, default=128, metavar='N',
help='batch size')
# meta specifications
parser.add_argument('--cuda', action='store_false',
help='use CUDA')
parser.add_argument('--gpu', nargs='+', type=int, default=[0])
args = parser.parse_args()
os.environ['CUDA_VISIBLE_DEVICES'] = ','.join(str(i) for i in args.gpu)
args.root_dir = os.path.join('runs/', args.root_dir)
args.save_dir = os.path.join(args.root_dir, args.save_dir)
use_cuda = torch.cuda.is_available()
torch.manual_seed(args.seed)
if use_cuda:
torch.cuda.manual_seed(args.seed)
###############################################################################
# Load data
###############################################################################
print('==> Preparing data..')
trainloader, validloader, testloader, nb_classes, dim_inp = get_dataset(args)
def test(loader, model):
global best_acc, args
if args.bn_eval: # forward prop data twice to update BN running averages
model.train()
for _ in range(2):
for batch_idx, (inputs, targets) in enumerate(loader):
if use_cuda:
inputs, targets = inputs.cuda(), targets.cuda()
inputs, targets = Variable(inputs), Variable(targets)
_ = (model(inputs, train=False))
model.eval()
test_loss, correct, total = 0,0,0
tot_iters = len(loader)
for batch_idx in tqdm.tqdm(range(tot_iters), total=tot_iters):
inputs, targets = next(iter(loader))
if use_cuda:
inputs, targets = inputs.cuda(), targets.cuda()
with torch.no_grad():
inputs, targets = Variable(inputs), Variable(targets)
outputs = (model(inputs, train=False))
_, predicted = torch.max(nn.Softmax(dim=1)(outputs).data, 1)
total += targets.size(0)
correct += predicted.eq(targets.data).cpu().sum()
# Save checkpoint.
acc = 100.*float(correct)/float(total)
return acc
with open(args.save_dir + '/best_model.pt', 'rb') as f:
best_state = torch.load(f)
model = best_state['model']
if use_cuda:
model.cuda()
# Run on test data.
test_acc = test(testloader, model=model)
best_val_acc = test(validloader, model=model)
print('=' * 89)
status = 'Test acc {:3.4f} at best val acc {:3.4f}'.format(test_acc, best_val_acc)
print(status)