Skip to content

Commit

Permalink
Fix weird broken test that never ran (#827)
Browse files Browse the repository at this point in the history
Just stumbled across this, we had a test that didn't have test in the
name, was never run, and was very broken if you tried to run it.
  • Loading branch information
mattdangerw authored Aug 31, 2023
1 parent 5caec45 commit 7a096ef
Showing 1 changed file with 16 additions and 22 deletions.
38 changes: 16 additions & 22 deletions keras_core/saving/serialization_lib_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,18 +49,10 @@ def get_config(self):
}


class WrapperLayer(keras_core.layers.Layer):
def __init__(self, layer, **kwargs):
super().__init__(**kwargs)
self.layer = layer

class WrapperLayer(keras_core.layers.Wrapper):
def call(self, x):
return self.layer(x)

def get_config(self):
config = super().get_config()
return {"layer": self.layer, **config}


class SerializationLibTest(testing.TestCase):
def roundtrip(self, obj, custom_objects=None, safe_mode=True):
Expand Down Expand Up @@ -210,20 +202,22 @@ def test_dict_inputs_outputs(self):
self.assertAllClose(original_output["foo"], restored_output["foo"])
self.assertAllClose(original_output["bar"], restored_output["bar"])

def shared_inner_layer(self):
input_1 = keras_core.Input((2,))
input_2 = keras_core.Input((2,))
shared_layer = keras_core.layers.Dense(1)
output_1 = shared_layer(input_1)
wrapper_layer = WrapperLayer(shared_layer)
output_2 = wrapper_layer(input_2)
model = keras_core.Model([input_1, input_2], [output_1, output_2])
_, new_model, _ = self.roundtrip(
model, custom_objects={"WrapperLayer": WrapperLayer}
)
@pytest.mark.requires_trainable_backend
def test_shared_inner_layer(self):
with serialization_lib.ObjectSharingScope():
input_1 = keras_core.Input((2,))
input_2 = keras_core.Input((2,))
shared_layer = keras_core.layers.Dense(1)
output_1 = shared_layer(input_1)
wrapper_layer = WrapperLayer(shared_layer)
output_2 = wrapper_layer(input_2)
model = keras_core.Model([input_1, input_2], [output_1, output_2])
_, new_model, _ = self.roundtrip(
model, custom_objects={"WrapperLayer": WrapperLayer}
)

self.assertIs(model.layers[2], model.layers[3].layer)
self.assertIs(new_model.layers[2], new_model.layers[3].layer)
self.assertIs(model.layers[2], model.layers[3].layer)
self.assertIs(new_model.layers[2], new_model.layers[3].layer)

@pytest.mark.requires_trainable_backend
def test_functional_subclass(self):
Expand Down

0 comments on commit 7a096ef

Please sign in to comment.