Skip to content

Commit b3ffea6

Browse files
authored
Cherrypick Sequential serialization bug fix for r2.13 (#18258)
1 parent 87db506 commit b3ffea6

File tree

3 files changed

+48
-4
lines changed

3 files changed

+48
-4
lines changed

keras/engine/sequential.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from keras.engine import training
2727
from keras.engine import training_utils
2828
from keras.saving import serialization_lib
29+
from keras.saving.legacy import serialization as legacy_serialization
2930
from keras.saving.legacy.saved_model import model_serialization
3031
from keras.utils import generic_utils
3132
from keras.utils import layer_utils
@@ -441,14 +442,15 @@ def compute_mask(self, inputs, mask):
441442

442443
def get_config(self):
443444
layer_configs = []
445+
serialize_obj_fn = serialization_lib.serialize_keras_object
446+
if getattr(self, "use_legacy_config", None):
447+
serialize_obj_fn = legacy_serialization.serialize_keras_object
444448
for layer in super().layers:
445449
# `super().layers` include the InputLayer if available (it is
446450
# filtered out of `self.layers`). Note that
447451
# `self._self_tracked_trackables` is managed by the tracking
448452
# infrastructure and should not be used.
449-
layer_configs.append(
450-
serialization_lib.serialize_keras_object(layer)
451-
)
453+
layer_configs.append(serialize_obj_fn(layer))
452454
config = training.Model.get_config(self)
453455
config["name"] = self.name
454456
config["layers"] = copy.deepcopy(layer_configs)

keras/saving/legacy/hdf5_format.py

+3
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,9 @@ def save_model_to_hdf5(model, filepath, overwrite=True, include_optimizer=True):
8181
"import h5py."
8282
)
8383

84+
# Ensures that all models saved in HDF5 format follow the old serialization
85+
model.use_legacy_config = True
86+
8487
# TODO(psv) Add warning when we save models that contain non-serializable
8588
# entities like metrics added using `add_metric` and losses added using
8689
# `add_loss.`

keras/saving/legacy/save_test.py

+40-1
Original file line numberDiff line numberDiff line change
@@ -1134,6 +1134,46 @@ def c(self):
11341134
)
11351135
self.assertIsInstance(reloaded_model, new_cls)
11361136

1137+
@test_combinations.generate(test_combinations.combine(mode=["eager"]))
1138+
def test_custom_sequential_registered_no_scope(self):
1139+
@object_registration.register_keras_serializable(package="my_package")
1140+
class MyDense(keras.layers.Dense):
1141+
def __init__(self, units, **kwargs):
1142+
super().__init__(units, **kwargs)
1143+
1144+
input_shape = [1]
1145+
inputs = keras.Input(shape=input_shape)
1146+
custom_layer = MyDense(1)
1147+
saved_model_dir = self._save_model_dir()
1148+
save_format = test_utils.get_save_format()
1149+
1150+
model = keras.Sequential(layers=[inputs, custom_layer])
1151+
model.save(saved_model_dir, save_format=save_format)
1152+
loaded_model = keras.models.load_model(saved_model_dir)
1153+
1154+
x = tf.constant([5])
1155+
self.assertAllEqual(model(x), loaded_model(x))
1156+
1157+
@test_combinations.generate(test_combinations.combine(mode=["eager"]))
1158+
def test_custom_functional_registered_no_scope(self):
1159+
@object_registration.register_keras_serializable(package="my_package")
1160+
class MyDense(keras.layers.Dense):
1161+
def __init__(self, units, **kwargs):
1162+
super().__init__(units, **kwargs)
1163+
1164+
saved_model_dir = self._save_model_dir()
1165+
save_format = test_utils.get_save_format()
1166+
input_shape = [1]
1167+
inputs = keras.Input(shape=input_shape)
1168+
outputs = MyDense(1)(inputs)
1169+
model = keras.Model(inputs, outputs)
1170+
1171+
model.save(saved_model_dir, save_format=save_format)
1172+
loaded_model = keras.models.load_model(saved_model_dir)
1173+
1174+
x = tf.constant([5])
1175+
self.assertAllEqual(model(x), loaded_model(x))
1176+
11371177
@test_combinations.generate(test_combinations.combine(mode=["eager"]))
11381178
def test_shared_objects(self):
11391179
class OuterLayer(keras.layers.Layer):
@@ -1222,7 +1262,6 @@ def _get_all_keys_recursive(dict_or_iterable):
12221262
with object_registration.CustomObjectScope(
12231263
{"OuterLayer": OuterLayer, "InnerLayer": InnerLayer}
12241264
):
1225-
12261265
# Test saving and loading to disk
12271266
save_format = test_utils.get_save_format()
12281267
saved_model_dir = self._save_model_dir()

0 commit comments

Comments
 (0)