-
Notifications
You must be signed in to change notification settings - Fork 12
/
train.py
127 lines (108 loc) · 4.71 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import argparse
import os
import sys
import torch
import torch.backends.cudnn as cudnn
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
import voc.transforms as transforms
from encoder import DataEncoder
from loss import FocalLoss
from retinanet import RetinaNet
from voc.datasets import VocLikeDataset
parser = argparse.ArgumentParser(description='PyTorch RetinaNet Training')
parser.add_argument('--exp', required=True, help='experiment name')
parser.add_argument('--resume', '-r', action='store_true', help='resume from checkpoint')
args = parser.parse_args()
sys.path.insert(0, os.path.join('exps', args.exp))
import config as cfg
assert torch.cuda.is_available(), 'Error: CUDA not found!'
best_loss = float('inf')
start_epoch = 0
lr = cfg.lr
print('Preparing data..')
train_transform_list = [transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize(cfg.mean, cfg.std)]
if cfg.scale is not None:
train_transform_list.insert(0, transforms.Scale(cfg.scale))
train_transform = transforms.Compose(train_transform_list)
val_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(cfg.mean, cfg.std)
])
trainset = VocLikeDataset(image_dir=cfg.image_dir, annotation_dir=cfg.annotation_dir, imageset_fn=cfg.train_imageset_fn,
image_ext=cfg.image_ext, classes=cfg.classes, encoder=DataEncoder(), transform=train_transform)
valset = VocLikeDataset(image_dir=cfg.image_dir, annotation_dir=cfg.annotation_dir, imageset_fn=cfg.val_imageset_fn,
image_ext=cfg.image_ext, classes=cfg.classes, encoder=DataEncoder(), transform=val_transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=cfg.batch_size, shuffle=True,
num_workers=cfg.num_workers, collate_fn=trainset.collate_fn)
valloader = torch.utils.data.DataLoader(valset, batch_size=cfg.batch_size, shuffle=False,
num_workers=cfg.num_workers, collate_fn=valset.collate_fn)
print('Building model...')
net = RetinaNet(backbone=cfg.backbone, num_classes=len(cfg.classes))
if args.resume:
print('Resuming from checkpoint..')
checkpoint = torch.load(os.path.join('ckpts', args.exp, 'ckpt.pth'))
net.load_state_dict(checkpoint['net'])
best_loss = checkpoint['loss']
start_epoch = checkpoint['epoch']
lr = checkpoint['lr']
net = torch.nn.DataParallel(net, device_ids=range(torch.cuda.device_count()))
net.cuda()
cudnn.benchmark = True
criterion = FocalLoss(len(cfg.classes))
optimizer = optim.SGD(net.parameters(), lr=lr, momentum=cfg.momentum, weight_decay=cfg.weight_decay)
def train(epoch):
print('\nTrain Epoch: %d' % epoch)
net.train()
train_loss = 0
for batch_idx, (inputs, loc_targets, cls_targets) in enumerate(trainloader):
inputs = Variable(inputs.cuda())
loc_targets = Variable(loc_targets.cuda())
cls_targets = Variable(cls_targets.cuda())
optimizer.zero_grad()
loc_preds, cls_preds = net(inputs)
loss = criterion(loc_preds, loc_targets, cls_preds, cls_targets)
loss.backward()
nn.utils.clip_grad_norm(net.parameters(), max_norm=1.2)
optimizer.step()
train_loss += loss.data[0]
print('train_loss: %.3f | avg_loss: %.3f' % (loss.data[0], train_loss/(batch_idx+1)))
save_checkpoint(train_loss, len(trainloader))
def val(epoch):
net.eval()
val_loss = 0
for batch_idx, (inputs, loc_targets, cls_targets) in enumerate(valloader):
inputs = Variable(inputs.cuda())
loc_targets = Variable(loc_targets.cuda())
cls_targets = Variable(cls_targets.cuda())
loc_preds, cls_preds = net(inputs)
loss = criterion(loc_preds, loc_targets, cls_preds, cls_targets)
val_loss += loss.data[0]
print('val_loss: %.3f | avg_loss: %.3f' % (loss.data[0], val_loss/(batch_idx+1)))
save_checkpoint(val_loss, len(valloader))
def save_checkpoint(loss, n):
global best_loss
loss /= n
if loss < best_loss:
print('Saving..')
state = {
'net': net.module.state_dict(),
'loss': loss,
'epoch': epoch,
'lr': lr
}
ckpt_path = os.path.join('ckpts', args.exp)
if not os.path.isdir(ckpt_path):
os.makedirs(ckpt_path)
torch.save(state, os.path.join(ckpt_path, 'ckpt.pth'))
best_loss = loss
for epoch in range(start_epoch + 1, start_epoch + cfg.num_epochs + 1):
if epoch in cfg.lr_decay_epochs:
lr *= 0.1
for param_group in optimizer.param_groups:
param_group['lr'] = lr
train(epoch)
if cfg.eval_while_training and epoch % cfg.eval_every == 0:
val(epoch)