-
Notifications
You must be signed in to change notification settings - Fork 9
/
sgdr.py
86 lines (71 loc) · 3.38 KB
/
sgdr.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
from tensorflow.keras.callbacks import Callback
import tensorflow.keras.backend as K
import numpy as np
class SGDRScheduler(Callback):
'''Cosine annealing learning rate scheduler with periodic restarts.
# Usage
```python
schedule = SGDRScheduler(min_lr=1e-5,
max_lr=1e-2,
steps_per_epoch=np.ceil(epoch_size/batch_size),
lr_decay=0.9,
cycle_length=5,
mult_factor=1.5)
model.fit(X_train, Y_train, epochs=100, callbacks=[schedule])
```
# Arguments
min_lr: The lower bound of the learning rate range for the experiment.
max_lr: The upper bound of the learning rate range for the experiment.
steps_per_epoch: Number of mini-batches in the dataset. Calculated as `np.ceil(epoch_size/batch_size)`.
lr_decay: Reduce the max_lr after the completion of each cycle.
Ex. To reduce the max_lr by 20% after each cycle, set this value to 0.8.
cycle_length: Initial number of epochs in a cycle.
mult_factor: Scale epochs_to_restart after each full cycle completion.
# References
Blog post: jeremyjordan.me/nn-learning-rate
Original paper: http://arxiv.org/abs/1608.03983
'''
def __init__(self,
min_lr,
max_lr,
steps_per_epoch,
lr_decay=1,
cycle_length=10,
mult_factor=2):
self.min_lr = min_lr
self.max_lr = max_lr
self.lr_decay = lr_decay
self.batch_since_restart = 0
self.next_restart = cycle_length
self.steps_per_epoch = steps_per_epoch
self.cycle_length = cycle_length
self.mult_factor = mult_factor
self.history = {}
def clr(self):
'''Calculate the learning rate.'''
fraction_to_restart = self.batch_since_restart / (self.steps_per_epoch * self.cycle_length)
lr = self.min_lr + 0.5 * (self.max_lr - self.min_lr) * (1 + np.cos(fraction_to_restart * np.pi))
return lr
def on_train_begin(self, logs={}):
'''Initialize the learning rate to the minimum value at the start of training.'''
logs = logs or {}
K.set_value(self.model.optimizer.lr, self.max_lr)
def on_batch_end(self, batch, logs={}):
'''Record previous batch statistics and update the learning rate.'''
logs = logs or {}
self.history.setdefault('lr', []).append(K.get_value(self.model.optimizer.lr))
for k, v in logs.items():
self.history.setdefault(k, []).append(v)
self.batch_since_restart += 1
K.set_value(self.model.optimizer.lr, self.clr())
def on_epoch_end(self, epoch, logs={}):
'''Check for end of current cycle, apply restarts when necessary.'''
if epoch + 1 == self.next_restart:
self.batch_since_restart = 0
self.cycle_length = np.ceil(self.cycle_length * self.mult_factor)
self.next_restart += self.cycle_length
self.max_lr *= self.lr_decay
self.best_weights = self.model.get_weights()
def on_train_end(self, logs={}):
'''Set weights to the values from the end of the most recent cycle for best performance.'''
self.model.set_weights(self.best_weights)