-
Notifications
You must be signed in to change notification settings - Fork 186
/
Copy pathmain.py
117 lines (88 loc) · 3.58 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
#Copyright (C) 2019. Huawei Technologies Co., Ltd. All rights reserved.
#This program is free software; you can redistribute it and/or modify it under the terms of the BSD 3-Clause License.
#This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the BSD 3-Clause License for more details.
import os
from resnet20 import resnet20
import torch
from torch.autograd import Variable
from torchvision.datasets import CIFAR10
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import argparse
import math
parser = argparse.ArgumentParser(description='train-addernet')
# Basic model parameters.
parser.add_argument('--data', type=str, default='/cache/data/')
parser.add_argument('--output_dir', type=str, default='/cache/models/')
args = parser.parse_args()
os.makedirs(args.output_dir, exist_ok=True)
acc = 0
acc_best = 0
transform_train = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
transform_test = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
data_train = CIFAR10(args.data,
transform=transform_train)
data_test = CIFAR10(args.data,
train=False,
transform=transform_test)
data_train_loader = DataLoader(data_train, batch_size=256, shuffle=True, num_workers=8)
data_test_loader = DataLoader(data_test, batch_size=100, num_workers=0)
net = resnet20().cuda()
criterion = torch.nn.CrossEntropyLoss().cuda()
optimizer = torch.optim.SGD(net.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)
def adjust_learning_rate(optimizer, epoch):
"""For resnet, the lr starts from 0.1, and is divided by 10 at 80 and 120 epochs"""
lr = 0.05 * (1+math.cos(float(epoch)/400*math.pi))
for param_group in optimizer.param_groups:
param_group['lr'] = lr
def train(epoch):
adjust_learning_rate(optimizer, epoch)
global cur_batch_win
net.train()
loss_list, batch_list = [], []
for i, (images, labels) in enumerate(data_train_loader):
images, labels = Variable(images).cuda(), Variable(labels).cuda()
optimizer.zero_grad()
output = net(images)
loss = criterion(output, labels)
loss_list.append(loss.data.item())
batch_list.append(i+1)
if i == 1:
print('Train - Epoch %d, Batch: %d, Loss: %f' % (epoch, i, loss.data.item()))
loss.backward()
optimizer.step()
def test():
global acc, acc_best
net.eval()
total_correct = 0
avg_loss = 0.0
with torch.no_grad():
for i, (images, labels) in enumerate(data_test_loader):
images, labels = Variable(images).cuda(), Variable(labels).cuda()
output = net(images)
avg_loss += criterion(output, labels).sum()
pred = output.data.max(1)[1]
total_correct += pred.eq(labels.data.view_as(pred)).sum()
avg_loss /= len(data_test)
acc = float(total_correct) / len(data_test)
if acc_best < acc:
acc_best = acc
print('Test Avg. Loss: %f, Accuracy: %f' % (avg_loss.data.item(), acc))
def train_and_test(epoch):
train(epoch)
test()
def main():
epoch = 400
for e in range(1, epoch):
train_and_test(e)
torch.save(net,args.output_dir + 'addernet')
if __name__ == '__main__':
main()