-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutils.py
118 lines (91 loc) · 3.89 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import os
import torch
from torch import nn
from torch.nn import functional as F
from torch.nn import Parameter
def makedir(path):
if not os.path.lexists(path):
os.makedirs(path)
def get_data_from_iter(data_iter, data_loader):
try:
phos = next(data_iter)
except:
data_iter = iter(data_loader)
phos = next(data_iter)
return phos, data_iter
# Image processing
def flip(ts, dim=-1):
idx = torch.arange(ts.size(dim) - 1, -1, step=-1, device=ts.device)
return ts.index_select(index=idx.long(), dim=dim)
# Model
def save_checkpoint(models, step, save_path):
for key, model in models.items():
model = model.module if hasattr(model, "module") else model
if not hasattr(model, "state_dict"):
continue
data = model.state_dict()
data_save_path = os.path.join(save_path, "model-%s-%s.cpkt" % (key, step))
torch.save(data, data_save_path)
latest_link = os.path.join(save_path, "model-%s-latest.cpkt" % key)
if os.path.lexists(latest_link):
os.remove(latest_link)
os.symlink(data_save_path, latest_link)
print("Save checkpoint, step: %s" % step)
def load_checkpoint(models, save_path, step=-1):
model_marker = step if step > 0 else "latest"
for key, model in models.items():
data_save_path = os.path.join(save_path, "model-%s-%s.cpkt" % (key, model_marker))
model = model.module if hasattr(model, "module") else model
if not hasattr(model, "state_dict"):
continue
pretrained_dict = torch.load(data_save_path, map_location="cpu")
model_dict = model.state_dict()
model_dict.update(pretrained_dict)
model.load_state_dict(model_dict)
print("Load checkpoint, step: %s" % model_marker)
# SpectralNorm
def l2normalize(v, eps=1e-12):
return v / (v.norm() + eps)
class SpectralNorm(nn.Module):
def __init__(self, module, name='weight', power_iterations=1):
super(SpectralNorm, self).__init__()
self.module = module
self.name = name
self.power_iterations = power_iterations
if not self._made_params():
self._make_params()
def _update_u_v(self):
u = getattr(self.module, self.name + "_u")
v = getattr(self.module, self.name + "_v")
w = getattr(self.module, self.name + "_bar")
height = w.data.shape[0]
for _ in range(self.power_iterations):
v.data = l2normalize(torch.mv(torch.t(w.view(height, -1).data), u.data))
u.data = l2normalize(torch.mv(w.view(height, -1).data, v.data))
# sigma = torch.dot(u.data, torch.mv(w.view(height,-1).data, v.data))
sigma = u.dot(w.view(height, -1).mv(v))
setattr(self.module, self.name, w / sigma.expand_as(w))
def _made_params(self):
try:
u = getattr(self.module, self.name + "_u")
v = getattr(self.module, self.name + "_v")
w = getattr(self.module, self.name + "_bar")
return True
except AttributeError:
return False
def _make_params(self):
w = getattr(self.module, self.name)
height = w.data.shape[0]
width = w.view(height, -1).data.shape[1]
u = Parameter(w.data.new(height).normal_(0, 1), requires_grad=False)
v = Parameter(w.data.new(width).normal_(0, 1), requires_grad=False)
u.data = l2normalize(u.data)
v.data = l2normalize(v.data)
w_bar = Parameter(w.data)
del self.module._parameters[self.name]
self.module.register_parameter(self.name + "_u", u)
self.module.register_parameter(self.name + "_v", v)
self.module.register_parameter(self.name + "_bar", w_bar)
def forward(self, *args):
self._update_u_v()
return self.module.forward(*args)