@@ -1134,6 +1134,46 @@ def c(self):
1134
1134
)
1135
1135
self .assertIsInstance (reloaded_model , new_cls )
1136
1136
1137
+ @test_combinations .generate (test_combinations .combine (mode = ["eager" ]))
1138
+ def test_custom_sequential_registered_no_scope (self ):
1139
+ @object_registration .register_keras_serializable (package = "my_package" )
1140
+ class MyDense (keras .layers .Dense ):
1141
+ def __init__ (self , units , ** kwargs ):
1142
+ super ().__init__ (units , ** kwargs )
1143
+
1144
+ input_shape = [1 ]
1145
+ inputs = keras .Input (shape = input_shape )
1146
+ custom_layer = MyDense (1 )
1147
+ saved_model_dir = self ._save_model_dir ()
1148
+ save_format = test_utils .get_save_format ()
1149
+
1150
+ model = keras .Sequential (layers = [inputs , custom_layer ])
1151
+ model .save (saved_model_dir , save_format = save_format )
1152
+ loaded_model = keras .models .load_model (saved_model_dir )
1153
+
1154
+ x = tf .constant ([5 ])
1155
+ self .assertAllEqual (model (x ), loaded_model (x ))
1156
+
1157
+ @test_combinations .generate (test_combinations .combine (mode = ["eager" ]))
1158
+ def test_custom_functional_registered_no_scope (self ):
1159
+ @object_registration .register_keras_serializable (package = "my_package" )
1160
+ class MyDense (keras .layers .Dense ):
1161
+ def __init__ (self , units , ** kwargs ):
1162
+ super ().__init__ (units , ** kwargs )
1163
+
1164
+ saved_model_dir = self ._save_model_dir ()
1165
+ save_format = test_utils .get_save_format ()
1166
+ input_shape = [1 ]
1167
+ inputs = keras .Input (shape = input_shape )
1168
+ outputs = MyDense (1 )(inputs )
1169
+ model = keras .Model (inputs , outputs )
1170
+
1171
+ model .save (saved_model_dir , save_format = save_format )
1172
+ loaded_model = keras .models .load_model (saved_model_dir )
1173
+
1174
+ x = tf .constant ([5 ])
1175
+ self .assertAllEqual (model (x ), loaded_model (x ))
1176
+
1137
1177
@test_combinations .generate (test_combinations .combine (mode = ["eager" ]))
1138
1178
def test_shared_objects (self ):
1139
1179
class OuterLayer (keras .layers .Layer ):
@@ -1222,7 +1262,6 @@ def _get_all_keys_recursive(dict_or_iterable):
1222
1262
with object_registration .CustomObjectScope (
1223
1263
{"OuterLayer" : OuterLayer , "InnerLayer" : InnerLayer }
1224
1264
):
1225
-
1226
1265
# Test saving and loading to disk
1227
1266
save_format = test_utils .get_save_format ()
1228
1267
saved_model_dir = self ._save_model_dir ()
0 commit comments