-
Notifications
You must be signed in to change notification settings - Fork 82
/
task.py
155 lines (125 loc) · 5.18 KB
/
task.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
import logging
from typing import List
import torch
from torch import optim, nn
from torch.nn import Module
from torch.optim import Optimizer
from torch.optim.lr_scheduler import CosineAnnealingLR
from torchvision.transforms import transforms
from metrics.accuracy_metric import AccuracyMetric
from metrics.metric import Metric
from metrics.test_loss_metric import TestLossMetric
from tasks.batch import Batch
from utils.parameters import Params
logger = logging.getLogger('logger')
class Task:
params: Params = None
train_dataset = None
test_dataset = None
train_loader = None
test_loader = None
classes = None
model: Module = None
optimizer: optim.Optimizer = None
criterion: Module = None
scheduler: CosineAnnealingLR = None
metrics: List[Metric] = None
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
"Generic normalization for input data."
input_shape: torch.Size = None
def __init__(self, params: Params):
self.params = params
self.init_task()
def init_task(self):
self.load_data()
self.model = self.build_model()
self.resume_model()
self.model = self.model.to(self.params.device)
self.optimizer = self.make_optimizer()
self.criterion = self.make_criterion()
self.metrics = [AccuracyMetric(), TestLossMetric(self.criterion)]
self.set_input_shape()
def load_data(self) -> None:
raise NotImplemented
def build_model(self) -> Module:
raise NotImplemented
def make_criterion(self) -> Module:
"""Initialize with Cross Entropy by default.
We use reduction `none` to support gradient shaping defense.
:return:
"""
return nn.CrossEntropyLoss(reduction='none')
def make_optimizer(self, model=None) -> Optimizer:
if model is None:
model = self.model
if self.params.optimizer == 'SGD':
optimizer = optim.SGD(model.parameters(),
lr=self.params.lr,
weight_decay=self.params.decay,
momentum=self.params.momentum)
elif self.params.optimizer == 'Adam':
optimizer = optim.Adam(model.parameters(),
lr=self.params.lr,
weight_decay=self.params.decay)
else:
raise ValueError(f'No optimizer: {self.optimizer}')
return optimizer
def make_scheduler(self) -> None:
if self.params.scheduler:
self.scheduler = CosineAnnealingLR(self.optimizer, T_max=self.params.epochs)
def resume_model(self):
if self.params.resume_model:
logger.info(f'Resuming training from {self.params.resume_model}')
loaded_params = torch.load(f"saved_models/"
f"{self.params.resume_model}",
map_location=torch.device('cpu'))
self.model.load_state_dict(loaded_params['state_dict'])
self.params.start_epoch = loaded_params['epoch']
self.params.lr = loaded_params.get('lr', self.params.lr)
logger.warning(f"Loaded parameters from saved model: LR is"
f" {self.params.lr} and current epoch is"
f" {self.params.start_epoch}")
def set_input_shape(self):
inp = self.train_dataset[0][0]
self.params.input_shape = inp.shape
def get_batch(self, batch_id, data) -> Batch:
"""Process data into a batch.
Specific for different datasets and data loaders this method unifies
the output by returning the object of class Batch.
:param batch_id: id of the batch
:param data: object returned by the Loader.
:return:
"""
inputs, labels = data
batch = Batch(batch_id, inputs, labels)
return batch.to(self.params.device)
def accumulate_metrics(self, outputs, labels):
for metric in self.metrics:
metric.accumulate_on_batch(outputs, labels)
def reset_metrics(self):
for metric in self.metrics:
metric.reset_metric()
def report_metrics(self, step, prefix='',
tb_writer=None, tb_prefix='Metric/'):
metric_text = []
for metric in self.metrics:
metric_text.append(str(metric))
metric.plot(tb_writer, step, tb_prefix=tb_prefix)
logger.warning(f'{prefix} {step:4d}. {" | ".join(metric_text)}')
return self.metrics[0].get_main_metric_value()
@staticmethod
def get_batch_accuracy(outputs, labels, top_k=(1,)):
"""Computes the precision@k for the specified values of k"""
max_k = max(top_k)
batch_size = labels.size(0)
_, pred = outputs.topk(max_k, 1, True, True)
pred = pred.t()
correct = pred.eq(labels.view(1, -1).expand_as(pred))
res = []
for k in top_k:
correct_k = correct[:k].view(-1).float().sum(0)
res.append((correct_k.mul_(100.0 / batch_size)).item())
if len(res) == 1:
res = res[0]
return res