-
Notifications
You must be signed in to change notification settings - Fork 14
/
Copy pathTrainer.py
executable file
·99 lines (80 loc) · 3.78 KB
/
Trainer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
from Statistics import Statistics
import torch
import torch.optim as optim
from torch.nn.utils import clip_grad_norm
import torch.nn as nn
import time
class Trainer:
def __init__(self, model):
self.model = model
self.opt = model.module.opt if isinstance(model, nn.parallel.DistributedDataParallel) else model.opt
self.start_epoch = self.opt.start_epoch if self.opt.start_epoch else 1
self.lr = self.opt.learning_rate
self.betas = [0.9, 0.98]
self.optimizer = optim.Adam(self.model.parameters(), lr=self.lr,
betas=self.betas, eps=1e-9)
if 'prev_optim' in self.opt:
print('Loading prev optimizer state')
self.optimizer.load_state_dict(self.opt.prev_optim)
for state in self.optimizer.state.values():
for k, v in state.items():
if isinstance(v, torch.Tensor):
state[k] = v.cuda()
def save_checkpoint(self, epoch, valid_stats):
real_model = (self.model.module
if isinstance(self.model, nn.parallel.DistributedDataParallel)
else self.model)
model_state_dict = real_model.state_dict()
self.opt.learning_rate = self.lr
checkpoint = {
'model': model_state_dict,
'vocab': real_model.vocabs,
'opt': self.opt,
'epoch': epoch,
'optim': self.optimizer.state_dict()
}
torch.save(checkpoint,
'%s_acc_%.2f_ppl_%.2f_e%d.pt'
% (self.opt.save_model + '/model', valid_stats.accuracy(),
valid_stats.ppl(), epoch))
def update_learning_rate(self, valid_stats):
if self.last_ppl is not None and valid_stats.ppl() > self.last_ppl:
self.lr = self.lr * self.opt.learning_rate_decay
print("Decaying learning rate to %g" % self.lr)
self.last_ppl = valid_stats.ppl()
self.optimizer.param_groups[0]['lr'] = self.lr
def run_train_batched(self, train_data, valid_data, vocabs):
print(self.model.parameters)
total_train = train_data.compute_batches(self.opt.batch_size, vocabs, self.opt.max_camel, 0, 1, self.opt.decoder_type, trunc=self.opt.trunc)
total_valid = valid_data.compute_batches(10 if self.opt.decoder_type in ["prod", "concode"] else self.opt.batch_size, vocabs, self.opt.max_camel, 0, 1, self.opt.decoder_type, randomize=False, trunc=self.opt.trunc)
print('Computed Batches. Total train={}, Total valid={}'.format(total_train, total_valid))
report_stats = Statistics()
self.last_ppl = None
for epoch in range(self.start_epoch, self.opt.epochs + 1):
self.model.train()
total_stats = Statistics()
for idx, batch in enumerate(train_data.batches):
loss, batch_stats = self.model.forward(batch)
batch_size = batch['code'].size(0)
loss.div(batch_size).backward()
report_stats.update(batch_stats)
total_stats.update(batch_stats)
clip_grad_norm(self.model.parameters(), self.opt.max_grad_norm)
self.optimizer.step()
self.optimizer.zero_grad()
if (idx + 1) % self.opt.report_every == -1 % self.opt.report_every:
report_stats.output(epoch, idx + 1, len(train_data.batches), total_stats.start_time)
report_stats = Statistics()
print('Train perplexity: %g' % total_stats.ppl())
print('Train accuracy: %g' % total_stats.accuracy())
self.model.eval()
valid_stats = Statistics()
for idx, batch in enumerate(valid_data.batches):
loss, batch_stats = self.model.forward(batch)
valid_stats.update(batch_stats)
print('Validation perplexity: %g' % valid_stats.ppl())
print('Validation accuracy: %g' % valid_stats.accuracy())
self.update_learning_rate(valid_stats)
print('Saving model')
self.save_checkpoint(epoch, valid_stats)
print('Model saved')