-
Notifications
You must be signed in to change notification settings - Fork 0
/
resnet.py
101 lines (76 loc) · 2.47 KB
/
resnet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import os.path
import torch.nn as nn
import torch
from torchvision import models
class Resnet(object):
"""
?
"""
def __init__(
self, model,
checkpoint,
device,
optimizer,
criterion,
lr_scheduler=None
):
# model.fc = nn.Linear(model.fc.in_features, n_classes)
if checkpoint is not None:
checkpoint_path = os.path.join('learners', checkpoint)
self.model.load_state_dict(torch.load(checkpoint_path))
# x = torch.load('../pytorch-AE/data.pt')
# y = torch.load('../pytorch-AE/label.pt')
# rep = self.model.encode(x)
# recon = self.model(x)
# batch_size = x.size(0)
# x_recon = model(x)
# criterion = torch.nn.BCELoss(reduction='none')
# recon_loss = criterion(x_recon, x.view(batch_size, -1)).sum(dim=1)
self.device = device
self.optimizer = optimizer
self.lr_scheduler = lr_scheduler
self.criterion = criterion
self.model_dim = int(self.get_param_tensor().shape[0])
def optimizer_step(self):
"""
perform one optimizer step, requires the gradients to be already computed.
"""
self.optimizer.step()
if self.lr_scheduler:
self.lr_scheduler.step()
def get_param_tensor(self):
"""
get `model` parameters as a unique flattened tensor
:return: torch.tensor
"""
param_list = []
for param in self.model.parameters():
param_list.append(param.data.view(-1, ))
return torch.cat(param_list)
def get_grad_tensor(self):
"""
get `model` gradients as a unique flattened tensor
:return: torch.tensor
"""
grad_list = []
for param in self.model.parameters():
if param.grad is not None:
grad_list.append(param.grad.data.view(-1, ))
return torch.cat(grad_list)
def free_memory(self):
"""
free the memory allocated by the model weights
"""
del self.optimizer
del self.model
def free_gradients(self):
"""
free memory allocated by gradients
"""
self.optimizer.zero_grad(set_to_none=True)
def freeze(self):
for param in self.model.parameters():
param.requires_grad = False
def unfreeze(self):
for param in self.model.parameters():
param.requires_grad = True