-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
124 lines (106 loc) · 4.33 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import torch
from torch import autograd, clone, nn, optim, cuda
from utils import save_experiment, save_checkpoint
from data import prepare_data
from vit import ViTForClassification
config = {
"patch_size": 4,
"hidden_size": 48,
"num_hidden_layers": 4,
"num_attention_heads": 4,
"intermediate_size": 4 * 48,
"hidden_dropout_prob": 0.0,
"attention_probs_dropout_prob": 0.0,
"initializer_range": 0.02,
"image_size": 32,
"num_classes": 10,
"num_channels": 3,
"qkv_bias": True,
"use_faster_attention": True,
}
assert config["hidden_size"] % config["num_attention_heads"] == 0
assert config["intermediate_size"] == 4 * config["hidden_size"]
assert config["image_size"] % config["patch_size"] == 0
class Trainer:
def __init__(self, model, optimizer, loss_fn, exp_name, device):
self.model = model.to(device)
self.optimizer = optimizer
self.loss_fn = loss_fn
self.exp_name = exp_name
self.device = device
def train(self, trainloader, testloader, epochs, save_model_every_n_epochs=0):
train_losses, test_losses, accuracies = [], [], []
#Train the model
for i in range(epochs):
train_loss = self.train_epoch(trainloader)
accuracy, test_loss = self.evaluate(testloader)
train_losses.append(train_losses)
test_losses.append(test_loss)
accuracies.append(accuracy)
print(f"Epoch: {i+1}, Train loss: {train_loss:.4f}, Test loss: {test_loss:.4f}, Accuracy: {accuracy:.4f}")
if save_model_every_n_epochs > 0 and (i+1) % save_model_every_n_epochs == 0 and i+1 != epochs:
print('\tSave checkpoint at epoch', i+1)
save_checkpoint(self.exp_name, self.model, i+1)
save_experiment(self.exp_name, config, self.model, train_losses, test_losses, accuracies)
def train_epoch(self, trainloader):
torch.autograd.set_detect_anomaly(True)
self.model.train()
total_loss = 0
for batch in trainloader:
batch = [t.to(self.device) for t in batch]
images, labels = batch
self.optimizer.zero_grad()
loss = self.loss_fn(self.model(images)[0], labels)
loss = clone(loss)
loss.backward()
self.optimizer.step()
total_loss = total_loss + loss.item() * len(images)
return total_loss / len(trainloader.dataset)
@torch.no_grad()
def evaluate(self, testloader):
self.model.eval()
total_loss = 0
correct = 0
with torch.no_grad():
for batch in testloader:
batch = [t.to(self.device) for t in batch]
images, labels = batch
logits, _ = self.model(images)
loss = self.loss_fn(logits, labels)
total_loss = total_loss + loss.item() * len(images)
predictions = torch.argmax(logits, dim=1)
correct += torch.sum(predictions == labels).item()
accuracy = correct / len(testloader.dataset)
avg_loss = total_loss / len(testloader.dataset)
return accuracy, avg_loss
def parse_args():
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--exp-name", type=str, required=True)
parser.add_argument("--batch-size", type=int, default=256)
parser.add_argument("--epochs", type=int, default=100)
parser.add_argument("--lr", type=float, default=1e-2)
parser.add_argument("--device", type=str)
parser.add_argument("--save-model-every", type=int, default=0)
args = parser.parse_args()
if args.device is None:
args.device = "cuda" if torch.cuda.is_available() else "cpu"
print(args.device)
return args
def main():
args = parse_args()
batch_size = args.batch_size
epochs = args.epochs
lr = args.lr
device = args.device
save_model_every_n_epochs = args.save_model_every
#Load the CIFAR10 dataset
trainloader, testloader, _ = prepare_data(batch_size=batch_size)
#Create the model, optimizer, loss function and trainer
model = ViTForClassification(config)
optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-2)
loss_fn = nn.CrossEntropyLoss()
trainer = Trainer(model, optimizer, loss_fn, args.exp_name, device=device)
trainer.train(trainloader, testloader, epochs, save_model_every_n_epochs=save_model_every_n_epochs)
if __name__ == "__main__":
main()