-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy path_components.py
396 lines (321 loc) · 13.8 KB
/
_components.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
# -*- coding: utf-8 -*-
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Function
from torch.nn import Parameter
from torch.distributions import Multinomial
from torchvision.utils import make_grid
class Visible:
def __init__(self, vis_channels=None):
self.vis = False
self.name = None
self.iter = -1
self.tb_logger = None
self.vis_channels = vis_channels
def vis_feat(self, feat):
with torch.no_grad():
prefix = self.name if self.training else f"{self.name}/eval"
self.tb_logger.add_histogram(prefix + "/feat", feat, self.iter)
if isinstance(self, nn.Conv2d):
y0 = feat[0].unsqueeze(0).permute(1, 0, 2, 3)
if self.vis_channels is not None:
y0 = y0[self.vis_channels]
n_rows = math.ceil(math.sqrt(y0.size(0)))
y0_img = make_grid(y0, nrow=n_rows, normalize=True)
self.tb_logger.add_image(prefix + "/feat", y0_img, self.iter)
def vis_error(self, feat, ref_feat):
with torch.no_grad():
err = feat - ref_feat
snr = 10. * torch.log10(ref_feat.norm() / err.norm())
prefix = self.name if self.training else f"{self.name}/eval"
self.tb_logger.add_histogram(prefix + "/err", err, self.iter)
self.tb_logger.add_scalar("SNR/" + prefix, snr.item(), self.iter)
def vis_weight(self):
if self.vis and self.iter <= 1:
w = self.weight.detach()
self.tb_logger.add_histogram(f"{self.name}/weight", w.view(-1), self.iter)
def vis_bound(self):
if self.vis:
self.tb_logger.add_scalar(f"{self.name}/lb", self.lb.item(), self.iter)
self.tb_logger.add_scalar(f"{self.name}/ub", self.ub.item(), self.iter)
def forward_vis(self, y):
# TODO: rewrite as decorator
if self.vis:
self.vis_feat(y)
if self.is_teacher:
for m in self.students:
m.ref_feat = y.data
else:
# import pdb; pdb.set_trace()
if self.ref_feat is not None:
assert self.ref_feat.shape == y.shape, \
f"FP and quant feat on layer `{self.name}` shape not match"
self.vis_error(y, self.ref_feat)
self.ref_feat = None
self.vis = False
###############################################################################
# Denoiser
###############################################################################
class NonLocal(nn.Module):
def __init__(self, channels, inplace=False, softmax=False):
super().__init__()
self.channels = channels
self.inplace = inplace
self.softmax = softmax
self.mapping = nn.Conv2d(channels, channels, kernel_size=1, bias=False)
def forward(self, x):
n, c, h, w = x.shape
res = x
# adapted from: https://github.com/facebookresearch/ImageNet-Adversarial-Training
theta, phi, g = x, x, x
if c > h * w or self.softmax:
f = torch.einsum('niab,nicd->nabcd', theta, phi)
if self.softmax:
orig_shape = f.shape
f = f.reshape(-1, h * w, h * w)
f = f / torch.sqrt(torch.tensor(c, device=f.device, dtype=f.dtype))
f = torch.softmax(f)
f = f.reshape(orig_shape)
f = torch.einsum('nabcd,nicd->niab', f, g)
else:
f = torch.einsum('nihw,njhw->nij', phi, g)
f = torch.einsum('nij,nihw->njhw', f, theta)
if not self.softmax:
f = f / torch.tensor(h * w, device=f.device, dtype=f.dtype)
f = f.reshape(x.shape)
y = self.mapping(f) + res
y = F.relu(y, self.inplace)
return y
###############################################################################
# Naive quant with STE
###############################################################################
class NaiveQuantSTE(Function):
@staticmethod
def forward(ctx, x, k):
x = x.clone()
n = 2 ** k - 1
lb = x.min()
delta = x.max().sub_(lb).div_(n)
return x.sub(lb).div_(delta).round_().mul_(delta).add_(lb)
@staticmethod
def backward(ctx, dy):
return dy.clone(), None
class NaiveQuantConv2d(nn.Conv2d, Visible):
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
padding=0, dilation=1, groups=1, bias=True,
quant=False, bit_width=4, vis_channels=None):
super().__init__(in_channels, out_channels, kernel_size, stride,
padding, dilation, groups, bias)
Visible.__init__(self, vis_channels)
self.quant = quant
self.bit_width = bit_width
self.register_buffer("ref_feat", None)
def forward(self, input):
if self.quant:
q = NaiveQuantSTE.apply
w = q(self.weight, self.bit_width)
y = F.conv2d(input, w, self.bias, self.stride,
self.padding, self.dilation, self.groups)
else:
y = super().forward(input)
self.forward_vis(y)
return y
class NaiveQuantLinear(nn.Linear, Visible):
def __init__(self, in_features, out_features, bias=True, quant=False,
bit_width=4, vis_channels=None):
super().__init__(in_features, out_features, bias)
Visible.__init__(self, vis_channels)
self.quant = quant
self.bit_width = bit_width
self.register_buffer("ref_feat", None)
def forward(self, input):
if self.quant:
q = NaiveQuantSTE.apply
w = q(self.weight, self.bit_width)
y = F.linear(input, w, self.bias)
else:
y = super().forward(input)
self.forward_vis(y)
return y
###############################################################################
# Differentiable quant boundaries
###############################################################################
class RoundSTE(Function):
@staticmethod
def forward(ctx, x):
return x.round()
@staticmethod
def backward(ctx, dy):
return dy.clone()
class DiffBoundary:
def __init__(self, bit_width=4):
# TODO: add channel-wise option?
self.bit_width = bit_width
self.register_boundaries()
def register_boundaries(self):
assert hasattr(self, "weight")
self.lb = Parameter(self.weight.data.min())
self.ub = Parameter(self.weight.data.max())
def reset_boundaries(self):
assert hasattr(self, "weight")
self.lb.data = self.weight.data.min()
self.ub.data = self.weight.data.max()
def get_quant_weight(self, align_zero=True):
# TODO: set `align_zero`?
if align_zero:
return self._get_quant_weight_align_zero()
else:
return self._get_quant_weight()
def _get_quant_weight(self):
round_ = RoundSTE.apply
w = self.weight.detach()
delta = (self.ub - self.lb) / (2 ** self.bit_width - 1)
w = torch.clamp(w, self.lb.item(), self.ub.item())
idx = round_((w - self.lb).div(delta)) # TODO: do we need STE here?
qw = (idx * delta) + self.lb
return qw
def _get_quant_weight_align_zero(self):
# TODO: WTF?
round_ = RoundSTE.apply
n = 2 ** self.bit_width - 1
w = self.weight.detach()
delta = (self.ub - self.lb) / n
z = round_(self.lb.abs() / delta)
lb = -z * delta
ub = (n - z) * delta
w = torch.clamp(w, lb.item(), ub.item())
idx = round_((w - self.lb).div(delta)) # TODO: do we need STE here?
qw = (idx - z) * delta
return qw
class QConv2dDiffBounds(nn.Conv2d, DiffBoundary, Visible):
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
padding=0, dilation=1, groups=1, bias=True,
quant=False, bit_width=4):
super().__init__(in_channels, out_channels, kernel_size, stride,
padding, dilation, groups, bias)
self.quant = quant
DiffBoundary.__init__(self, bit_width)
Visible.__init__(self)
def forward(self, input):
if self.quant:
w = self.get_quant_weight()
y = F.conv2d(input, w, self.bias, self.stride,
self.padding, self.dilation, self.groups)
else:
y = super().forward(input)
self.vis_weight()
self.vis_bound()
return y
class QLinearDiffBounds(nn.Linear, DiffBoundary, Visible):
def __init__(self, in_features, out_features, bias=True, quant=False,
bit_width=4):
super().__init__(in_features, out_features, bias)
self.quant = quant
DiffBoundary.__init__(self, bit_width)
Visible.__init__(self)
def forward(self, input):
if self.quant:
w = self.get_quant_weight()
y = F.linear(input, w, self.bias)
else:
y = super().forward(input)
self.vis_weight()
self.vis_bound()
return y
###############################################################################
# Probabilistic quant with local reparameterization trick
###############################################################################
def inv_sigmoid(x, lb, ub):
x = torch.clamp(x, lb, ub)
return - torch.log(1. / x - 1.)
class TernaryConv2d(nn.Conv2d, Visible):
"""Implementation of `LR-Nets`(https://arxiv.org/abs/1710.07739)."""
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
padding=0, dilation=1, groups=1, bias=True, quant=False,
p_max=0.95, p_min=0.05, vis_channels=None):
super().__init__(in_channels, out_channels, kernel_size, stride,
padding, dilation, groups, bias)
Visible.__init__(self, vis_channels)
self.quant = quant
self.p_max = p_max
self.p_min = p_min
self.eps = 1e-5
self.register_buffer("w_candidate", torch.tensor([-1., 0., 1.]))
self.p_a = Parameter(torch.zeros_like(self.weight))
self.p_b = Parameter(torch.zeros_like(self.weight))
self.reset_p()
def reset_p(self):
w = self.weight.data / self.weight.data.std()
self.p_a.data = inv_sigmoid(self.p_max - (self.p_max - self.p_min) * w.abs(), self.p_min, self.p_max)
self.p_b.data = inv_sigmoid(0.5 * (1. + w / (1. - torch.sigmoid(self.p_a.data))), self.p_min, self.p_max)
def forward(self, input):
if self.quant:
p_a = torch.sigmoid(self.p_a)
p_b = torch.sigmoid(self.p_b)
p_w_0 = p_a
p_w_pos = p_b * (1. - p_w_0)
p_w_neg = (1. - p_b) * (1. - p_w_0)
p = torch.stack([p_w_neg, p_w_0, p_w_pos], dim=-1)
if self.training:
w_mean = (p * self.w_candidate).sum(dim=-1)
w_var = (p * self.w_candidate.pow(2)).sum(dim=-1) - w_mean.pow(2)
act_mean = F.conv2d(input, w_mean, self.bias, self.stride,
self.padding, self.dilation, self.groups)
act_var = F.conv2d(input.pow(2), w_var, None, self.stride,
self.padding, self.dilation, self.groups)
var_eps = torch.randn_like(act_mean)
y = act_mean + var_eps * act_var.add(self.eps).sqrt()
else:
m = Multinomial(probs=p)
indices = m.sample().argmax(dim=-1)
w = self.w_candidate[indices]
y = F.conv2d(input, w, self.bias, self.stride,
self.padding, self.dilation, self.groups)
else:
y = super().forward(input)
self.forward_vis(y)
return y
class TernaryLinear(nn.Linear, Visible):
"""Implementation of `LR-Nets`(https://arxiv.org/abs/1710.07739)."""
def __init__(self, in_features, out_features, bias=True, quant=False,
p_max=0.95, p_min=0.05, vis_channels=None):
super().__init__(in_features, out_features, bias)
Visible.__init__(self, vis_channels)
self.quant = quant
self.p_max = p_max
self.p_min = p_min
self.eps = 1e-5
self.register_buffer("w_candidate", torch.tensor([-1., 0., 1.]))
self.p_a = Parameter(torch.zeros_like(self.weight))
self.p_b = Parameter(torch.zeros_like(self.weight))
self.reset_p()
def reset_p(self):
w = self.weight.data / self.weight.data.std()
self.p_a.data = inv_sigmoid(self.p_max - (self.p_max - self.p_min) * w.abs(), self.p_min, self.p_max)
self.p_b.data = inv_sigmoid(0.5 * (1. + w / (1. - torch.sigmoid(self.p_a.data))), self.p_min, self.p_max)
def forward(self, input):
if self.quant:
p_a = torch.sigmoid(self.p_a)
p_b = torch.sigmoid(self.p_b)
p_w_0 = p_a
p_w_pos = p_b * (1. - p_w_0)
p_w_neg = (1. - p_b) * (1. - p_w_0)
p = torch.stack([p_w_neg, p_w_0, p_w_pos], dim=-1)
if self.training:
w_mean = (p * self.w_candidate).sum(dim=-1)
w_var = (p * self.w_candidate.pow(2)).sum(dim=-1) - w_mean.pow(2)
act_mean = F.linear(input, w_mean, self.bias)
act_var = F.linear(input.pow(2), w_var, None)
var_eps = torch.randn_like(act_mean)
y = act_mean + var_eps * act_var.add(self.eps).sqrt()
else:
m = Multinomial(probs=p)
indices = m.sample().argmax(dim=-1)
w = self.w_candidate[indices]
y = F.linear(input, w, self.bias)
else:
y = super().forward(input)
self.forward(y)
return y