-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrainer.py
95 lines (83 loc) · 3.85 KB
/
trainer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import csv
import copy
import time
from tqdm import tqdm
import torch
import numpy as np
import os
def train_model(model, criterion, dataloaders, optimizer, metrics, bpath, num_epochs=3):
since = time.time()
best_model_wts = copy.deepcopy(model.state_dict())
best_loss = 1e20
# Use gpu if available
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model.to(device)
# Initialize the log file for training and testing loss and metrics
fieldnames = ['epoch', 'Train_loss', 'Test_loss']
with open(os.path.join(bpath, 'log.csv'), 'w', newline='') as csvfile:
writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
writer.writeheader()
for epoch in range(1, num_epochs+1):
print('Epoch {}/{}'.format(epoch, num_epochs))
print('-' * 10)
# Each epoch has a training and validation phase
# Initialize batch summary
batchsummary = {a: [0] for a in fieldnames}
for phase in ['Train', 'Test']:
if phase == 'Train':
model.train() # Set model to training mode
else:
model.eval() # Set model to evaluate mode
# Iterate over data.
for sample in tqdm(iter(dataloaders[phase])):
#print('heyyyy')
inputs = sample['image'].to(device)
masks = sample['mask'].to(device)
masks = masks.long() # transform the data type to long
# Squeezing a tensor removes the dimensions or axes that have a length of one
# current mask has dim: (batch_size,1,W,H) and I want it to be (batch_size,W,H)
masks = masks.squeeze()
#print(masks.shape)
#print(inputs.shape)
# zero the parameter gradients
optimizer.zero_grad()
# track history if only in train
with torch.set_grad_enabled(phase == 'Train'):
outputs = model(inputs)
loss = criterion(outputs['out'], masks)
y_pred = outputs['out'].data.cpu().numpy().ravel()
y_true = masks.data.cpu().numpy().ravel()
# for name, metric in metrics.items():
# if name == 'f1_score':
# # Use a classification threshold of 0.1
# batchsummary[f'{phase}_{name}'].append(
# metric(y_true > 0, y_pred > 0.1))
# else:
# batchsummary[f'{phase}_{name}'].append(
# metric(y_true.astype('uint8'), y_pred))
# backward + optimize only if in training phase
if phase == 'Train':
loss.backward()
optimizer.step()
batchsummary['epoch'] = epoch
epoch_loss = loss
batchsummary['{}_loss'.format(phase)] = epoch_loss.item()
print('{} Loss: {:.4f}'.format(
phase, loss))
for field in fieldnames[3:]:
batchsummary[field] = np.mean(batchsummary[field])
print(batchsummary)
with open(os.path.join(bpath, 'log.csv'), 'a', newline='') as csvfile:
writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
writer.writerow(batchsummary)
# deep copy the model
if phase == 'Test' and loss < best_loss:
best_loss = loss
best_model_wts = copy.deepcopy(model.state_dict())
time_elapsed = time.time() - since
print('Training complete in {:.0f}m {:.0f}s'.format(
time_elapsed // 60, time_elapsed % 60))
print('Lowest Loss: {:4f}'.format(best_loss))
# load best model weights
model.load_state_dict(best_model_wts)
return model