forked from titu1994/Neural-Style-Transfer
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtf_bfgs.py
219 lines (161 loc) · 7.62 KB
/
tf_bfgs.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
from abc import ABC, abstractmethod
import numpy as np
import tensorflow as tf
import tensorflow_probability as tfp
# Ported from https://pychao.com/2019/11/02/optimize-tensorflow-keras-models-with-l-bfgs-from-tensorflow-probability/
class AbstractTFPOptimizer(ABC):
def __init__(self, trace_function=False):
super(AbstractTFPOptimizer, self).__init__()
self.trace_function = trace_function
self.callback_list = None
def _function_wrapper(self, loss_func, model):
"""A factory to create a function required by tfp.optimizer.lbfgs_minimize.
Args:
loss_func: a function with signature loss_value = loss(model).
model: an instance of `tf.keras.Model` or its subclasses.
Returns:
A function that has a signature of:
loss_value, gradients = f(model_parameters).
"""
# obtain the shapes of all trainable parameters in the model
shapes = tf.shape_n(model.trainable_variables)
n_tensors = len(shapes)
# we'll use tf.dynamic_stitch and tf.dynamic_partition later, so we need to
# prepare required information first
count = 0
idx = [] # stitch indices
part = [] # partition indices
for i, shape in enumerate(shapes):
n = np.product(shape)
idx.append(tf.reshape(tf.range(count, count + n, dtype=tf.int32), shape))
part.extend([i] * n)
count += n
part = tf.constant(part)
@tf.function
def assign_new_model_parameters(params_1d):
"""A function updating the model's parameters with a 1D tf.Tensor.
Args:
params_1d [in]: a 1D tf.Tensor representing the model's trainable parameters.
"""
params = tf.dynamic_partition(params_1d, part, n_tensors)
for i, (shape, param) in enumerate(zip(shapes, params)):
model.trainable_variables[i].assign(tf.reshape(param, shape))
# now create a function that will be returned by this factory
def f(params_1d):
"""A function that can be used by tfp.optimizer.lbfgs_minimize.
This function is created by function_factory.
Args:
params_1d [in]: a 1D tf.Tensor.
Returns:
A scalar loss and the gradients w.r.t. the `params_1d`.
"""
# use GradientTape so that we can calculate the gradient of loss w.r.t. parameters
with tf.GradientTape() as tape:
# update the parameters in the model
assign_new_model_parameters(params_1d)
# calculate the loss
loss_value = loss_func(model)
# calculate gradients and convert to 1D tf.Tensor
grads = tape.gradient(loss_value, model.trainable_variables)
grads = tf.dynamic_stitch(idx, grads)
# print out iteration & loss
f.iter.assign_add(1)
tf.print("Iter:", f.iter, "loss:", loss_value)
if self.callback_list is not None:
info_dict = {
'iter': f.iter,
'loss': loss_value,
'grad': grads,
}
for callback in self.callback_list:
callback(model, info_dict=info_dict)
return loss_value, grads
if self.trace_function:
f = tf.function(f)
# store these information as members so we can use them outside the scope
f.iter = tf.Variable(0, trainable=False)
f.idx = idx
f.part = part
f.shapes = shapes
f.assign_new_model_parameters = assign_new_model_parameters
return f
def register_callback(self, callable):
"""
Accepts a callable with signature `callback(model, info_dict=None)`.
Callable should not return anything, it will not be dealt with.
`info_dict` will contain the following information:
- Optimizer iteration number (key = 'iter')
- Loss value (key = 'loss')
- Grad value (key = 'grad')
Args:
callable: A callable function with the signature `callable(model, info_dict=None)`.
See above for what info_dict can contain.
"""
if self.callback_list is None:
self.callback_list = []
self.callback_list.append(callable)
@abstractmethod
def minimize(self, loss_func, model):
pass
class BFGSOptimizer(AbstractTFPOptimizer):
def __init__(self, max_iterations=50, tolerance=1e-8, bfgs_kwargs=None, trace_function=False):
super(BFGSOptimizer, self).__init__(trace_function=trace_function)
self.max_iterations = max_iterations
self.tolerance = tolerance
bfgs_kwargs = bfgs_kwargs or {}
if 'max_iterations' in bfgs_kwargs.keys():
del bfgs_kwargs['max_iterations']
if 'tolerance' in bfgs_kwargs.keys():
keys = [key for key in bfgs_kwargs.keys()
if 'tolerance' in key]
for key in keys:
del bfgs_kwargs[key]
self.bfgs_kwargs = bfgs_kwargs
def minimize(self, loss_func, model):
optim_func = self._function_wrapper(loss_func, model)
# convert initial model parameters to a 1D tf.Tensor
init_params = tf.dynamic_stitch(optim_func.idx, model.trainable_variables)
# train the model with BFGS solver
results = tfp.optimizer.bfgs_minimize(
value_and_gradients_function=optim_func, initial_position=init_params,
max_iterations=self.max_iterations,
tolerance=self.tolerance,
x_tolerance=self.tolerance,
f_relative_tolerance=self.tolerance,
**self.bfgs_kwargs)
# after training, the final optimized parameters are still in results.position
# so we have to manually put them back to the model
optim_func.assign_new_model_parameters(results.position)
print("BFGS complete, and parameters updated !")
return model
class LBFGSOptimizer(AbstractTFPOptimizer):
def __init__(self, max_iterations=50, tolerance=1e-8, lbfgs_kwargs=None, trace_function=False):
super(LBFGSOptimizer, self).__init__(trace_function=trace_function)
self.max_iterations = max_iterations
self.tolerance = tolerance
lbfgs_kwargs = lbfgs_kwargs or {}
if 'max_iterations' in lbfgs_kwargs.keys():
del lbfgs_kwargs['max_iterations']
if 'tolerance' in lbfgs_kwargs.keys():
keys = [key for key in lbfgs_kwargs.keys()
if 'tolerance' in key]
for key in keys:
del lbfgs_kwargs[key]
self.lbfgs_kwargs = lbfgs_kwargs
def minimize(self, loss_func, model):
optim_func = self._function_wrapper(loss_func, model)
# convert initial model parameters to a 1D tf.Tensor
init_params = tf.dynamic_stitch(optim_func.idx, model.trainable_variables)
# train the model with L-BFGS solver
results = tfp.optimizer.lbfgs_minimize(
value_and_gradients_function=optim_func, initial_position=init_params,
max_iterations=self.max_iterations,
tolerance=self.tolerance,
x_tolerance=self.tolerance,
f_relative_tolerance=self.tolerance,
**self.lbfgs_kwargs)
# after training, the final optimized parameters are still in results.position
# so we have to manually put them back to the model
optim_func.assign_new_model_parameters(results.position)
print("L-BFGS complete, and parameters updated !")
return model