-
Notifications
You must be signed in to change notification settings - Fork 28
/
trainer.py
186 lines (142 loc) · 6.97 KB
/
trainer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
import torch
import numpy as np
from tqdm import tqdm
from os.path import dirname, abspath, join, exists
from os import makedirs
from datetime import datetime
import json
PAD_INDEX = 0
BASE_DIR = dirname(abspath(__file__))
class EpochSeq2SeqTrainer:
def __init__(self, model,
train_dataloader, val_dataloader,
loss_function, metric_function, optimizer,
logger, run_name,
save_config, save_checkpoint,
config):
self.config = config
self.device = torch.device(self.config['device'])
self.model = model.to(self.device)
self.train_dataloader = train_dataloader
self.val_dataloader = val_dataloader
self.loss_function = loss_function.to(self.device)
self.metric_function = metric_function
self.optimizer = optimizer
self.clip_grads = self.config['clip_grads']
self.logger = logger
self.checkpoint_dir = join(BASE_DIR, 'checkpoints', run_name)
if not exists(self.checkpoint_dir):
makedirs(self.checkpoint_dir)
if save_config is None:
config_filepath = join(self.checkpoint_dir, 'config.json')
else:
config_filepath = save_config
with open(config_filepath, 'w') as config_file:
json.dump(self.config, config_file)
self.print_every = self.config['print_every']
self.save_every = self.config['save_every']
self.epoch = 0
self.history = []
self.start_time = datetime.now()
self.best_val_metric = None
self.best_checkpoint_filepath = None
self.save_checkpoint = save_checkpoint
self.save_format = 'epoch={epoch:0>3}-val_loss={val_loss:<.3}-val_metrics={val_metrics}.pth'
self.log_format = (
"Epoch: {epoch:>3} "
"Progress: {progress:<.1%} "
"Elapsed: {elapsed} "
"Examples/second: {per_second:<.1} "
"Train Loss: {train_loss:<.6} "
"Val Loss: {val_loss:<.6} "
"Train Metrics: {train_metrics} "
"Val Metrics: {val_metrics} "
"Learning rate: {current_lr:<.4} "
)
def run_epoch(self, dataloader, mode='train'):
batch_losses = []
batch_counts = []
batch_metrics = []
for sources, inputs, targets in tqdm(dataloader):
sources, inputs, targets = sources.to(self.device), inputs.to(self.device), targets.to(self.device)
outputs = self.model(sources, inputs)
batch_loss, batch_count = self.loss_function(outputs, targets)
if mode == 'train':
self.optimizer.zero_grad()
batch_loss.backward()
if self.clip_grads:
torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1)
self.optimizer.step()
batch_losses.append(batch_loss.item())
batch_counts.append(batch_count)
batch_metric, batch_metric_count = self.metric_function(outputs, targets)
batch_metrics.append(batch_metric)
assert batch_count == batch_metric_count
if self.epoch == 0: # for testing
return float('inf'), [float('inf')]
epoch_loss = sum(batch_losses) / sum(batch_counts)
epoch_accuracy = sum(batch_metrics) / sum(batch_counts)
epoch_perplexity = float(np.exp(epoch_loss))
epoch_metrics = [epoch_perplexity, epoch_accuracy]
return epoch_loss, epoch_metrics
def run(self, epochs=10):
for epoch in range(self.epoch, epochs + 1):
self.epoch = epoch
self.model.train()
epoch_start_time = datetime.now()
train_epoch_loss, train_epoch_metrics = self.run_epoch(self.train_dataloader, mode='train')
epoch_end_time = datetime.now()
self.model.eval()
val_epoch_loss, val_epoch_metrics = self.run_epoch(self.val_dataloader, mode='val')
if epoch % self.print_every == 0 and self.logger:
per_second = len(self.train_dataloader.dataset) / ((epoch_end_time - epoch_start_time).seconds + 1)
current_lr = self.optimizer.param_groups[0]['lr']
log_message = self.log_format.format(epoch=epoch,
progress=epoch / epochs,
per_second=per_second,
train_loss=train_epoch_loss,
val_loss=val_epoch_loss,
train_metrics=[round(metric, 4) for metric in train_epoch_metrics],
val_metrics=[round(metric, 4) for metric in val_epoch_metrics],
current_lr=current_lr,
elapsed=self._elapsed_time()
)
self.logger.info(log_message)
if epoch % self.save_every == 0:
self._save_model(epoch, train_epoch_loss, val_epoch_loss, train_epoch_metrics, val_epoch_metrics)
def _save_model(self, epoch, train_epoch_loss, val_epoch_loss, train_epoch_metrics, val_epoch_metrics):
checkpoint_filename = self.save_format.format(
epoch=epoch,
val_loss=val_epoch_loss,
val_metrics='-'.join(['{:<.3}'.format(v) for v in val_epoch_metrics])
)
if self.save_checkpoint is None:
checkpoint_filepath = join(self.checkpoint_dir, checkpoint_filename)
else:
checkpoint_filepath = self.save_checkpoint
save_state = {
'epoch': epoch,
'train_loss': train_epoch_loss,
'train_metrics': train_epoch_metrics,
'val_loss': val_epoch_loss,
'val_metrics': val_epoch_metrics,
'checkpoint': checkpoint_filepath,
}
if self.epoch > 0:
torch.save(self.model.state_dict(), checkpoint_filepath)
self.history.append(save_state)
representative_val_metric = val_epoch_metrics[0]
if self.best_val_metric is None or self.best_val_metric > representative_val_metric:
self.best_val_metric = representative_val_metric
self.val_loss_at_best = val_epoch_loss
self.train_loss_at_best = train_epoch_loss
self.train_metrics_at_best = train_epoch_metrics
self.val_metrics_at_best = val_epoch_metrics
self.best_checkpoint_filepath = checkpoint_filepath
if self.logger:
self.logger.info("Saved model to {}".format(checkpoint_filepath))
self.logger.info("Current best model is {}".format(self.best_checkpoint_filepath))
def _elapsed_time(self):
now = datetime.now()
elapsed = now - self.start_time
return str(elapsed).split('.')[0] # remove milliseconds