Skip to content

Commit

Permalink
Expose seed in Estimator and StatevectorEstimator (#12862)
Browse files Browse the repository at this point in the history
* Fix Estimator and StatevectorEstimator with reset

* reno

* use rng instead of seed

* update reno and tests

* simplify tests

* apply review comments

* Apply suggestions from code review

Co-authored-by: Julien Gacon <gaconju@gmail.com>

---------

Co-authored-by: Julien Gacon <gaconju@gmail.com>
Co-authored-by: Julien Gacon <jules.gacon@googlemail.com>
  • Loading branch information
3 people committed Aug 8, 2024
1 parent 154601b commit e7ee189
Show file tree
Hide file tree
Showing 6 changed files with 129 additions and 11 deletions.
17 changes: 12 additions & 5 deletions qiskit/primitives/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@

from qiskit.circuit import QuantumCircuit
from qiskit.exceptions import QiskitError
from qiskit.quantum_info import Statevector
from qiskit.quantum_info.operators.base_operator import BaseOperator
from qiskit.utils.deprecation import deprecate_func

Expand All @@ -31,7 +30,7 @@
from .utils import (
_circuit_key,
_observable_key,
bound_circuit_to_instruction,
_statevector_from_circuit,
init_observable,
)

Expand All @@ -43,13 +42,21 @@ class Estimator(BaseEstimator[PrimitiveJob[EstimatorResult]]):
:Run Options:
- **shots** (None or int) --
The number of shots. If None, it calculates the exact expectation
values. Otherwise, it samples from normal distributions with standard errors as standard
The number of shots. If None, it calculates the expectation values
with full state vector simulation.
Otherwise, it samples from normal distributions with standard errors as standard
deviations using normal distribution approximation.
- **seed** (np.random.Generator or int) --
Set a fixed seed or generator for the normal distribution. If shots is None,
this option is ignored.
.. note::
The result of this class is exact if the circuit contains only unitary operations.
On the other hand, the result could be stochastic if the circuit contains a non-unitary
operation such as a reset for a some subsystems.
The stochastic result can be made reproducible by setting ``seed``, e.g.,
``Estimator(options={"seed":123})``.
"""

@deprecate_func(
Expand Down Expand Up @@ -112,7 +119,7 @@ def _call(
f"The number of qubits of a circuit ({circ.num_qubits}) does not match "
f"the number of qubits of a observable ({obs.num_qubits})."
)
final_state = Statevector(bound_circuit_to_instruction(circ))
final_state = _statevector_from_circuit(circ, rng)
expectation_value = final_state.expectation_value(obs)
if shots is None:
expectation_values.append(expectation_value)
Expand Down
13 changes: 10 additions & 3 deletions qiskit/primitives/statevector_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,13 @@

import numpy as np

from qiskit.quantum_info import SparsePauliOp, Statevector
from qiskit.quantum_info import SparsePauliOp

from .base import BaseEstimatorV2
from .containers import DataBin, EstimatorPubLike, PrimitiveResult, PubResult
from .containers.estimator_pub import EstimatorPub
from .primitive_job import PrimitiveJob
from .utils import bound_circuit_to_instruction
from .utils import _statevector_from_circuit


class StatevectorEstimator(BaseEstimatorV2):
Expand All @@ -41,6 +41,13 @@ class StatevectorEstimator(BaseEstimatorV2):
called an estimator primitive unified bloc (PUB), produces its own array-based result. The
:meth:`~.EstimatorV2.run` method can be given a sequence of pubs to run in one call.
.. note::
The result of this class is exact if the circuit contains only unitary operations.
On the other hand, the result could be stochastic if the circuit contains a non-unitary
operation such as a reset for a some subsystems.
The stochastic result can be made reproducible by setting ``seed``, e.g.,
``StatevectorEstimator(seed=123)``.
.. plot::
:include-source:
Expand Down Expand Up @@ -151,7 +158,7 @@ def _run_pub(self, pub: EstimatorPub) -> PubResult:
for index in np.ndindex(*bc_circuits.shape):
bound_circuit = bc_circuits[index]
observable = bc_obs[index]
final_state = Statevector(bound_circuit_to_instruction(bound_circuit))
final_state = _statevector_from_circuit(bound_circuit, rng)
paulis, coeffs = zip(*observable.items())
obs = SparsePauliOp(paulis, coeffs) # TODO: support non Pauli operators
expectation_value = np.real_if_close(final_state.expectation_value(obs))
Expand Down
20 changes: 20 additions & 0 deletions qiskit/primitives/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,3 +225,23 @@ def bound_circuit_to_instruction(circuit: QuantumCircuit) -> Instruction:
)
inst.definition = circuit
return inst


def _statevector_from_circuit(
circuit: QuantumCircuit, rng: np.random.Generator | None
) -> Statevector:
"""Generate a statevector from a circuit
If the input circuit includes any resets for a some subsystem,
:meth:`.Statevector.reset` behaves in a stochastic way in :meth:`.Statevector.evolve`.
This function sets a random number generator to be reproducible.
See :meth:`.Statevector.reset` for details.
Args:
circuit: The quantum circuit.
seed: The random number generator or None.
"""
sv = Statevector.from_int(0, 2**circuit.num_qubits)
sv.seed(rng)
return sv.evolve(bound_circuit_to_instruction(circuit))
19 changes: 19 additions & 0 deletions releasenotes/notes/fix-estimator-reset-9e7539776df4cac4.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
---
features_primitives:
- |
:class:`.Estimator` and :class:`.StatevectorEstimator` return
expectation values in a stochastic way if the input circuit includes
a reset for a some subsystems.
The result was not reproducible, but it is now reproducible
if a random seed is set. For example::
from qiskit.primitives import StatevectorEstimator
estimator = StatevectorEstimator(seed=123)
or::
from qiskit.primitives import Estimator
estimator = Estimator(options={"seed":123})
37 changes: 36 additions & 1 deletion test/python/primitives/test_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
"""Tests for Estimator."""

import unittest
from test import QiskitTestCase

import numpy as np
from ddt import data, ddt, unpack
Expand All @@ -24,7 +25,6 @@
from qiskit.primitives.base import validation
from qiskit.primitives.utils import _observable_key
from qiskit.quantum_info import Pauli, SparsePauliOp
from test import QiskitTestCase # pylint: disable=wrong-import-order


class TestEstimator(QiskitTestCase):
Expand Down Expand Up @@ -355,6 +355,41 @@ def get_op(i):
keys = [_observable_key(get_op(i)) for i in range(5)]
self.assertEqual(len(keys), len(set(keys)))

def test_reset(self):
"""Test for circuits with reset."""
qc = QuantumCircuit(2)
qc.h(0)
qc.cx(0, 1)
qc.reset(0)
op = SparsePauliOp("ZI")

seed = 12
n = 1000
with self.subTest("shots=None"):
with self.assertWarns(DeprecationWarning):
estimator = Estimator(options={"seed": seed})
result = estimator.run([qc for _ in range(n)], [op] * n).result()
# expectation values should be stochastic due to reset for subsystems
np.testing.assert_allclose(result.values.mean(), 0, atol=1e-1)

with self.assertWarns(DeprecationWarning):
result2 = estimator.run([qc for _ in range(n)], [op] * n).result()
# expectation values should be reproducible due to seed
np.testing.assert_allclose(result.values, result2.values)

with self.subTest("shots=10000"):
shots = 10000
with self.assertWarns(DeprecationWarning):
estimator = Estimator(options={"seed": seed})
result = estimator.run([qc for _ in range(n)], [op] * n, shots=shots).result()
# expectation values should be stochastic due to reset for subsystems
np.testing.assert_allclose(result.values.mean(), 0, atol=1e-1)

with self.assertWarns(DeprecationWarning):
result2 = estimator.run([qc for _ in range(n)], [op] * n, shots=shots).result()
# expectation values should be reproducible due to seed
np.testing.assert_allclose(result.values, result2.values)


@ddt
class TestObservableValidation(QiskitTestCase):
Expand Down
34 changes: 32 additions & 2 deletions test/python/primitives/test_statevector_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,17 @@
"""Tests for Estimator."""

import unittest
from test import QiskitTestCase

import numpy as np

from qiskit.circuit import Parameter, QuantumCircuit
from qiskit.circuit.library import RealAmplitudes
from qiskit.primitives import StatevectorEstimator
from qiskit.primitives.containers.bindings_array import BindingsArray
from qiskit.primitives.containers.estimator_pub import EstimatorPub
from qiskit.primitives.containers.observables_array import ObservablesArray
from qiskit.primitives.containers.bindings_array import BindingsArray
from qiskit.quantum_info import SparsePauliOp
from test import QiskitTestCase # pylint: disable=wrong-import-order


class TestStatevectorEstimator(QiskitTestCase):
Expand Down Expand Up @@ -307,6 +307,36 @@ def test_metadata(self):
result[1].metadata, {"target_precision": 0.1, "circuit_metadata": qc2.metadata}
)

def test_reset(self):
"""Test for circuits with reset."""
qc = QuantumCircuit(2)
qc.h(0)
qc.cx(0, 1)
qc.reset(0)
op = SparsePauliOp("ZI")

seed = 12
n = 1000
estimator = StatevectorEstimator(seed=seed)
with self.subTest("precision=0"):
result = estimator.run([(qc, [op] * n)]).result()
# expectation values should be stochastic due to reset for subsystems
np.testing.assert_allclose(result[0].data.evs.mean(), 0, atol=1e-1)

result2 = estimator.run([(qc, [op] * n)]).result()
# expectation values should be reproducible due to seed
np.testing.assert_allclose(result[0].data.evs, result2[0].data.evs)

with self.subTest("precision=0.01"):
precision = 0.01
result = estimator.run([(qc, [op] * n)], precision=precision).result()
# expectation values should be stochastic due to reset for subsystems
np.testing.assert_allclose(result[0].data.evs.mean(), 0, atol=1e-1)

result2 = estimator.run([(qc, [op] * n)], precision=precision).result()
# expectation values should be reproducible due to seed
np.testing.assert_allclose(result[0].data.evs, result2[0].data.evs)


if __name__ == "__main__":
unittest.main()

0 comments on commit e7ee189

Please sign in to comment.