-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathoreba_3d_cnn.py
65 lines (51 loc) · 2.23 KB
/
oreba_3d_cnn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
"""3D CNN Model"""
import tensorflow as tf
import oreba_building_blocks
SCOPE = "oreba_3d_cnn"
DEFAULT_DTYPE = tf.float32
CASTABLE_TYPES = (tf.float16,)
ALLOWED_TYPES = (DEFAULT_DTYPE,) + CASTABLE_TYPES
class Model(object):
"""Base class for building 3d convolutional network."""
def __init__(self, params):
"""Create a model to learn features on an object of the dimensions
[seq_length, width, depth, channels].
Args:
params: Hyperparameters.
"""
self.params = params
self.dtype = params.dtype
def _custom_dtype_getter(self, getter, name, shape=None,
dtype=DEFAULT_DTYPE, *args, **kwargs):
"""Creates variables in fp32, then casts to fp16 if necessary."""
if dtype in CASTABLE_TYPES:
var = getter(name, shape, tf.float32, *args, **kwargs)
return tf.cast(var, dtype=dtype, name=name + '_cast')
else:
return getter(name, shape, dtype, *args, **kwargs)
def __call__(self, inputs, is_training, scope=SCOPE):
"""Add operations to learn features on a batch of image sequences.
Args:
inputs: A tensor representing a batch of input image sequences.
is_training: A boolean representing whether training is active.
Returns:
A tensor with shape [batch_size, num_classes]
"""
with tf.compat.v1.variable_scope(scope, custom_getter=self._custom_dtype_getter):
# Convert to channels_first if necessary (performance boost)
if self.params.data_format == 'channels_first':
inputs = tf.transpose(a=inputs, perm=[0, 4, 1, 2, 3])
inputs = oreba_building_blocks.conv3d_layers(
inputs=inputs,
is_training=is_training,
params=self.params)
inputs = tf.keras.layers.Flatten()(inputs)
inputs = oreba_building_blocks.dense_layer(
inputs=inputs,
is_training=is_training,
params=self.params)
logits = oreba_building_blocks.class_layer(
inputs=inputs,
is_training=is_training,
params=self.params)
return logits, inputs