Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Models and layers now return owned metrics recursively. #19522

Merged
merged 1 commit into from
Apr 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 9 additions & 4 deletions keras/layers/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -635,15 +635,20 @@ def non_trainable_weights(self):
return self.weights
return [v for v in self.weights if not v.trainable]

@property
def metrics(self):
"""List of all metrics."""
metrics = list(self._metrics)
for layer in self._layers:
metrics.extend(layer.metrics)
return metrics

@property
def metrics_variables(self):
"""List of all metric variables."""
vars = []
for metric in self._metrics:
for metric in self.metrics:
vars.extend(metric.variables)
for layer in self._layers:
for metric in layer._metrics:
vars.extend(metric.variables)
return vars

def get_weights(self):
Expand Down
44 changes: 31 additions & 13 deletions keras/layers/layer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ def call(self, x):
self.assertAllClose(layer.variables[1], [10, 1])

def test_layer_tracking(self):
class NestedLayer(layers.Layer):
class LayerWithDenseLayers(layers.Layer):
def __init__(self, units):
super().__init__()
self.dense1 = layers.Dense(units)
Expand All @@ -185,6 +185,7 @@ def __init__(self, units):
}
self.layer_list = [layers.Dense(units)]
self.units = units
self.seed_generator = backend.random.SeedGenerator(seed=1)

def build(self, input_shape):
self.layer_list.append(layers.Dense(self.units))
Expand All @@ -196,24 +197,31 @@ def call(self, x):
x = self.layer_list[1](x)
return x

class DoubleNestedLayer(layers.Layer):
def __init__(self, units):
class ParentLayer(layers.Layer):
def __init__(self, inner_layer):
super().__init__()
self.inner_layer = NestedLayer(units)
self.inner_layer = inner_layer

def call(self, x):
return self.inner_layer(x)

layer = NestedLayer(3)
layer = LayerWithDenseLayers(3)
layer.build((1, 3))
self.assertLen(layer._layers, 4)
layer(np.zeros((1, 3)))
self.assertLen(layer.variables, 9)
self.assertLen(layer.weights, 8)

layer = ParentLayer(LayerWithDenseLayers(3))
self.assertLen(layer._layers, 1)
layer(np.zeros((1, 3)))
self.assertLen(layer.variables, 9)
self.assertLen(layer.weights, 8)

layer = DoubleNestedLayer(3)
layer = ParentLayer(ParentLayer(LayerWithDenseLayers(3)))
self.assertLen(layer._layers, 1)
layer(np.zeros((1, 3)))
self.assertLen(layer.inner_layer.weights, 8)
self.assertLen(layer.variables, 9)
self.assertLen(layer.weights, 8)

def test_metric_tracking(self):
Expand All @@ -229,32 +237,42 @@ def build(self, input_shape):
def call(self, x):
return self.dense(x)

class NestedLayerWithMetric(layers.Layer):
def __init__(self, units):
class ParentLayerWithMetric(layers.Layer):
def __init__(self, inner_layer):
super().__init__()
self.layer_with_metric = LayerWithMetric(units)
self.inner_layer = inner_layer
self.metric = metrics.MeanSquaredError(name="my_metric")

def build(self, input_shape):
self.layer_with_metric.build(input_shape)
self.inner_layer.build(input_shape)

def call(self, x):
return self.layer_with_metric(x)
return self.inner_layer(x)

layer = LayerWithMetric(3)
layer.build((1, 3))

self.assertLen(layer.metrics, 1)
self.assertLen(layer.metrics_variables, 2)
self.assertLen(layer.trainable_variables, 2)
self.assertLen(layer.non_trainable_variables, 0)

layer = NestedLayerWithMetric(3)
layer = ParentLayerWithMetric(LayerWithMetric(3))
layer.build((1, 3))

self.assertLen(layer.metrics, 2)
self.assertLen(layer.metrics_variables, 4)
self.assertLen(layer.trainable_variables, 2)
self.assertLen(layer.non_trainable_variables, 0)

layer = ParentLayerWithMetric(ParentLayerWithMetric(LayerWithMetric(3)))
layer.build((1, 3))

self.assertLen(layer.metrics, 3)
self.assertLen(layer.metrics_variables, 6)
self.assertLen(layer.trainable_variables, 2)
self.assertLen(layer.non_trainable_variables, 0)

def test_build_on_call(self):
class LayerWithUnbuiltState(layers.Layer):
def __init__(self, units):
Expand Down
9 changes: 1 addition & 8 deletions keras/trainers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ def run_eagerly(self, value):
@property
def metrics(self):
metrics = [self._loss_tracker] if self.compiled else []
metrics.extend(self._metrics[:])
metrics.extend(super().metrics)
if self.compiled and self._compile_metrics is not None:
metrics += [self._compile_metrics]
return metrics
Expand All @@ -251,13 +251,6 @@ def metrics(self):
def metrics_names(self):
return [m.name for m in self.metrics]

@property
def metrics_variables(self):
vars = []
for metric in self.metrics:
vars.extend(metric.variables)
return vars

def reset_metrics(self):
for m in self.metrics:
m.reset_state()
Expand Down