-
Notifications
You must be signed in to change notification settings - Fork 29
/
Copy pathtrain_eval.py
74 lines (63 loc) · 2.39 KB
/
train_eval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import math
import torch.optim as optim
import torch
import numpy as np
np.seterr(divide='ignore', invalid='ignore')
def evaluate(data, X, Y, model, evaluateL2, evaluateL1, args):
model.eval()
total_loss = 0
total_loss_l1 = 0
n_samples = 0
predict = None
test = None
for X, Y in data.get_batches(X, Y, args.batch_size, False):
output = model(X)
if predict is None:
predict = output.clone().detach()
test = Y
else:
predict = torch.cat((predict, output.clone().detach()))
test = torch.cat((test, Y))
scale = data.scale.expand(output.size(0), data.m)
total_loss += float(evaluateL2(output * scale, Y * scale).data.item())
total_loss_l1 +=float( evaluateL1(output * scale, Y * scale).data.item())
n_samples += int((output.size(0) * data.m))
rse = math.sqrt(total_loss / n_samples) / data.rse
rae = (total_loss_l1 / n_samples) / data.rae
predict = predict.data.cpu().numpy()
Ytest = test.data.cpu().numpy()
sigma_p = (predict).std(axis=0)
sigma_g = (Ytest).std(axis=0)
mean_p = predict.mean(axis=0)
mean_g = Ytest.mean(axis=0)
index = (sigma_g != 0)
correlation = ((predict - mean_p) * (Ytest - mean_g)).mean(axis=0) / (sigma_p * sigma_g)
correlation = (correlation[index]).mean()
return rse, rae, correlation
def train(data, X, Y, model, criterion, optim, args):
model.train()
total_loss = 0
n_samples = 0
for X, Y in data.get_batches(X, Y, args.batch_size, True):
optim.zero_grad()
output = model(X)
scale = data.scale.expand(output.size(0), data.m)
loss = criterion(output * scale, Y * scale)
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip)
optim.step()
total_loss += loss.data.item()
n_samples +=int( (output.size(0) * data.m))
return total_loss / n_samples
def makeOptimizer(params, args):
if args.optim == 'sgd':
optimizer = optim.SGD(params, lr=args.lr, )
elif args.optim == 'adagrad':
optimizer = optim.Adagrad(params, lr=args.lr, )
elif args.optim == 'adadelta':
optimizer = optim.Adadelta(params, lr=args.lr, )
elif args.optim == 'adam':
optimizer = optim.Adam(params, lr=args.lr, )
else:
raise RuntimeError("Invalid optim method: " + args.method)
return optimizer