-
Notifications
You must be signed in to change notification settings - Fork 6
/
model.py
68 lines (53 loc) · 3.01 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import sys
from keras import applications
from keras.models import Model, load_model
from keras.layers import Input, InputLayer, Conv2D, Activation, LeakyReLU, Concatenate
from layers import BilinearUpSampling2D
from loss import depth_loss_function
def create_model(existing='', is_twohundred=False, is_halffeatures=True):
if len(existing) == 0:
print('Loading base model (DenseNet)..')
# Encoder Layers
if is_twohundred:
base_model = applications.DenseNet201(input_shape=(None, None, 3), include_top=False)
else:
base_model = applications.DenseNet169(input_shape=(None, None, 3), include_top=False)
print('Base model loaded.')
# Starting point for decoder
base_model_output_shape = base_model.layers[-1].output.shape
# Layer freezing?
for layer in base_model.layers: layer.trainable = True
# Starting number of decoder filters
if is_halffeatures:
decode_filters = int(int(base_model_output_shape[-1])/2)
else:
decode_filters = int(base_model_output_shape[-1])
# Define upsampling layer
def upproject(tensor, filters, name, concat_with):
up_i = BilinearUpSampling2D((2, 2), name=name+'_upsampling2d')(tensor)
up_i = Concatenate(name=name+'_concat')([up_i, base_model.get_layer(concat_with).output]) # Skip connection
up_i = Conv2D(filters=filters, kernel_size=3, strides=1, padding='same', name=name+'_convA')(up_i)
up_i = LeakyReLU(alpha=0.2)(up_i)
up_i = Conv2D(filters=filters, kernel_size=3, strides=1, padding='same', name=name+'_convB')(up_i)
up_i = LeakyReLU(alpha=0.2)(up_i)
return up_i
# Decoder Layers
decoder = Conv2D(filters=decode_filters, kernel_size=1, padding='same', input_shape=base_model_output_shape, name='conv2')(base_model.output)
decoder = upproject(decoder, int(decode_filters/2), 'up1', concat_with='pool3_pool')
decoder = upproject(decoder, int(decode_filters/4), 'up2', concat_with='pool2_pool')
decoder = upproject(decoder, int(decode_filters/8), 'up3', concat_with='pool1')
decoder = upproject(decoder, int(decode_filters/16), 'up4', concat_with='conv1/relu')
if False: decoder = upproject(decoder, int(decode_filters/32), 'up5', concat_with='input_1')
# Extract depths (final layer)
conv3 = Conv2D(filters=1, kernel_size=3, strides=1, padding='same', name='conv3')(decoder)
# Create the model
model = Model(inputs=base_model.input, outputs=conv3)
else:
# Load model from file
if not existing.endswith('.h5'):
sys.exit('Please provide a correct model file when using [existing] argument.')
custom_objects = {'BilinearUpSampling2D': BilinearUpSampling2D, 'depth_loss_function': depth_loss_function}
model = load_model(existing, custom_objects=custom_objects)
print('\nExisting model loaded.\n')
print('Model created.')
return model