-
Notifications
You must be signed in to change notification settings - Fork 26
/
Copy pathtrain.py
executable file
·164 lines (122 loc) · 5.24 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
import argparse
import os
import torch
import numpy as np
from torch.utils import data
from tqdm import tqdm
from tensorboardX import SummaryWriter
from loss import CalculateLoss
from partial_conv_net import PartialConvUNet
from places2_train import Places2Data
class SubsetSampler(data.sampler.Sampler):
def __init__(self, start_sample, num_samples):
self.num_samples = num_samples
self.start_sample = start_sample
def __iter__(self):
return iter(range(self.start_sample, self.num_samples))
def __len__(self):
return self.num_samples
def requires_grad(param):
return param.requires_grad
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--train_path", type=str, default="/data_256")
parser.add_argument("--mask_path", type=str, default="/mask")
parser.add_argument("--val_path", type=str, default="/val_256")
parser.add_argument("--log_dir", type=str, default="/training_logs")
parser.add_argument("--save_dir", type=str, default="/model")
parser.add_argument("--load_model", type=str)
parser.add_argument("--lr", type=float, default=2e-4)
parser.add_argument("--fine_tune_lr", type=float, default=5e-5)
parser.add_argument("--batch_size", type=int, default=32)
parser.add_argument("--epochs", type=int, default=3)
parser.add_argument("--fine_tune", action="store_true")
parser.add_argument("--gpu", type=int, default=0)
parser.add_argument("--num_workers", type=int, default=32)
parser.add_argument("--log_interval", type=int, default=10)
parser.add_argument("--save_interval", type=int, default=5000)
args = parser.parse_args()
cwd = os.getcwd()
#Tensorboard SummaryWriter setup
if not os.path.exists(cwd + args.log_dir):
os.makedirs(cwd + args.log_dir)
writer = SummaryWriter(cwd + args.log_dir)
if not os.path.exists(cwd + args.save_dir):
os.makedirs(cwd + args.save_dir)
if args.gpu >= 0:
device = torch.device("cuda:{}".format(args.gpu))
else:
device = torch.device("cpu")
data_train = Places2Data(args.train_path, args.mask_path)
data_size = len(data_train)
print("Loaded training dataset with {} samples and {} masks".format(data_size, data_train.num_masks))
assert(data_size % args.batch_size == 0)
iters_per_epoch = data_size // args.batch_size
# data_val = Places2Data(args.val_path, args.mask_path)
# print("Loaded validation dataset...")
# Move model to gpu prior to creating optimizer, since parameters become different objects after loading
model = PartialConvUNet().to(device)
print("Loaded model to device...")
# Set the fine tune learning rate if necessary
if args.fine_tune:
lr = args.fine_tune_lr
model.freeze_enc_bn = True
else:
lr = args.lr
# Adam optimizer proposed in: "Adam: A Method for Stochastic Optimization"
# filters the model parameters for those with requires_grad == True
optimizer = torch.optim.Adam(filter(requires_grad, model.parameters()), lr=lr)
print("Setup Adam optimizer...")
# Loss function
# Moves vgg16 model to gpu, used for feature map in loss function
loss_func = CalculateLoss().to(device)
print("Setup loss function...")
# Resume training on model
if args.load_model:
assert os.path.isfile(cwd + args.save_dir + args.load_model)
filename = cwd + args.save_dir + args.load_model
checkpoint_dict = torch.load(filename)
model.load_state_dict(checkpoint_dict["model"])
optimizer.load_state_dict(checkpoint_dict["optimizer"])
print("Resume training on model:{}".format(args.load_model))
# Load all parameters to gpu
model = model.to(device)
for state in optimizer.state.values():
for key, value in state.items():
if isinstance(value, torch.Tensor):
state[key] = value.to(device)
for epoch in range(0, args.epochs):
iterator_train = iter(data.DataLoader(data_train,
batch_size=args.batch_size,
num_workers=args.num_workers,
sampler=SubsetSampler(0, data_size)))
# TRAINING LOOP
print("\nEPOCH:{} of {} - starting training loop from iteration:0 to iteration:{}\n".format(epoch, args.epochs, iters_per_epoch))
for i in tqdm(range(0, iters_per_epoch)):
# Sets model to train mode
model.train()
# Gets the next batch of images
image, mask, gt = [x.to(device) for x in next(iterator_train)]
# Forward-propagates images through net
# Mask is also propagated, though it is usually gone by the decoding stage
output = model(image, mask)
loss_dict = loss_func(image, mask, output, gt)
loss = 0.0
# sums up each loss value
for key, value in loss_dict.items():
loss += value
if (i + 1) % args.log_interval == 0:
writer.add_scalar(key, value.item(), (epoch * iters_per_epoch) + i + 1)
writer.file_writer.flush()
# Resets gradient accumulator in optimizer
optimizer.zero_grad()
# back-propogates gradients through model weights
loss.backward()
# updates the weights
optimizer.step()
# Save model
if (i + 1) % args.save_interval == 0 or (i + 1) == iters_per_epoch:
filename = cwd + args.save_dir + "/model_e{}_i{}.pth".format(epoch, i + 1)
state = {"model": model.state_dict(), "optimizer": optimizer.state_dict()}
torch.save(state, filename)
writer.close()