-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathexample_train.py
39 lines (30 loc) · 1.06 KB
/
example_train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
# example_train
# author: Duncan Tilley
# This script shows how to use the PPN module to train a model.
import tensorflow.compat.v1 as tf
import numpy as np
import ppn.config
import ppn.model
import ppn.data
tf.logging.set_verbosity(tf.logging.ERROR)
# set these for reproducible results
#np.random.seed(1)
#tf.set_random_seed(1)
# create the base and model configs
config = ppn.config.ppn_config()
resnet_config = ppn.config.resnet_config()
# create the datasets
data = ppn.data.Data(patch_images=True)
train, val, test = data.split_data([0.7, 0.15, 0.15], augment=[True, False, False], shuffle=False)
print('\nCreating data labels...\n')
train = ppn.data.create_labeled_set(train, config)
val = ppn.data.create_labeled_set(val, config)
test = ppn.data.create_labeled_set(test, config)
# create the PPN model
print('\nCreating model...\n')
model = ppn.model.PpnModel(ppn.model.get_resnet_constructor(resnet_config), config)
# train the model on the dataset
model.train(train, val, config)
# reload the best weights and evaluate
model.load_weights(config)
model.test(test, config)