-
Notifications
You must be signed in to change notification settings - Fork 0
/
train_classification.py
102 lines (83 loc) · 3.94 KB
/
train_classification.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
from __future__ import print_function
import argparse
import os
import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
# import torchvision.datasets as dset
# import torchvision.transforms as transforms
# import torchvision.utils as vutils
from torch.autograd import Variable
from datasets import PartDataset
from pointnet import PointNetCls
import torch.nn.functional as F
parser = argparse.ArgumentParser()
parser.add_argument('--batchSize', type=int, default=16, help='input batch size')
parser.add_argument('--num_points', type=int, default=2500, help='input batch size')
parser.add_argument('--workers', type=int, help='number of data loading workers', default=0)
parser.add_argument('--nepoch', type=int, default=25, help='number of epochs to train for')
parser.add_argument('--outf', type=str, default='cls', help='output folder')
parser.add_argument('--model', type=str, default = '', help='model path')
parser.add_argument('--n_views', type=int, default = 13, help='view numbers')
parser.add_argument('--lr', type=float, default = 0.01, help='learning rate')
parser.add_argument('--momentum', type=float, default = 0.9, help='momentum')
opt = parser.parse_args()
print (opt)
blue = lambda x:'\033[94m' + x + '\033[0m'
opt.manualSeed = random.randint(1, 10000) # fix seed
print("Random Seed: ", opt.manualSeed)
random.seed(opt.manualSeed)
torch.manual_seed(opt.manualSeed)
dataset = PartDataset(root = 'shapenetcore_partanno_segmentation_benchmark_v0', classification = True, npoints = opt.num_points)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=opt.batchSize,
shuffle=True, num_workers=int(opt.workers))
test_dataset = PartDataset(root = 'shapenetcore_partanno_segmentation_benchmark_v0', classification = True, train = False, npoints = opt.num_points)
testdataloader = torch.utils.data.DataLoader(test_dataset, batch_size=opt.batchSize,
shuffle=True, num_workers=int(opt.workers))
print(len(dataset), len(test_dataset))
num_classes = len(dataset.classes)
print('classes', num_classes)
try:
os.makedirs(opt.outf)
except OSError:
pass
classifier = PointNetCls(k = num_classes,views=opt.n_views)
if opt.model != '':
classifier.load_state_dict(torch.load(opt.model))
optimizer = optim.SGD(classifier.parameters(), lr=opt.lr, momentum=opt.momentum)
classifier.cuda()
num_batch = len(dataset)/opt.batchSize
for epoch in range(opt.nepoch):
for i, data in enumerate(dataloader, 0):
points, target = data
points, target = Variable(points), Variable(target[:,0])
points = points.transpose(2,1)
# sys.exit(0)
points, target = points.cuda(), target.cuda()
optimizer.zero_grad()
classifier = classifier.train()
pred = classifier(points)
loss = F.nll_loss(pred, target)
loss.backward()
optimizer.step()
pred_choice = pred.data.max(1)[1]
correct = pred_choice.eq(target.data).cpu().sum()
print('[%d: %d/%d] train loss: %f accuracy: %f' %(epoch, i, num_batch, loss.item(),correct.item() / float(opt.batchSize)))
if i % 10 == 0:
j, data = next(enumerate(testdataloader, 0))
points, target = data
points, target = Variable(points), Variable(target[:,0])
points = points.transpose(2,1)
points, target = points.cuda(), target.cuda()
classifier = classifier.eval()
pred = classifier(points)
loss = F.nll_loss(pred, target)
pred_choice = pred.data.max(1)[1]
correct = pred_choice.eq(target.data).cpu().sum()
print('[%d: %d/%d] %s loss: %f accuracy: %f' %(epoch, i, num_batch, blue('test'), loss.item(), correct.item()/float(opt.batchSize)))
torch.save(classifier.state_dict(), '%s/cls_model_%d.pth' % (opt.outf, epoch))