Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Primitives support the dynamic circuits with control flow #9231

Merged
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 13 additions & 4 deletions qiskit/primitives/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
"""
from __future__ import annotations

from collections.abc import Iterable

import numpy as np

from qiskit.circuit import Instruction, ParameterExpression, QuantumCircuit
Expand Down Expand Up @@ -123,6 +125,16 @@ def _bits_key(bits: tuple[Bit, ...], circuit: QuantumCircuit) -> tuple:
)


def _format_params(param):
if isinstance(param, np.ndarray):
return param.data.tobytes()
elif isinstance(param, QuantumCircuit):
return _circuit_key(param)
elif isinstance(param, Iterable):
return tuple(param)
return param


def _circuit_key(circuit: QuantumCircuit, functional: bool = True) -> tuple:
"""Private key function for QuantumCircuit.

Expand All @@ -145,10 +157,7 @@ def _circuit_key(circuit: QuantumCircuit, functional: bool = True) -> tuple:
_bits_key(data.qubits, circuit), # qubits
_bits_key(data.clbits, circuit), # clbits
data.operation.name, # operation.name
tuple(
param.data.tobytes() if isinstance(param, np.ndarray) else param
for param in data.operation.params
), # operation.params
tuple(_format_params(param) for param in data.operation.params), # operation.params
)
for data in circuit.data
),
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
---
features:
- |
Primitives may now support dynamic circuits with control flow, if the particular
provider's implementation can support them. Previously, the
:class:`~BaseSampler` and :class:`~BaseEstimator` base classes could not correctly
normalize such circuits. This change does not automatically make all
primitives support dynamic circuits, but it does make it possible for them
to be supported by downstream providers.
23 changes: 22 additions & 1 deletion test/python/primitives/test_backend_sampler.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# This code is part of Qiskit.
#
# (C) Copyright IBM 2022.
# (C) Copyright IBM 2022, 2023.
#
# This code is licensed under the Apache License, Version 2.0. You may
# obtain a copy of this license in the LICENSE.txt file in the root directory
Expand All @@ -24,6 +24,7 @@
from qiskit.providers import JobStatus, JobV1
from qiskit.providers.fake_provider import FakeNairobi, FakeNairobiV2
from qiskit.test import QiskitTestCase
from qiskit.utils import optionals

BACKENDS = [FakeNairobi(), FakeNairobiV2()]

Expand Down Expand Up @@ -319,6 +320,26 @@ def test_primitive_job_size_limit_backend_v1(self):
self.assertDictAlmostEqual(result.quasi_dists[0], {0: 1}, 0.1)
self.assertDictAlmostEqual(result.quasi_dists[1], {1: 1}, 0.1)

@unittest.skipUnless(optionals.HAS_AER, "qiskit-aer is required to run this test")
def test_circuit_with_dynamic_circuit(self):
"""Test BackendSampler with QuantumCircuit with a dynamic circuit"""
from qiskit_aer import Aer

qc = QuantumCircuit(2, 1)

with qc.for_loop(range(5)):
qc.h(0)
qc.cx(0, 1)
qc.measure(0, 0)
qc.break_loop().c_if(0, True)

backend = Aer.get_backend("aer_simulator")
backend.set_options(seed_simulator=15)
sampler = BackendSampler(backend, skip_transpilation=True)
sampler.set_transpile_options(seed_transpiler=15)
result = sampler.run(qc).result()
self.assertDictAlmostEqual(result.quasi_dists[0], {0: 0.5029296875, 1: 0.4970703125})

def test_sequential_run(self):
"""Test sequential run."""
qc = QuantumCircuit(1)
Expand Down
55 changes: 52 additions & 3 deletions test/python/primitives/test_primitive.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# This code is part of Qiskit.
#
# (C) Copyright IBM 2022.
# (C) Copyright IBM 2022, 2023.
#
# This code is licensed under the Apache License, Version 2.0. You may
# obtain a copy of this license in the LICENSE.txt file in the root directory
Expand All @@ -12,12 +12,16 @@

"""Tests for BasePrimitive."""

from ddt import ddt, data, unpack
import json

from numpy import array, int32, int64, float32, float64
from ddt import data, ddt, unpack
from numpy import array, float32, float64, int32, int64

from qiskit import QuantumCircuit, pulse, transpile
from qiskit.circuit.random import random_circuit
from qiskit.primitives.base.base_primitive import BasePrimitive
from qiskit.primitives.utils import _circuit_key
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not a big fan of these tests of internal features, because generally they make it harder for the internal implementation to evolve without changing the tests. Ideally we should only be testing the outward public behaviour, and not the particulars of how that's achieved. If the tests are too tightly coupled, then any change to the internal implementation has to change the tests, and that makes it much easier for bugs and regressions to sneak in; we have to verify that the new tests are equivalent, and cover the same ground.

That said, this is a single function, and if you think it's much better to test this way, I won't block this PR on it.

