From 05310bd3242ef7a02b63995c518f33fd281bc4b5 Mon Sep 17 00:00:00 2001 From: ElePT <57907331+ElePT@users.noreply.github.com> Date: Fri, 24 Feb 2023 17:23:46 +0100 Subject: [PATCH] Fix bug in backend primitives with `bound_pass_manager` (#9629) * Fix pass, add unit test * Fix black * Fix lint * Fix tests * Add reno * Move assert to within cm * black --------- Co-authored-by: ikkoham Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com> --- qiskit/primitives/backend_estimator.py | 5 ++- qiskit/primitives/backend_sampler.py | 11 ++++-- qiskit/primitives/base/base_estimator.py | 2 +- qiskit/primitives/base/base_sampler.py | 2 +- ...m-backend-primitives-98fd11c5e852501c.yaml | 7 ++++ .../primitives/test_backend_estimator.py | 39 +++++++++++++++++++ .../python/primitives/test_backend_sampler.py | 36 +++++++++++++++++ 7 files changed, 95 insertions(+), 7 deletions(-) create mode 100644 releasenotes/notes/fix-bound-pm-backend-primitives-98fd11c5e852501c.yaml diff --git a/qiskit/primitives/backend_estimator.py b/qiskit/primitives/backend_estimator.py index c2e9a424c6f7..1184d9420948 100644 --- a/qiskit/primitives/backend_estimator.py +++ b/qiskit/primitives/backend_estimator.py @@ -392,7 +392,10 @@ def _bound_pass_manager_run(self, circuits): if self._bound_pass_manager is None: return circuits else: - return self._bound_pass_manager.run(circuits) + output = self._bound_pass_manager.run(circuits) + if not isinstance(output, list): + output = [output] + return output def _paulis2inds(paulis: PauliList) -> list[int]: diff --git a/qiskit/primitives/backend_sampler.py b/qiskit/primitives/backend_sampler.py index 241e79ede911..6059ded4f8c6 100644 --- a/qiskit/primitives/backend_sampler.py +++ b/qiskit/primitives/backend_sampler.py @@ -16,7 +16,7 @@ import math from collections.abc import Sequence -from typing import Any, cast +from typing import Any from qiskit.circuit.quantumcircuit import QuantumCircuit from qiskit.providers.backend import BackendV1, BackendV2 @@ -154,7 +154,6 @@ def _call( for i, value in zip(circuits, parameter_values) ] bound_circuits = self._bound_pass_manager_run(bound_circuits) - # Run result, _metadata = _run_circuits(bound_circuits, self._backend, **run_options) return self._postprocessing(result, bound_circuits) @@ -164,7 +163,7 @@ def _postprocessing(self, result: Result, circuits: list[QuantumCircuit]) -> Sam shots = sum(counts[0].values()) probabilities = [] - metadata: list[dict[str, Any]] = [{}] * len(circuits) + metadata: list[dict[str, Any]] = [{} for _ in range(len(circuits))] for count in counts: prob_dist = {k: v / shots for k, v in count.int_outcomes().items()} probabilities.append( @@ -172,6 +171,7 @@ def _postprocessing(self, result: Result, circuits: list[QuantumCircuit]) -> Sam ) for metadatum in metadata: metadatum["shots"] = shots + return SamplerResult(probabilities, metadata) def _transpile(self): @@ -190,7 +190,10 @@ def _bound_pass_manager_run(self, circuits): if self._bound_pass_manager is None: return circuits else: - return cast("list[QuantumCircuit]", self._bound_pass_manager.run(circuits)) + output = self._bound_pass_manager.run(circuits) + if not isinstance(output, list): + output = [output] + return output def _run( self, diff --git a/qiskit/primitives/base/base_estimator.py b/qiskit/primitives/base/base_estimator.py index 9d3f5212c80e..e5d396884fdb 100644 --- a/qiskit/primitives/base/base_estimator.py +++ b/qiskit/primitives/base/base_estimator.py @@ -96,9 +96,9 @@ from qiskit.quantum_info.operators.base_operator import BaseOperator from qiskit.utils.deprecation import deprecate_arguments, deprecate_function +from ..utils import _circuit_key, _observable_key, init_observable from .base_primitive import BasePrimitive from .estimator_result import EstimatorResult -from ..utils import _circuit_key, _observable_key, init_observable class BaseEstimator(BasePrimitive): diff --git a/qiskit/primitives/base/base_sampler.py b/qiskit/primitives/base/base_sampler.py index 0ef339d58696..28b717e45b1a 100644 --- a/qiskit/primitives/base/base_sampler.py +++ b/qiskit/primitives/base/base_sampler.py @@ -88,9 +88,9 @@ from qiskit.providers import JobV1 as Job from qiskit.utils.deprecation import deprecate_arguments, deprecate_function +from ..utils import _circuit_key from .base_primitive import BasePrimitive from .sampler_result import SamplerResult -from ..utils import _circuit_key class BaseSampler(BasePrimitive): diff --git a/releasenotes/notes/fix-bound-pm-backend-primitives-98fd11c5e852501c.yaml b/releasenotes/notes/fix-bound-pm-backend-primitives-98fd11c5e852501c.yaml new file mode 100644 index 000000000000..4197e41be561 --- /dev/null +++ b/releasenotes/notes/fix-bound-pm-backend-primitives-98fd11c5e852501c.yaml @@ -0,0 +1,7 @@ +--- +fixes: + - | + Fixed the :class:`.BackendSampler` and :class:`.BackendExtimator` to run successfully + with a custom ``bound_pass_manager``. Previously, the execution for single circuits with + a ``bound_pass_manager`` would raise a ``ValueError`` because a list was not returned + in one of the steps. diff --git a/test/python/primitives/test_backend_estimator.py b/test/python/primitives/test_backend_estimator.py index 2fcdd6212288..d2a9ea84dd70 100644 --- a/test/python/primitives/test_backend_estimator.py +++ b/test/python/primitives/test_backend_estimator.py @@ -12,8 +12,10 @@ """Tests for Estimator.""" +import logging import unittest from test import combine +from test.python.transpiler._dummy_passes import DummyAP import numpy as np from ddt import ddt @@ -25,9 +27,23 @@ from qiskit.providers.fake_provider import FakeNairobi, FakeNairobiV2 from qiskit.quantum_info import SparsePauliOp from qiskit.test import QiskitTestCase +from qiskit.transpiler import PassManager BACKENDS = [FakeNairobi(), FakeNairobiV2()] +logger = "LocalLogger" + + +class LogPass(DummyAP): + """A dummy analysis pass that logs when executed""" + + def __init__(self, message): + super().__init__() + self.message = message + + def run(self, dag): + logging.getLogger(logger).info(self.message) + @ddt class TestBackendEstimator(QiskitTestCase): @@ -322,6 +338,29 @@ def test_no_max_circuits(self): self.assertEqual(len(result.metadata), k) np.testing.assert_allclose(result.values, target.values, rtol=0.2, atol=0.2) + def test_bound_pass_manager(self): + """Test bound pass manager.""" + + bound_counter = LogPass("bound_pass_manager") + bound_pass = PassManager(bound_counter) + + estimator = BackendEstimator(backend=FakeNairobi(), bound_pass_manager=bound_pass) + + qc = QuantumCircuit(2) + op = SparsePauliOp.from_list([("II", 1)]) + + with self.subTest("Test single circuit"): + with self.assertLogs(logger, level="INFO") as cm: + _ = estimator.run(qc, op).result() + expected = ["INFO:LocalLogger:bound_pass_manager"] + self.assertEqual(cm.output, expected) + + with self.subTest("Test circuit batch"): + with self.assertLogs(logger, level="INFO") as cm: + _ = estimator.run([qc, qc], [op, op]).result() + expected = ["INFO:LocalLogger:bound_pass_manager"] * 2 + self.assertEqual(cm.output, expected) + if __name__ == "__main__": unittest.main() diff --git a/test/python/primitives/test_backend_sampler.py b/test/python/primitives/test_backend_sampler.py index e3816bf6240d..6f5f52b6babb 100644 --- a/test/python/primitives/test_backend_sampler.py +++ b/test/python/primitives/test_backend_sampler.py @@ -12,9 +12,11 @@ """Tests for BackendSampler.""" +import logging import math import unittest from test import combine +from test.python.transpiler._dummy_passes import DummyAP import numpy as np from ddt import ddt @@ -25,10 +27,24 @@ from qiskit.providers import JobStatus, JobV1 from qiskit.providers.fake_provider import FakeNairobi, FakeNairobiV2 from qiskit.test import QiskitTestCase +from qiskit.transpiler import PassManager from qiskit.utils import optionals BACKENDS = [FakeNairobi(), FakeNairobiV2()] +logger = "LocalLogger" + + +class LogPass(DummyAP): + """A dummy analysis pass that logs when executed""" + + def __init__(self, message): + super().__init__() + self.message = message + + def run(self, dag): + logging.getLogger(logger).info(self.message) + @ddt class TestBackendSampler(QiskitTestCase): @@ -358,6 +374,26 @@ def test_sequential_run(self): self.assertDictAlmostEqual(result3.quasi_dists[0], {0: 1}, 0.1) self.assertDictAlmostEqual(result3.quasi_dists[1], {1: 1}, 0.1) + def test_bound_pass_manager(self): + """Test bound pass manager.""" + + bound_counter = LogPass("bound_pass_manager") + bound_pass = PassManager(bound_counter) + + sampler = BackendSampler(backend=FakeNairobi(), bound_pass_manager=bound_pass) + + with self.subTest("Test single circuit"): + with self.assertLogs(logger, level="INFO") as cm: + _ = sampler.run(self._circuit[0]).result() + expected = ["INFO:LocalLogger:bound_pass_manager"] + self.assertEqual(cm.output, expected) + + with self.subTest("Test circuit batch"): + with self.assertLogs(logger, level="INFO") as cm: + _ = sampler.run([self._circuit[0], self._circuit[0]]).result() + expected = ["INFO:LocalLogger:bound_pass_manager"] * 2 + self.assertEqual(cm.output, expected) + if __name__ == "__main__": unittest.main()