Skip to content

Commit

Permalink
Don't lock functional models (#823)
Browse files Browse the repository at this point in the history
* Don't lock functional models

* Add unit tests
  • Loading branch information
mattdangerw authored Aug 31, 2023
1 parent a59bfd9 commit 13fc468
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 0 deletions.
6 changes: 6 additions & 0 deletions keras_core/models/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,12 @@ def __init__(self, inputs, outputs, name=None, **kwargs):
output_layers = [x._keras_history[0] for x in self.outputs]
self.output_names = [x.name for x in output_layers]

def _lock_state(self):
# Unlike other layers, we allow Functional state to be mutable after
# build. E.g. to attach a layer to a model that is not part of the
# functional DAG.
pass

@property
def layers(self):
layers = []
Expand Down
10 changes: 10 additions & 0 deletions keras_core/models/functional_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,16 @@ def test_scalar_input(self):
out_val = model(in_val)
self.assertAllClose(out_val, np.ones((2, 3)))

@pytest.mark.requires_trainable_backend
def test_mutable_state(self):
inputs = Input(shape=(3,), batch_size=2, name="input")
x = layers.Dense(5)(inputs)
outputs = layers.Dense(5)(x)
model = Functional(inputs, outputs)
# Allow attaching state to a model that isn't directly part of the DAG.
# Most useful for functional subclasses.
model.extra_layer = layers.Dense(5)

@pytest.mark.requires_trainable_backend
def test_basic_flow_multi_output(self):
inputs = Input(shape=(3,), batch_size=2, name="input")
Expand Down

0 comments on commit 13fc468

Please sign in to comment.