From 24034230fc1e5eae18e8cafab9f556b034902958 Mon Sep 17 00:00:00 2001 From: Osma Suominen Date: Wed, 16 Jun 2021 15:03:06 +0300 Subject: [PATCH] Custom Keras layer (MeanLayer) instead of Lambda which is hard to serialize --- annif/backend/nn_ensemble.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/annif/backend/nn_ensemble.py b/annif/backend/nn_ensemble.py index eb55ca228..4f4bf4bad 100644 --- a/annif/backend/nn_ensemble.py +++ b/annif/backend/nn_ensemble.py @@ -9,7 +9,7 @@ from scipy.sparse import csr_matrix, csc_matrix import joblib import lmdb -from tensorflow.keras.layers import Input, Dense, Add, Flatten, Lambda, Dropout +from tensorflow.keras.layers import Input, Dense, Add, Flatten, Dropout, Layer from tensorflow.keras.models import Model, load_model from tensorflow.keras.utils import Sequence import tensorflow.keras.backend as K @@ -74,6 +74,12 @@ def __len__(self): return int(np.ceil(self._counter / self._batch_size)) +class MeanLayer(Layer): + """Custom Keras layer that calculates mean values along the 2nd axis.""" + def call(self, inputs): + return K.mean(inputs, axis=2) + + class NNEnsembleBackend( backend.AnnifLearningBackend, ensemble.BaseEnsembleBackend): @@ -112,7 +118,8 @@ def initialize(self): 'model file {} not found'.format(model_filename), backend_id=self.backend_id) self.debug('loading Keras model from {}'.format(model_filename)) - self._model = load_model(model_filename) + self._model = load_model(model_filename, + custom_objects={'MeanLayer': MeanLayer}) def _merge_hits_from_sources(self, hits_from_sources, params): score_vector = np.array([np.sqrt(hits.as_vector(subjects)) @@ -140,7 +147,7 @@ def _create_model(self, sources): kernel_initializer='zeros', bias_initializer='zeros')(drop_hidden) - mean = Lambda(lambda x: K.mean(x, axis=2))(inputs) + mean = MeanLayer()(inputs) predictions = Add()([mean, delta])