-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
188 lines (137 loc) · 6.36 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
import numpy as np
import argparse
import torch
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter
# from models import cls_model, seg_model
from models import cls_model, seg_model
from data_loader import get_data_loader
from utils import save_checkpoint, create_dir
def train(train_dataloader, model, opt, epoch, args, writer):
model.train()
model.to(args.device)
step = epoch*len(train_dataloader)
epoch_loss = 0
for i, batch in enumerate(train_dataloader):
point_clouds, labels = batch
point_clouds = point_clouds.to(args.device)
labels = labels.to(args.device).to(torch.long)
# ------ TO DO: Forward Pass ------
predictions = model(point_clouds)
if (args.task == "seg"):
labels = labels.reshape([-1])
predictions = predictions.reshape([-1, args.num_seg_class])
# Compute Loss
criterion = torch.nn.CrossEntropyLoss()
loss = criterion(predictions, labels)
epoch_loss += loss
# Backward and Optimize
opt.zero_grad()
loss.backward()
opt.step()
writer.add_scalar('train_loss', loss.item(), step+i)
return epoch_loss
def test(test_dataloader, model, epoch, args, writer):
model.eval()
# Evaluation in Classification Task
if (args.task == "cls"):
correct_obj = 0
num_obj = 0
for batch in test_dataloader:
point_clouds, labels = batch
point_clouds = point_clouds.to(args.device)
labels = labels.to(args.device).to(torch.long)
# ------ TO DO: Make Predictions ------
with torch.no_grad():
pred_labels = torch.argmax(model(point_clouds), dim=-1, keepdim=False)
correct_obj += pred_labels.eq(labels.data).cpu().sum().item()
num_obj += labels.size()[0]
# Compute Accuracy of Test Dataset
accuracy = correct_obj / num_obj
# Evaluation in Segmentation Task
else:
correct_point = 0
num_point = 0
for batch in test_dataloader:
point_clouds, labels = batch
point_clouds = point_clouds.to(args.device)
labels = labels.to(args.device).to(torch.long)
# ------ TO DO: Make Predictions ------
with torch.no_grad():
pred_labels = torch.argmax(model(point_clouds), dim=-1, keepdim=False)
correct_point += pred_labels.eq(labels.data).cpu().sum().item()
num_point += labels.view([-1,1]).size()[0]
# Compute Accuracy of Test Dataset
accuracy = correct_point / num_point
writer.add_scalar("test_acc", accuracy, epoch)
return accuracy
def main(args):
"""Loads the data, creates checkpoint and sample directories, and starts the training loop.
"""
# Create Directories
create_dir(args.checkpoint_dir)
create_dir('./logs')
# Tensorboard Logger
writer = SummaryWriter('./logs/{0}'.format(args.task+"_"+args.exp_name))
# ------ TO DO: Initialize Model ------
if args.task == "cls":
model = cls_model()
else:
model = seg_model()
# Load Checkpoint
if args.load_checkpoint:
model_path = "{}/{}.pt".format(args.checkpoint_dir,args.load_checkpoint)
with open(model_path, 'rb') as f:
state_dict = torch.load(f, map_location=args.device)
model.load_state_dict(state_dict)
print ("successfully loaded checkpoint from {}".format(model_path))
# Optimizer
opt = optim.Adam(model.parameters(), args.lr, betas=(0.9, 0.999))
# Dataloader for Training & Testing
train_dataloader = get_data_loader(args=args, train=True)
test_dataloader = get_data_loader(args=args, train=False)
print ("successfully loaded data")
best_acc = -1
print ("======== start training for {} task ========".format(args.task))
print ("(check tensorboard for plots of experiment logs/{})".format(args.task+"_"+args.exp_name))
for epoch in range(args.num_epochs):
# Train
train_epoch_loss = train(train_dataloader, model, opt, epoch, args, writer)
# Test
current_acc = test(test_dataloader, model, epoch, args, writer)
print ("epoch: {} train loss: {:.4f} test accuracy: {:.4f}".format(epoch, train_epoch_loss, current_acc))
# Save Model Checkpoint Regularly
if epoch % args.checkpoint_every == 0:
print ("checkpoint saved at epoch {}".format(epoch))
save_checkpoint(epoch=epoch, model=model, args=args, best=False)
# Save Best Model Checkpoint
if (current_acc >= best_acc):
best_acc = current_acc
print ("best model saved at epoch {}".format(epoch))
save_checkpoint(epoch=epoch, model=model, args=args, best=True)
print ("======== training completes ========")
def create_parser():
"""Creates a parser for command-line arguments.
"""
parser = argparse.ArgumentParser()
# Model & Data hyper-parameters
parser.add_argument('--task', type=str, default="cls", help='The task: cls or seg')
parser.add_argument('--num_seg_class', type=int, default=6, help='The number of segmentation classes')
# Training hyper-parameters
parser.add_argument('--num_epochs', type=int, default=150)
parser.add_argument('--batch_size', type=int, default=16, help='The number of images in a batch.')
parser.add_argument('--num_workers', type=int, default=0, help='The number of threads to use for the DataLoader.')
parser.add_argument('--lr', type=float, default=0.001, help='The learning rate (default 0.001)')
parser.add_argument('--exp_name', type=str, default="exp", help='The name of the experiment')
# Directories and checkpoint/sample iterations
parser.add_argument('--main_dir', type=str, default='./data/')
parser.add_argument('--checkpoint_dir', type=str, default='./checkpoints')
parser.add_argument('--checkpoint_every', type=int , default=10)
parser.add_argument('--load_checkpoint', type=str, default='')
return parser
if __name__ == '__main__':
parser = create_parser()
args = parser.parse_args()
args.device = torch.device("cuda" if torch.cuda.is_available() else 'cpu')
args.checkpoint_dir = args.checkpoint_dir+"/"+args.task # checkpoint directory is task specific
main(args)