diff --git a/prometheus_client/metrics_core.py b/prometheus_client/metrics_core.py index 4fd0fcd7..4ab2af16 100644 --- a/prometheus_client/metrics_core.py +++ b/prometheus_client/metrics_core.py @@ -58,6 +58,15 @@ def __repr__(self): self.samples, ) + def _restricted_metric(self, names): + """Build a snapshot of a metric with samples restricted to a given set of names.""" + samples = [s for s in self.samples if s[0] in names] + if samples: + m = Metric(self.name, self.documentation, self.type) + m.samples = samples + return m + return None + class UnknownMetricFamily(Metric): """A single unknown metric and its samples. diff --git a/prometheus_client/registry.py b/prometheus_client/registry.py index fff1f98c..0eb003bb 100644 --- a/prometheus_client/registry.py +++ b/prometheus_client/registry.py @@ -94,28 +94,7 @@ def restricted_registry(self, names): Experimental.""" names = set(names) - collectors = set() - metrics = [] - with self._lock: - if 'target_info' in names and self._target_info: - metrics.append(self._target_info_metric()) - names.remove('target_info') - for name in names: - if name in self._names_to_collectors: - collectors.add(self._names_to_collectors[name]) - for collector in collectors: - for metric in collector.collect(): - samples = [s for s in metric.samples if s[0] in names] - if samples: - m = Metric(metric.name, metric.documentation, metric.type) - m.samples = samples - metrics.append(m) - - class RestrictedRegistry(object): - def collect(self): - return metrics - - return RestrictedRegistry() + return RestrictedRegistry(names, self) def set_target_info(self, labels): with self._lock: @@ -150,4 +129,16 @@ def get_sample_value(self, name, labels=None): return None +class RestrictedRegistry(object): + def __init__(self, names, registry): + self._name_set = set(names) + self._registry = registry + + def collect(self): + for metric in self._registry.collect(): + m = metric._restricted_metric(self._name_set) + if m: + yield m + + REGISTRY = CollectorRegistry(auto_describe=True) diff --git a/tests/test_core.py b/tests/test_core.py index dbb27033..d36aef0a 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -796,7 +796,19 @@ def test_restricted_registry(self): m = Metric('s', 'help', 'summary') m.samples = [Sample('s_sum', {}, 7)] - self.assertEqual([m], registry.restricted_registry(['s_sum']).collect()) + self.assertEqual([m], list(registry.restricted_registry(['s_sum']).collect())) + + def test_restricted_registry_adds_new_metrics(self): + registry = CollectorRegistry() + Counter('c_total', 'help', registry=registry) + + restricted_registry = registry.restricted_registry(['s_sum']) + + Summary('s', 'help', registry=registry).observe(7) + m = Metric('s', 'help', 'summary') + m.samples = [Sample('s_sum', {}, 7)] + + self.assertEqual([m], list(restricted_registry.collect())) def test_target_info_injected(self): registry = CollectorRegistry(target_info={'foo': 'bar'}) @@ -820,11 +832,11 @@ def test_target_info_restricted_registry(self): m = Metric('s', 'help', 'summary') m.samples = [Sample('s_sum', {}, 7)] - self.assertEqual([m], registry.restricted_registry(['s_sum']).collect()) + self.assertEqual([m], list(registry.restricted_registry(['s_sum']).collect())) m = Metric('target', 'Target metadata', 'info') m.samples = [Sample('target_info', {'foo': 'bar'}, 1)] - self.assertEqual([m], registry.restricted_registry(['target_info']).collect()) + self.assertEqual([m], list(registry.restricted_registry(['target_info']).collect())) if __name__ == '__main__':