I'd potentially attempt to rewrite the tests by using a custom subclass of the primitives for testing that has its _run (or whatever it's called) method defined to make assertions about the normalised circuits it gets back from the base class's handling. That's testing public parts of the subclassing API, rather than this internal detail of how that's done.

edit: I also see now that these tests are just moved from one file to another, rather than actually written new. I'm still not a fan, but it does move this change out-of-scope of this PR.

from qiskit.providers.fake_provider import FakeAlmaden
from qiskit.test import QiskitTestCase


Expand Down Expand Up @@ -110,3 +114,48 @@ def test_value_error(self):
"""Test value error if no parameter_values or default are provided."""
with self.assertRaises(ValueError):
BasePrimitive._validate_parameter_values(None)


class TestCircuitKey(QiskitTestCase):
"""Tests for _circuit_key function"""

def test_different_circuits(self):
"""Test collision of quantum circuits."""

with self.subTest("Ry circuit"):

def test_func(n):
qc = QuantumCircuit(1, 1, name="foo")
qc.ry(n, 0)
return qc

keys = [_circuit_key(test_func(i)) for i in range(5)]
self.assertEqual(len(keys), len(set(keys)))

with self.subTest("pulse circuit"):

def test_with_scheduling(n):
custom_gate = pulse.Schedule(name="custom_x_gate")
custom_gate.insert(
0, pulse.Play(pulse.Constant(160 * n, 0.1), pulse.DriveChannel(0)), inplace=True
)
qc = QuantumCircuit(1)
qc.x(0)
qc.add_calibration("x", qubits=(0,), schedule=custom_gate)
return transpile(qc, FakeAlmaden(), scheduling_method="alap")

keys = [_circuit_key(test_with_scheduling(i)) for i in range(1, 5)]
self.assertEqual(len(keys), len(set(keys)))

def test_circuit_key_controlflow(self):
"""Test for a circuit with control flow."""
qc = QuantumCircuit(2, 1)

with qc.for_loop(range(5)):
qc.h(0)
qc.cx(0, 1)
qc.measure(0, 0)
qc.break_loop().c_if(0, True)

self.assertIsInstance(hash(_circuit_key(qc)), int)
self.assertIsInstance(json.dumps(_circuit_key(qc)), str)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This json line doesn't look like it belongs in Terra - JSON serialisation of the internal object _circuit_key shouldn't be something Terra needs, and no package downstream of us should be relying on a private function having any particular behaviour.

Using a dummy primitive whose _run method (or whatever) just makes assertions / leaks the received circuit back to the caller somehow could be a cleaner way of making a public-API test.

34 changes: 2 additions & 32 deletions test/python/primitives/test_sampler.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# This code is part of Qiskit.
#
# (C) Copyright IBM 2022.
# (C) Copyright IBM 2022, 2023.
#
# This code is licensed under the Apache License, Version 2.0. You may
# obtain a copy of this license in the LICENSE.txt file in the root directory
Expand All @@ -18,15 +18,13 @@
import numpy as np
from ddt import ddt

from qiskit import QuantumCircuit, pulse, transpile
from qiskit import QuantumCircuit
from qiskit.circuit import Parameter
from qiskit.circuit.library import RealAmplitudes
from qiskit.exceptions import QiskitError
from qiskit.extensions.unitary import UnitaryGate
from qiskit.primitives import Sampler, SamplerResult
from qiskit.primitives.utils import _circuit_key
from qiskit.providers import JobStatus, JobV1
from qiskit.providers.fake_provider import FakeAlmaden
from qiskit.test import QiskitTestCase


Expand Down Expand Up @@ -743,34 +741,6 @@ def test_options(self):
self._compare_probs(result.quasi_dists, target)
self.assertEqual(result.quasi_dists[0].shots, 1024)

def test_different_circuits(self):
"""Test collision of quantum circuits."""

with self.subTest("Ry circuit"):

def test_func(n):
qc = QuantumCircuit(1, 1, name="foo")
qc.ry(n, 0)
return qc

keys = [_circuit_key(test_func(i)) for i in range(5)]
self.assertEqual(len(keys), len(set(keys)))

with self.subTest("pulse circuit"):

def test_with_scheduling(n):
custom_gate = pulse.Schedule(name="custom_x_gate")
custom_gate.insert(
0, pulse.Play(pulse.Constant(160 * n, 0.1), pulse.DriveChannel(0)), inplace=True
)
qc = QuantumCircuit(1)
qc.x(0)
qc.add_calibration("x", qubits=(0,), schedule=custom_gate)
return transpile(qc, FakeAlmaden(), scheduling_method="alap")

keys = [_circuit_key(test_with_scheduling(i)) for i in range(1, 5)]
self.assertEqual(len(keys), len(set(keys)))

def test_circuit_with_unitary(self):
"""Test for circuit with unitary gate."""
gate = UnitaryGate(np.eye(2))
Expand Down