diff --git a/keras/layers/layer.py b/keras/layers/layer.py index a19a29f7f22..db91e349b98 100644 --- a/keras/layers/layer.py +++ b/keras/layers/layer.py @@ -635,15 +635,20 @@ def non_trainable_weights(self): return self.weights return [v for v in self.weights if not v.trainable] + @property + def metrics(self): + """List of all metrics.""" + metrics = list(self._metrics) + for layer in self._layers: + metrics.extend(layer.metrics) + return metrics + @property def metrics_variables(self): """List of all metric variables.""" vars = [] - for metric in self._metrics: + for metric in self.metrics: vars.extend(metric.variables) - for layer in self._layers: - for metric in layer._metrics: - vars.extend(metric.variables) return vars def get_weights(self): diff --git a/keras/layers/layer_test.py b/keras/layers/layer_test.py index ab93d808c72..0e8ca4548df 100644 --- a/keras/layers/layer_test.py +++ b/keras/layers/layer_test.py @@ -176,7 +176,7 @@ def call(self, x): self.assertAllClose(layer.variables[1], [10, 1]) def test_layer_tracking(self): - class NestedLayer(layers.Layer): + class LayerWithDenseLayers(layers.Layer): def __init__(self, units): super().__init__() self.dense1 = layers.Dense(units) @@ -185,6 +185,7 @@ def __init__(self, units): } self.layer_list = [layers.Dense(units)] self.units = units + self.seed_generator = backend.random.SeedGenerator(seed=1) def build(self, input_shape): self.layer_list.append(layers.Dense(self.units)) @@ -196,24 +197,31 @@ def call(self, x): x = self.layer_list[1](x) return x - class DoubleNestedLayer(layers.Layer): - def __init__(self, units): + class ParentLayer(layers.Layer): + def __init__(self, inner_layer): super().__init__() - self.inner_layer = NestedLayer(units) + self.inner_layer = inner_layer def call(self, x): return self.inner_layer(x) - layer = NestedLayer(3) + layer = LayerWithDenseLayers(3) layer.build((1, 3)) self.assertLen(layer._layers, 4) layer(np.zeros((1, 3))) + self.assertLen(layer.variables, 9) + self.assertLen(layer.weights, 8) + + layer = ParentLayer(LayerWithDenseLayers(3)) + self.assertLen(layer._layers, 1) + layer(np.zeros((1, 3))) + self.assertLen(layer.variables, 9) self.assertLen(layer.weights, 8) - layer = DoubleNestedLayer(3) + layer = ParentLayer(ParentLayer(LayerWithDenseLayers(3))) self.assertLen(layer._layers, 1) layer(np.zeros((1, 3))) - self.assertLen(layer.inner_layer.weights, 8) + self.assertLen(layer.variables, 9) self.assertLen(layer.weights, 8) def test_metric_tracking(self): @@ -229,32 +237,42 @@ def build(self, input_shape): def call(self, x): return self.dense(x) - class NestedLayerWithMetric(layers.Layer): - def __init__(self, units): + class ParentLayerWithMetric(layers.Layer): + def __init__(self, inner_layer): super().__init__() - self.layer_with_metric = LayerWithMetric(units) + self.inner_layer = inner_layer self.metric = metrics.MeanSquaredError(name="my_metric") def build(self, input_shape): - self.layer_with_metric.build(input_shape) + self.inner_layer.build(input_shape) def call(self, x): - return self.layer_with_metric(x) + return self.inner_layer(x) layer = LayerWithMetric(3) layer.build((1, 3)) + self.assertLen(layer.metrics, 1) self.assertLen(layer.metrics_variables, 2) self.assertLen(layer.trainable_variables, 2) self.assertLen(layer.non_trainable_variables, 0) - layer = NestedLayerWithMetric(3) + layer = ParentLayerWithMetric(LayerWithMetric(3)) layer.build((1, 3)) + self.assertLen(layer.metrics, 2) self.assertLen(layer.metrics_variables, 4) self.assertLen(layer.trainable_variables, 2) self.assertLen(layer.non_trainable_variables, 0) + layer = ParentLayerWithMetric(ParentLayerWithMetric(LayerWithMetric(3))) + layer.build((1, 3)) + + self.assertLen(layer.metrics, 3) + self.assertLen(layer.metrics_variables, 6) + self.assertLen(layer.trainable_variables, 2) + self.assertLen(layer.non_trainable_variables, 0) + def test_build_on_call(self): class LayerWithUnbuiltState(layers.Layer): def __init__(self, units): diff --git a/keras/trainers/trainer.py b/keras/trainers/trainer.py index 32d818e0b64..0ee156f9c5e 100644 --- a/keras/trainers/trainer.py +++ b/keras/trainers/trainer.py @@ -242,7 +242,7 @@ def run_eagerly(self, value): @property def metrics(self): metrics = [self._loss_tracker] if self.compiled else [] - metrics.extend(self._metrics[:]) + metrics.extend(super().metrics) if self.compiled and self._compile_metrics is not None: metrics += [self._compile_metrics] return metrics @@ -251,13 +251,6 @@ def metrics(self): def metrics_names(self): return [m.name for m in self.metrics] - @property - def metrics_variables(self): - vars = [] - for metric in self.metrics: - vars.extend(metric.variables) - return vars - def reset_metrics(self): for m in self.metrics: m.reset_state()