-
Notifications
You must be signed in to change notification settings - Fork 0
/
train_model.py
37 lines (27 loc) · 1.29 KB
/
train_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
import keras
from base_model import BaseModel
from keras.preprocessing.image import ImageDataGenerator
import model_file
class TrainModel:
def BeginTraining(epochs, batch_size, num_classes, dataloader, model_file, weight_file, load_weights=False, load_model=False):
train_gen = dataloader.TrainGen()
val_gen = dataloader.ValGen()
x_train, y_train = train_gen.next()
# either load or create the model
if load_model:
model = model_file.LoadModel('/models/' + model_file)
else:
model = BaseModel().CreateModel(x_train.shape[1:], num_classes)
# summarize model.
model.summary()
if load_weights:
model.load_weights('./weights/' + weight_file)
model.compile(loss='categorical_crossentropy',
optimizer=keras.optimizers.SGD(lr=0.02), #optimizer=keras.optimizers.SGD(lr=0.02)
metrics=['accuracy'])
step_size_train = train_gen.n//train_gen.batch_size
step_size_valid = val_gen.n//val_gen.batch_size
model.fit_generator(train_gen, epochs=epochs,
steps_per_epoch=step_size_train, validation_data=val_gen,
validation_steps=step_size_valid, shuffle=True, verbose=1)
return model