-
Notifications
You must be signed in to change notification settings - Fork 19
/
Copy pathtrain.py
108 lines (88 loc) · 4.54 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
from tools import dataloader
from tensorflow.keras import optimizers
from tools.callbacks import LearningRateScheduler
from tools.learning_rate import lr_decays_func
from tools.metrics import MeanIoU
from tensorflow.keras.callbacks import ModelCheckpoint,ReduceLROnPlateau
from builders import builder
import tensorflow as tf
import argparse
import os
from config import Config
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
cfg = Config('train')
# config = tf.compat.v1.ConfigProto()
# config.gpu_options.allow_growth = True
# tf.compat.v1.keras.backend.set_session(tf.compat.v1.Session(config=config))
def args_parse():
parser = argparse.ArgumentParser()
parser.add_argument('--model', help='Choose the semantic segmentation methods.', type=str, default='DeepLabV3Plus')
parser.add_argument('--backBone', help='Choose the backbone model.', type=str, default='DenseNet121')
parser.add_argument('--num_epochs', help='num_epochs', type=int, default=cfg.epoch)
parser.add_argument('--weights', help='The path of weights to be loaded.', type=str, default='weights')
parser.add_argument('--lr_scheduler', help='The strategy to schedule learning rate.', type=str,
default='cosine_decay',
choices=['step_decay', 'poly_decay', 'cosine_decay'])
parser.add_argument('--lr_warmup', help='Whether to use lr warm up.', type=bool, default=False)
parser.add_argument('--learning_rate', help='learning_rate.', type=float, default=cfg.lr)
args = parser.parse_args()
return args
def train(args):
filepath = "weights-{epoch:03d}-{val_loss:.4f}-{val_mean_iou:.4f}.h5"
weights_dir = os.path.join(args.weights, args.backBone + '_' + args.model)
cfg.check_folder(weights_dir)
model_weights = os.path.join(weights_dir, filepath)
# build the model
model, base_model = builder(cfg.n_classes, (256, 256), args.model, args.backBone)
model.summary()
# compile the model
sgd = optimizers.SGD(lr=0.0001, momentum=0.9)
adam = optimizers.Adam(lr=cfg.lr, beta_1=0.9, beta_2=0.999, epsilon=1e-8)
model.compile(optimizer=sgd, loss='categorical_crossentropy', metrics=[MeanIoU(cfg.n_classes)])
# checkpoint setting
model_checkpoint = ModelCheckpoint(model_weights, monitor='val_loss', save_best_only=True, mode='auto')
# learning rate scheduler setting
lr_decay = lr_decays_func(args.lr_scheduler, args.learning_rate, args.num_epochs, args.lr_warmup)
learning_rate_scheduler = LearningRateScheduler(lr_decay, args.learning_rate, args.lr_warmup, cfg.steps_per_epoch,
num_epochs=args.num_epochs, verbose=1)
Reduce_LR = ReduceLROnPlateau(monitor='val_mean_iou', mode='max', patience=2, verbose=1, factor=0.2, min_lr=1e-7)
# callbacks = [model_checkpoint]
callbacks = [model_checkpoint, Reduce_LR]
# training...
train_set = dataloader.train_data_generator(cfg.train_data_path, cfg.train_label_path, cfg.batch_size,
cfg.n_classes, cfg.data_augment)
val_set = dataloader.val_data_generator(cfg.val_data_path, cfg.val_label_path, cfg.batch_size, cfg.n_classes)
start_epoch = 0
if os.path.exists(weights_dir) and os.listdir(weights_dir):
a = sorted(file for file in os.listdir(weights_dir))
model.load_weights(weights_dir + '/' + a[-1], by_name=True)
# if load success, output info
print('loaded :' + '-' * 8 + weights_dir + '/' + a[-1])
start_epoch = int(a[-1][8:11])
for layer in model.layers:
layer.trainable=False
for i in range(-1, -27, -1):
model.layers[i].trainable = True
model.summary()
print("start_epoch: ", start_epoch)
if start_epoch == 0:
backbone_pretrained_path = "backBone_pretrained_weights/" + 'densenet121_weights_tf_dim_ordering_tf_kernels_notop.h5'
model.load_weights(backbone_pretrained_path, by_name=True)
print(f"loaded : {backbone_pretrained_path}" )
print(len(model.layers))
for layer in model.layers:
layer.trainable=False
for i in range(-1, -27, -1):
model.layers[i].trainable = True
model.summary()
model.fit(train_set,
steps_per_epoch=cfg.steps_per_epoch,
epochs=args.num_epochs,
callbacks=callbacks,
validation_data=val_set,
validation_steps=cfg.validation_steps,
max_queue_size= cfg.batch_size,
initial_epoch=start_epoch)
if __name__ == '__main__':
args = args_parse()
train(args)