forked from moskomule/senet.pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutils.py
126 lines (108 loc) · 4.54 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import os
import torch
from torch.autograd import Variable
from torch.optim import Optimizer
from tqdm import tqdm
class Trainer(object):
cuda = torch.cuda.is_available()
def __init__(self, model, optimizer, loss_f, save_dir=None, save_freq=5):
self.model = model
if self.cuda:
model.cuda()
self.optimizer = optimizer
self.loss_f = loss_f
self.save_dir = save_dir
self.save_freq = save_freq
def _loop(self, data_loader, is_train=True):
loop_loss = []
correct = []
for data, target in tqdm(data_loader):
if self.cuda:
data, target = data.cuda(), target.cuda()
data, target = Variable(data, volatile=not is_train), Variable(target, volatile=not is_train)
self.optimizer.zero_grad()
output = self.model(data)
loss = self.loss_f(output, target)
loop_loss.append(loss.data[0] / len(data_loader))
correct.append((output.data.max(1)[1] == target.data).sum() / len(data_loader.dataset))
if is_train:
loss.backward()
self.optimizer.step()
mode = "train" if is_train else "test"
print(f">>>[{mode}] loss: {sum(loop_loss):.2f}/accuracy: {sum(correct):.2%}")
return loop_loss, correct
def train(self, data_loader):
self.model.train()
loss, correct = self._loop(data_loader)
def test(self, data_loader):
self.model.eval()
loss, correct = self._loop(data_loader, is_train=False)
def loop(self, epochs, train_data, test_data, scheduler=None):
for ep in range(1, epochs + 1):
if scheduler is not None:
scheduler.step()
print(f"epochs: {ep}")
self.train(train_data)
self.test(test_data)
if ep % self.save_freq:
self.save(ep)
def save(self, epoch, **kwargs):
if self.save_dir:
name = f"weight-{epoch}-" + "-".join([f"{k}_{v}" for k, v in kwargs.items()]) + ".pkl"
torch.save({"weight": self.model.state_dict()},
os.path.join(self.save_dir, name))
# copied from pytorch's master
class _LRScheduler(object):
def __init__(self, optimizer, last_epoch=-1):
if not isinstance(optimizer, Optimizer):
raise TypeError('{} is not an Optimizer'.format(
type(optimizer).__name__))
self.optimizer = optimizer
if last_epoch == -1:
for group in optimizer.param_groups:
group.setdefault('initial_lr', group['lr'])
else:
for i, group in enumerate(optimizer.param_groups):
if 'initial_lr' not in group:
raise KeyError("param 'initial_lr' is not specified "
"in param_groups[{}] when resuming an optimizer".format(i))
self.base_lrs = list(map(lambda group: group['initial_lr'], optimizer.param_groups))
self.step(last_epoch + 1)
self.last_epoch = last_epoch
def get_lr(self):
raise NotImplementedError
def step(self, epoch=None):
if epoch is None:
epoch = self.last_epoch + 1
self.last_epoch = epoch
for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()):
param_group['lr'] = lr
class StepLR(_LRScheduler):
"""Sets the learning rate of each parameter group to the initial lr
decayed by gamma every step_size epochs. When last_epoch=-1, sets
initial lr as lr.
Args:
optimizer (Optimizer): Wrapped optimizer.
step_size (int): Period of learning rate decay.
gamma (float): Multiplicative factor of learning rate decay.
Default: 0.1.
last_epoch (int): The index of last epoch. Default: -1.
Example:
>>> # Assuming optimizer uses lr = 0.5 for all groups
>>> # lr = 0.05 if epoch < 30
>>> # lr = 0.005 if 30 <= epoch < 60
>>> # lr = 0.0005 if 60 <= epoch < 90
>>> # ...
>>> scheduler = StepLR(optimizer, step_size=30, gamma=0.1)
>>> for epoch in range(100):
>>> scheduler.step()
>>> train(...)
>>> validate(...)
"""
def __init__(self, optimizer, step_size, gamma=0.1, last_epoch=-1):
self.step_size = step_size
self.gamma = gamma
super(StepLR, self).__init__(optimizer, last_epoch)
def get_lr(self):
return [base_lr * self.gamma ** (self.last_epoch // self.step_size)
for base_lr in self.base_lrs]