-
Notifications
You must be signed in to change notification settings - Fork 4
/
lr_find.py
209 lines (178 loc) · 7.24 KB
/
lr_find.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
import tempfile
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tqdm.auto import tqdm
K = keras.backend
class Scheduler:
def __init__(self, vals, n_iter: int) -> None:
'Used to "step" from start,end (`vals`) over `n_iter` s on a schedule defined by `func`'
self.start, self.end = (
(vals[0], vals[1]) if isinstance(vals, tuple) else (vals, 0)
)
self.n_iter = max(1, n_iter)
self.func = self._aannealing_exp
self.n = 0
@staticmethod
def _aannealing_exp(start: float, end: float, pct: float) -> float:
"Exponentially anneal from `start` to `end` as pct goes from 0.0 to 1.0."
return start * (end / start) ** pct
def restart(self) -> None:
self.n = 0
def step(self) -> float:
self.n += 1
return self.func(self.start, self.end, self.n / self.n_iter)
@property
def is_done(self) -> bool:
"Return `True` if schedule completed."
return self.n >= self.n_iter
class LrFinder:
"""
[LrFinder Implemetation taken from Fast.ai]
(https://github.com/fastai/fastai/tree/master/fastai)
The learning rate range test increases the learning rate in a pre-training run
between two boundaries in a linear or exponential manner. It provides valuable
information on how well the network can be trained over a range of learning rates
and what is the optimal learning rate.
Args:
model (tf.keras.Model): wrapped model
optimizer (tf.keras.optimizers): wrapped optimizer
loss_fn (tf.keras.losses): loss function
Example:
>>> lr_finder = LrFinder(model, optimizer, loss_fn)
>>> lr_finder.range_test(trn_ds, end_lr=100, num_iter=100)
>>> lr_finder.plot_lrs() # to inspect the loss-learning rate graph
"""
def __init__(self,
model: tf.keras.Model,
optimizer: tf.keras.optimizers.Optimizer,
loss_fn: tf.keras.losses.Loss,
) -> None:
self.lrs = []
self.losses = []
self.model = model
self.optimizer = optimizer
self.loss_fn = loss_fn
self.mw = self.model.get_weights()
self.init_lr = K.get_value(self.optimizer.lr)
self.iteration = 0
self.weightsFile = tempfile.mkstemp()[1]
@tf.function
def trn_step(self, xb, yb):
"""performs 1 trainig step"""
with tf.GradientTape() as tape:
logits = self.model(xb, training=True)
main_loss = tf.reduce_mean(self.loss_fn(yb, logits))
loss = tf.add_n([main_loss] + self.model.losses)
grads = tape.gradient(loss, self.model.trainable_variables)
return loss, grads
def range_test(self,
trn_ds: tf.data.Dataset,
start_lr: float = 1e-7,
end_lr: float = 10,
num_iter: int = 100,
beta=0.98,
) -> None:
"""
Explore lr from `start_lr` to `end_lr` over `num_it` s in `model`.
Args:
trn_ds (tf.data.Dataset)
start_lr (float, optional): the starting learning rate for the range test.
Default:1e-07.
end_lr (float, optional): the maximum learning rate to test. Default: 10.
num_iter (int, optional): the number of s over which the test
occurs. Default: 100.
beta (float, optional): the loss smoothing factor within the [0, 1]
interval. The loss is smoothed using exponential smoothing.
Default: 0.98.
"""
# save original model weights
try:
self.model.save_weights(self.weightsFile)
except:
print("Unable to save initial weights, weights of model will change. Re-instantiate model to load previous weights ...")
# start scheduler
sched = Scheduler((start_lr, end_lr), num_iter)
avg_loss, best_loss, = 0.0, 0.0
# set the startig lr
K.set_value(self.optimizer.lr, sched.start)
print(f"Finding best initial lr over {num_iter} steps")
# initialize tqdm bar
bar = tqdm(iterable=range(num_iter))
# iterate over the batches
for (xb, yb) in trn_ds:
self.iteration += 1
loss, grads = self.trn_step(xb, yb)
# compute smoothed loss
avg_loss = beta * avg_loss + (1 - beta) * loss
smoothed_loss = avg_loss / (1 - beta ** self.iteration)
# record best loss
if self.iteration == 1 or smoothed_loss < best_loss:
best_loss = smoothed_loss
# stop if loss is exploding
if sched.is_done or (
smoothed_loss > 4 * best_loss or np.isnan(smoothed_loss)
):
break
# append losses and lrs
self.losses.append(smoothed_loss)
self.lrs.append(K.get_value(self.optimizer.lr))
# update weights
self.optimizer.apply_gradients(
zip(grads, self.model.trainable_variables))
# update lr
K.set_value(self.optimizer.lr, sched.step())
# update tqdm
bar.update(1)
# clean-up
bar.close()
sched.restart()
self._print_prompt()
def _print_prompt(self) -> None:
"Cleanup model weights disturbed during LRFinder exploration."
try:
self.model.load_weights(self.weightsFile)
except:
print(
"Unable to load inital weights. Re-instantiate model to load previous weights ...")
K.set_value(self.optimizer.lr, self.init_lr)
print(
"LR Finder is complete, type {LrFinder}.plot_lrs() to see the graph.")
@staticmethod
def _split_list(vals, skip_start: int, skip_end: int) -> list:
return vals[skip_start:-skip_end] if skip_end > 0 else vals[skip_start:]
def plot_lrs(self,
skip_start: int = 10,
skip_end: int = 5,
suggestion: bool = False,
show_grid: bool = False,
) -> None:
"""
Plot learning rate and losses, trimmed between `skip_start` and `skip_end`.
Optionally plot and return min gradient
"""
lrs = self._split_list(self.lrs, skip_start, skip_end)
losses = self._split_list(self.losses, skip_start, skip_end)
_, ax = plt.subplots(1, 1)
ax.plot(lrs, losses)
ax.set_ylabel("Loss")
ax.set_xlabel("Learning Rate")
ax.set_xscale("log")
if show_grid:
plt.grid(True, which="both", ls="-")
ax.xaxis.set_major_formatter(plt.FormatStrFormatter("%.0e"))
if suggestion:
try:
mg = (np.gradient(np.array(losses))).argmin()
except:
print(
"Failed to compute the gradients, there might not be enough points."
)
return
print(f"Min numerical gradient: {lrs[mg]:.2E}")
ax.plot(lrs[mg], losses[mg], markersize=10,
marker="o", color="red")
self.min_grad_lr = lrs[mg]
ml = np.argmin(losses)
print(f"Min loss divided by 10: {lrs[ml]/10:.2E}")