-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathtrain_test.py
72 lines (63 loc) · 2.66 KB
/
train_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import os
import time
import torch
from tqdm import tqdm
# Device configuration
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
def test(model, data_loader, verbose=False):
"""Measures the accuracy of a model on a data set."""
# Make sure the model is in evaluation mode.
model.eval()
correct = 0
# print('----- Model Evaluation -----')
# We do not need to maintain intermediate activations while testing.
with torch.no_grad():
# Loop over test data.
for features, target in tqdm(data_loader, total=len(data_loader.batch_sampler), desc="Testing"):
# Forward pass.
output = model(features.to(device))
# Get the label corresponding to the highest predicted probability.
pred = output.argmax(dim=1, keepdim=True)
# Count number of correct predictions.
correct += pred.cpu().eq(target.view_as(pred)).sum().item()
# Print test accuracy.
percent = 100. * correct / len(data_loader.sampler)
if verbose:
print(f'Test accuracy: {correct} / {len(data_loader.sampler)} ({percent:.0f}%)')
return percent
def train(model, criterion, data_loader, test_loader, optimizer, num_epochs):
"""Simple training loop for a PyTorch model."""
# Move model to the device (CPU or GPU).
model.to(device)
accs = []
# Exponential moving average of the loss.
ema_loss = None
# print('----- Training Loop -----')
# Loop over epochs.
for epoch in range(num_epochs):
tick = time.time()
model.train()
# Loop over data.
for batch_idx, (features, target) in tqdm(enumerate(data_loader), total=len(data_loader.batch_sampler), desc="training"):
# Forward pass.
output = model(features.to(device))
loss = criterion(output.to(device), target.to(device))
# Backward pass.
optimizer.zero_grad()
loss.backward()
optimizer.step()
# NOTE: It is important to call .item() on the loss before summing.
if ema_loss is None:
ema_loss = loss.item()
else:
ema_loss += (loss.item() - ema_loss) * 0.01
tock = time.time()
acc = test(model, test_loader, verbose=True)
accs.append(acc)
# Print out progress the end of epoch.
print('Epoch: {} \tLoss: {:.6f} \t Time taken: {:.6f} seconds'.format(epoch+1, ema_loss, tock-tick),)
torch.save(model.state_dict(), f'model_{epoch}.ckpt')
print("Model Saved!")
if os.path.isfile(f'model_{epoch-1}.ckpt'):
os.remove(f'model_{epoch-1}.ckpt')
return accs