Skip to content

Commit

Permalink
MiniXception model instantiated programatically, only weights loaded …
Browse files Browse the repository at this point in the history
…from hdf5 file for unit tests to pass with keras 2.0 and 3.0
  • Loading branch information
Manojkumarmuru committed May 11, 2024
1 parent 524b6e0 commit 600f311
Show file tree
Hide file tree
Showing 3 changed files with 111 additions and 5 deletions.
114 changes: 111 additions & 3 deletions paz/models/classification/xception.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
from tensorflow.keras.layers import GlobalAveragePooling2D
from tensorflow.keras import Model
from tensorflow.keras.regularizers import l2
from tensorflow.keras.models import load_model
from tensorflow.keras.utils import get_file
from keras import layers


URL = 'https://github.com/oarriaga/altamira-data/releases/download/v0.6/'
Expand Down Expand Up @@ -84,6 +84,113 @@ def build_xception(
return model


def build_minixception(input_shape, num_classes, l2_reg=0.01):
"""Function for instantiating an Mini-Xception model.
# Arguments
input_shape: List corresponding to the input shape of the model.
num_classes: Integer.
l2_reg. Float. L2 regularization used in the convolutional kernels.
# Returns
Tensorflow-Keras model.
"""

regularization = l2(l2_reg)

# base
img_input = Input(input_shape)
x = Conv2D(5, (3, 3), strides=(1, 1), kernel_regularizer=regularization,
use_bias=False)(img_input)
x = BatchNormalization()(x)
x = Activation('relu')(x)
x = Conv2D(8, (3, 3), strides=(1, 1), kernel_regularizer=regularization,
use_bias=False)(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)

# module 1
residual = Conv2D(16, (1, 1), strides=(2, 2),
padding='same', use_bias=False)(x)
residual = BatchNormalization()(residual)

x = SeparableConv2D(16, (3, 3), padding='same',
depthwise_regularizer=regularization,
use_bias=False)(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
x = SeparableConv2D(16, (3, 3), padding='same',
depthwise_regularizer=regularization,
use_bias=False)(x)
x = BatchNormalization()(x)

x = MaxPooling2D((3, 3), strides=(2, 2), padding='same')(x)
x = layers.add([x, residual])

# module 2
residual = Conv2D(32, (1, 1), strides=(2, 2),
padding='same', use_bias=False)(x)
residual = BatchNormalization()(residual)

x = SeparableConv2D(32, (3, 3), padding='same',
depthwise_regularizer=regularization,
use_bias=False)(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
x = SeparableConv2D(32, (3, 3), padding='same',
depthwise_regularizer=regularization,
use_bias=False)(x)
x = BatchNormalization()(x)

x = MaxPooling2D((3, 3), strides=(2, 2), padding='same')(x)
x = layers.add([x, residual])

# module 3
residual = Conv2D(64, (1, 1), strides=(2, 2),
padding='same', use_bias=False)(x)
residual = BatchNormalization()(residual)

x = SeparableConv2D(64, (3, 3), padding='same',
depthwise_regularizer=regularization,
use_bias=False)(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
x = SeparableConv2D(64, (3, 3), padding='same',
depthwise_regularizer=regularization,
use_bias=False)(x)
x = BatchNormalization()(x)

x = MaxPooling2D((3, 3), strides=(2, 2), padding='same')(x)
x = layers.add([x, residual])

# module 4
residual = Conv2D(128, (1, 1), strides=(1, 1),
padding='same', use_bias=False)(x)
residual = BatchNormalization()(residual)

x = SeparableConv2D(128, (3, 3), padding='same',
depthwise_regularizer=regularization,
use_bias=False)(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
x = SeparableConv2D(128, (3, 3), padding='same',
depthwise_regularizer=regularization,
use_bias=False)(x)
x = BatchNormalization()(x)

# x = MaxPooling2D((3, 3), strides=(1, 1), padding='same')(x)
x = layers.add([x, residual])

x = Conv2D(num_classes, (3, 3),
# kernel_regularizer=regularization,
padding='same')(x)
x = GlobalAveragePooling2D()(x)
output = Activation('softmax', name='predictions')(x)

model = Model(img_input, output)
return model


def MiniXception(input_shape, num_classes, weights=None):
"""Build MiniXception (see references).
Expand All @@ -101,9 +208,10 @@ def MiniXception(input_shape, num_classes, weights=None):
Gender Classification](https://arxiv.org/abs/1710.07557)
"""
if weights == 'FER':
filename = 'fer2013_mini_XCEPTION.119-0.65.hdf5'
filename = 'fer2013_mini_XCEPTION.hdf5'
path = get_file(filename, URL + filename, cache_subdir='paz/models')
model = load_model(path)
model = build_minixception(input_shape, num_classes)
model.load_weights(path)
else:
stem_kernels = [32, 64]
block_data = [128, 128, 256, 256, 512, 512, 1024]
Expand Down
1 change: 0 additions & 1 deletion tests/paz/pipelines/classification_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ def labeled_emotion():
return 'happy'


@pytest.mark.skip()
def test_MiniXceptionFER(image_with_face, labeled_emotion, labeled_scores):
classifier = MiniXceptionFER()
inferences = classifier(image_with_face)
Expand Down
1 change: 0 additions & 1 deletion tests/paz/pipelines/detection_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,6 @@ def test_HaarCascadeFrontalFace(image_with_faces, boxes_HaarCascadeFace):
assert_inferences(detector, image_with_faces, boxes_HaarCascadeFace)


@pytest.mark.skip()
def test_DetectMiniXceptionFER(image_with_faces, boxes_MiniXceptionFER):
cv2.ocl.setUseOpenCL(False)
cv2.setNumThreads(1)
Expand Down

0 comments on commit 600f311

Please sign in to comment.