diff --git a/.github/CHANGELOG.md b/.github/CHANGELOG.md index cb2ca1fd6f..bb9ff6706c 100644 --- a/.github/CHANGELOG.md +++ b/.github/CHANGELOG.md @@ -2,6 +2,9 @@ ### New features since last release +* `lightning.qubit` supports mid-circuit measurements. + [(#621)](https://github.com/PennyLaneAI/pennylane-lightning/pull/621) + * Add two new python classes (LightningStateVector and LightningMeasurements) to support `lightning.qubit2`. [(#613)](https://github.com/PennyLaneAI/pennylane-lightning/pull/613) @@ -32,7 +35,7 @@ This release contains contributions from (in alphabetical order): -Ali Asadi, Amintor Dusko, Vincent Michaud-Rioux +Ali Asadi, Amintor Dusko, Thomas Germain, Vincent Michaud-Rioux --- diff --git a/.github/workflows/tests_gpu_kokkos.yml b/.github/workflows/tests_gpu_kokkos.yml index c31488ad40..44894634ce 100644 --- a/.github/workflows/tests_gpu_kokkos.yml +++ b/.github/workflows/tests_gpu_kokkos.yml @@ -325,7 +325,7 @@ jobs: OMP_PROC_BIND: false run: | cd main/ - PL_DEVICE=lightning.qubit python -m pytest tests/ $COVERAGE_FLAGS + PL_DEVICE=lightning.qubit python -m pytest tests/ -k "not test_native_mcm" $COVERAGE_FLAGS pl-device-test --device lightning.qubit --skip-ops --shots=20000 $COVERAGE_FLAGS --cov-append pl-device-test --device lightning.qubit --shots=None --skip-ops $COVERAGE_FLAGS --cov-append PL_DEVICE=lightning.kokkos python -m pytest tests/ $COVERAGE_FLAGS diff --git a/.github/workflows/tests_linux.yml b/.github/workflows/tests_linux.yml index f1641b27d2..570b822dd0 100644 --- a/.github/workflows/tests_linux.yml +++ b/.github/workflows/tests_linux.yml @@ -106,7 +106,7 @@ jobs: matrix: os: [ubuntu-22.04] pl_backend: ["lightning_qubit"] - timeout-minutes: 30 + timeout-minutes: 60 name: Python tests runs-on: ${{ matrix.os }} @@ -184,7 +184,8 @@ jobs: run: | cd main/ DEVICENAME=`echo ${{ matrix.pl_backend }} | sed "s/_/./g"` - PL_DEVICE=${DEVICENAME} python -m pytest tests/ $COVERAGE_FLAGS + PL_DEVICE=${DEVICENAME} python -m pytest tests/ -k "not test_native_mcm" $COVERAGE_FLAGS + OMP_NUM_THREADS=1 PL_DEVICE=${DEVICENAME} python -m pytest -n auto tests/test_native_mcm.py $COVERAGE_FLAGS pl-device-test --device ${DEVICENAME} --skip-ops --shots=20000 $COVERAGE_FLAGS --cov-append pl-device-test --device ${DEVICENAME} --shots=None --skip-ops $COVERAGE_FLAGS --cov-append mv .coverage .coverage-${{ github.job }}-${{ matrix.pl_backend }} @@ -392,7 +393,7 @@ jobs: run: | cd main/ DEVICENAME=`echo ${{ matrix.pl_backend }} | sed "s/_/./g"` - PL_DEVICE=${DEVICENAME} python -m pytest tests/ $COVERAGE_FLAGS + PL_DEVICE=${DEVICENAME} python -m pytest tests/ -k "not test_native_mcm" $COVERAGE_FLAGS pl-device-test --device ${DEVICENAME} --skip-ops --shots=20000 $COVERAGE_FLAGS --cov-append pl-device-test --device ${DEVICENAME} --shots=None --skip-ops $COVERAGE_FLAGS --cov-append mv .coverage .coverage-${{ github.job }}-${{ matrix.pl_backend }} @@ -590,7 +591,7 @@ jobs: run: | cd main/ DEVICENAME=`echo ${{ matrix.pl_backend }} | sed "s/_/./g"` - PL_DEVICE=${DEVICENAME} python -m pytest tests/ $COVERAGE_FLAGS + PL_DEVICE=${DEVICENAME} python -m pytest tests/ -k "not test_native_mcm" $COVERAGE_FLAGS pl-device-test --device ${DEVICENAME} --skip-ops --shots=20000 $COVERAGE_FLAGS --cov-append pl-device-test --device ${DEVICENAME} --shots=None --skip-ops $COVERAGE_FLAGS --cov-append mv .coverage .coverage-${{ github.job }}-${{ matrix.pl_backend }} @@ -609,10 +610,10 @@ jobs: if: ${{ matrix.pl_backend == 'all' }} run: | cd main/ - PL_DEVICE=lightning.qubit python -m pytest tests/ $COVERAGE_FLAGS + PL_DEVICE=lightning.qubit python -m pytest tests/ -k "not test_native_mcm" $COVERAGE_FLAGS pl-device-test --device lightning.qubit --skip-ops --shots=20000 $COVERAGE_FLAGS --cov-append pl-device-test --device lightning.qubit --shots=None --skip-ops $COVERAGE_FLAGS --cov-append - PL_DEVICE=lightning.kokkos python -m pytest tests/ $COVERAGE_FLAGS + PL_DEVICE=lightning.kokkos python -m pytest tests/ -k "not test_native_mcm" $COVERAGE_FLAGS pl-device-test --device lightning.kokkos --skip-ops --shots=20000 $COVERAGE_FLAGS --cov-append pl-device-test --device lightning.kokkos --shots=None --skip-ops $COVERAGE_FLAGS --cov-append mv .coverage .coverage-${{ github.job }}-${{ matrix.pl_backend }} diff --git a/.github/workflows/tests_without_binary.yml b/.github/workflows/tests_without_binary.yml index 680b2b75c3..a7952931ed 100644 --- a/.github/workflows/tests_without_binary.yml +++ b/.github/workflows/tests_without_binary.yml @@ -106,7 +106,7 @@ jobs: run: | cd main/ DEVICENAME=`echo ${{ matrix.pl_backend }} | sed "s/_/./g"` - PL_DEVICE=${DEVICENAME} python -m pytest tests/ $COVERAGE_FLAGS + PL_DEVICE=${DEVICENAME} python -m pytest tests/ -k "not test_native_mcm" $COVERAGE_FLAGS pl-device-test --device ${DEVICENAME} --skip-ops --shots=20000 $COVERAGE_FLAGS --cov-append pl-device-test --device ${DEVICENAME} --shots=None --skip-ops $COVERAGE_FLAGS --cov-append diff --git a/pennylane_lightning/core/_version.py b/pennylane_lightning/core/_version.py index ea4384e017..f95ce9d020 100644 --- a/pennylane_lightning/core/_version.py +++ b/pennylane_lightning/core/_version.py @@ -16,4 +16,4 @@ Version number (major.minor.patch[-label]) """ -__version__ = "0.36.0-dev9" +__version__ = "0.36.0-dev10" diff --git a/pennylane_lightning/core/src/measurements/MeasurementsBase.hpp b/pennylane_lightning/core/src/measurements/MeasurementsBase.hpp index f67d625ad7..ded6192d17 100644 --- a/pennylane_lightning/core/src/measurements/MeasurementsBase.hpp +++ b/pennylane_lightning/core/src/measurements/MeasurementsBase.hpp @@ -17,6 +17,7 @@ */ #pragma once +#include #include #include @@ -55,6 +56,7 @@ template class MeasurementsBase { #else const StateVectorT &_statevector; #endif + std::mt19937 rng; public: #ifdef _ENABLE_PLGPU @@ -65,6 +67,23 @@ template class MeasurementsBase { : _statevector{statevector} {}; #endif + /** + * @brief Set the seed of the internal random generator + * + * @param seed Seed + */ + void setSeed(const size_t seed) { rng.seed(seed); } + + /** + * @brief Randomly set the seed of the internal random generator + * + * @param seed Seed + */ + void setRandomSeed() { + std::random_device rd; + setSeed(rd()); + } + /** * @brief Calculate the expectation value for a general Observable. * diff --git a/pennylane_lightning/core/src/simulators/lightning_gpu/measurements/MeasurementsGPU.hpp b/pennylane_lightning/core/src/simulators/lightning_gpu/measurements/MeasurementsGPU.hpp index f6b73fca90..fca0175177 100644 --- a/pennylane_lightning/core/src/simulators/lightning_gpu/measurements/MeasurementsGPU.hpp +++ b/pennylane_lightning/core/src/simulators/lightning_gpu/measurements/MeasurementsGPU.hpp @@ -224,10 +224,10 @@ class Measurements final data_type = CUDA_C_32F; } - std::mt19937 gen(std::random_device{}()); + this->setRandomSeed(); std::uniform_real_distribution dis(0.0, 1.0); for (size_t n = 0; n < num_samples; n++) { - rand_nums[n] = dis(gen); + rand_nums[n] = dis(this->rng); } std::vector samples(num_samples * num_qubits, 0); std::unordered_map cache; diff --git a/pennylane_lightning/core/src/simulators/lightning_qubit/StateVectorLQubit.hpp b/pennylane_lightning/core/src/simulators/lightning_qubit/StateVectorLQubit.hpp index afcfb080ec..c7e8002eb6 100644 --- a/pennylane_lightning/core/src/simulators/lightning_qubit/StateVectorLQubit.hpp +++ b/pennylane_lightning/core/src/simulators/lightning_qubit/StateVectorLQubit.hpp @@ -36,6 +36,7 @@ namespace { using Pennylane::LightningQubit::Util::Threading; using Pennylane::Util::CPUMemoryModel; using Pennylane::Util::exp2; +using Pennylane::Util::squaredNorm; using namespace Pennylane::LightningQubit::Gates; } // namespace /// @endcond @@ -632,5 +633,53 @@ class StateVectorLQubit : public StateVectorBase { applyMatrix(matrix.data(), wires, inverse); } + + /** + * @brief Collapse the state vector as after having measured one of the + * qubits. + * + * The branch parameter imposes the measurement result on the given wire. + * + * @param wire Wire to collapse. + * @param branch Branch 0 or 1. + */ + void collapse(const std::size_t wire, const bool branch) { + auto *arr = this->getData(); + const size_t stride = pow(2, this->num_qubits_ - (1 + wire)); + const size_t vec_size = pow(2, this->num_qubits_); + const auto section_size = vec_size / stride; + const auto half_section_size = section_size / 2; + + // zero half the entries + // the "half" entries depend on the stride + // *_*_*_*_ for stride 1 + // **__**__ for stride 2 + // ****____ for stride 4 + const size_t k = branch ? 0 : 1; + for (size_t idx = 0; idx < half_section_size; idx++) { + const size_t offset = stride * (k + 2 * idx); + for (size_t ids = 0; ids < stride; ids++) { + arr[offset + ids] = {0., 0.}; + } + } + + normalize(); + } + + /** + * @brief Normalize vector (to have norm 1). + */ + void normalize() { + auto *arr = this->getData(); + PrecisionT norm = std::sqrt(squaredNorm(arr, this->getLength())); + + PL_ABORT_IF(norm < std::numeric_limits::epsilon() * 1e2, + "vector has norm close to zero and can't be normalized"); + + std::complex inv_norm = 1. / norm; + for (size_t k = 0; k < this->getLength(); k++) { + arr[k] *= inv_norm; + } + } }; -} // namespace Pennylane::LightningQubit \ No newline at end of file +} // namespace Pennylane::LightningQubit diff --git a/pennylane_lightning/core/src/simulators/lightning_qubit/bindings/LQubitBindings.hpp b/pennylane_lightning/core/src/simulators/lightning_qubit/bindings/LQubitBindings.hpp index 74352a08aa..7994987c7f 100644 --- a/pennylane_lightning/core/src/simulators/lightning_qubit/bindings/LQubitBindings.hpp +++ b/pennylane_lightning/core/src/simulators/lightning_qubit/bindings/LQubitBindings.hpp @@ -220,6 +220,10 @@ void registerBackendClassSpecificBindings(PyClass &pyclass) { } }, "Copy StateVector data into a Numpy array.") + .def("collapse", &StateVectorT::collapse, + "Collapse the statevector onto the 0 or 1 branch of a given wire.") + .def("normalize", &StateVectorT::normalize, + "Normalizes the statevector to norm 1.") .def("applyControlledMatrix", &applyControlledMatrix, "Apply controlled operation") .def("kernel_map", &svKernelMap, diff --git a/pennylane_lightning/core/src/simulators/lightning_qubit/measurements/MeasurementsLQubit.hpp b/pennylane_lightning/core/src/simulators/lightning_qubit/measurements/MeasurementsLQubit.hpp index 744286740f..0e74849964 100644 --- a/pennylane_lightning/core/src/simulators/lightning_qubit/measurements/MeasurementsLQubit.hpp +++ b/pennylane_lightning/core/src/simulators/lightning_qubit/measurements/MeasurementsLQubit.hpp @@ -466,11 +466,10 @@ class Measurements final generate_samples_metropolis(const std::string &kernelname, size_t num_burnin, size_t num_samples) { size_t num_qubits = this->_statevector.getNumQubits(); - std::random_device rd; - std::mt19937 gen(rd()); std::uniform_real_distribution distrib(0.0, 1.0); std::vector samples(num_samples * num_qubits, 0); std::unordered_map cache; + this->setRandomSeed(); TransitionKernelType transition_kernel = TransitionKernelType::Local; if (kernelname == "NonZeroRandom") { @@ -484,13 +483,14 @@ class Measurements final // Burn In for (size_t i = 0; i < num_burnin; i++) { - idx = metropolis_step(this->_statevector, tk, gen, distrib, + idx = metropolis_step(this->_statevector, tk, this->rng, distrib, idx); // Burn-in. } // Sample for (size_t i = 0; i < num_samples; i++) { - idx = metropolis_step(this->_statevector, tk, gen, distrib, idx); + idx = metropolis_step(this->_statevector, tk, this->rng, distrib, + idx); if (cache.contains(idx)) { size_t cache_id = cache[idx]; @@ -562,9 +562,9 @@ class Measurements final auto &&probabilities = probs(); std::vector samples(num_samples * num_qubits, 0); - std::mt19937 generator(std::random_device{}()); std::uniform_real_distribution distribution(0.0, 1.0); std::unordered_map cache; + this->setRandomSeed(); const size_t N = probabilities.size(); std::vector bucket(N); @@ -611,7 +611,7 @@ class Measurements final // Pick samples for (size_t i = 0; i < num_samples; i++) { - PrecisionT pct = distribution(generator) * N; + PrecisionT pct = distribution(this->rng) * N; auto idx = static_cast(pct); if (pct - idx > bucket[idx]) { idx = bucket_partner[idx]; diff --git a/pennylane_lightning/core/src/simulators/lightning_qubit/tests/Test_StateVectorLQubit.cpp b/pennylane_lightning/core/src/simulators/lightning_qubit/tests/Test_StateVectorLQubit.cpp index 66ac1fd87f..310cb4cb68 100644 --- a/pennylane_lightning/core/src/simulators/lightning_qubit/tests/Test_StateVectorLQubit.cpp +++ b/pennylane_lightning/core/src/simulators/lightning_qubit/tests/Test_StateVectorLQubit.cpp @@ -216,3 +216,38 @@ TEMPLATE_PRODUCT_TEST_CASE("StateVectorLQubit::applyOperations", LightningException, "must all be equal"); // invalid parameters } } + +TEMPLATE_PRODUCT_TEST_CASE("StateVectorLQubit::collapse", "[StateVectorLQubit]", + (StateVectorLQubitManaged, StateVectorLQubitRaw), + (float, double)) { + using StateVectorT = TestType; + using PrecisionT = typename StateVectorT::PrecisionT; + using ComplexT = typename StateVectorT::ComplexT; + using TestVectorT = TestVector; + + const std::size_t num_qubits = 3; + + SECTION("Collapse the state vector as after having measured one of the " + "qubits.") { + TestVectorT init_state = createPlusState(num_qubits); + + const ComplexT coef{0.5, PrecisionT{0.0}}; + const ComplexT zero{PrecisionT{0.0}, PrecisionT{0.0}}; + + std::vector>> expected_state = { + {{coef, coef, coef, coef, zero, zero, zero, zero}, + {coef, coef, zero, zero, coef, coef, zero, zero}, + {coef, zero, coef, zero, coef, zero, coef, zero}}, + {{zero, zero, zero, zero, coef, coef, coef, coef}, + {zero, zero, coef, coef, zero, zero, coef, coef}, + {zero, coef, zero, coef, zero, coef, zero, coef}}, + }; + + std::size_t wire = GENERATE(0, 1, 2); + std::size_t branch = GENERATE(0, 1); + StateVectorLQubitManaged sv(init_state); + sv.collapse(wire, branch); + + REQUIRE(sv.getDataVector() == approx(expected_state[branch][wire])); + } +} diff --git a/pennylane_lightning/lightning_kokkos/lightning_kokkos.py b/pennylane_lightning/lightning_kokkos/lightning_kokkos.py index d7821660bf..59a54d390c 100644 --- a/pennylane_lightning/lightning_kokkos/lightning_kokkos.py +++ b/pennylane_lightning/lightning_kokkos/lightning_kokkos.py @@ -64,7 +64,7 @@ from pennylane.ops.op_math import Adjoint from pennylane.wires import Wires - # pylint: disable=import-error, no-name-in-module, ungrouped-imports + # pylint: disable=ungrouped-imports from pennylane_lightning.lightning_kokkos_ops.algorithms import ( AdjointJacobianC64, AdjointJacobianC128, diff --git a/pennylane_lightning/lightning_qubit/lightning_qubit.py b/pennylane_lightning/lightning_qubit/lightning_qubit.py index ffdb5e33a4..ecdec93e88 100644 --- a/pennylane_lightning/lightning_qubit/lightning_qubit.py +++ b/pennylane_lightning/lightning_qubit/lightning_qubit.py @@ -22,6 +22,8 @@ from warnings import warn import numpy as np +from pennylane.measurements import MidMeasureMP +from pennylane.ops import Conditional from pennylane_lightning.core.lightning_base import ( LightningBase, @@ -66,6 +68,10 @@ from pennylane.ops.op_math import Adjoint from pennylane.wires import Wires + # pylint: disable=import-error, no-name-in-module, ungrouped-imports + from pennylane_lightning.core._serialize import QuantumScriptSerializer + from pennylane_lightning.core._version import __version__ + # pylint: disable=no-name-in-module, ungrouped-imports from pennylane_lightning.lightning_qubit_ops.algorithms import ( AdjointJacobianC64, @@ -76,10 +82,6 @@ create_ops_listC128, ) - # pylint: disable=import-error, no-name-in-module, ungrouped-imports - from pennylane_lightning.core._serialize import QuantumScriptSerializer - from pennylane_lightning.core._version import __version__ - def _state_dtype(dtype): if dtype not in [np.complex128, np.complex64]: # pragma: no cover raise ValueError(f"Data type is not supported for state-vector computation: {dtype}") @@ -169,6 +171,8 @@ def _state_dtype(dtype): "QFT", "ECR", "BlockEncode", + "MidMeasureMP", + "Conditional", } allowed_observables = { @@ -258,6 +262,15 @@ def __init__( # pylint: disable=too-many-arguments self._kernel_name = kernel_name self._num_burnin = num_burnin + # pylint: disable=missing-function-docstring + @classmethod + def capabilities(cls): + capabilities = super().capabilities().copy() + capabilities.update( + supports_mid_measure=True, + ) + return capabilities + @staticmethod def _asarray(arr, dtype=None): arr = np.asarray(arr) # arr is not copied @@ -377,10 +390,10 @@ def _apply_lightning_controlled(self, operation): """Apply an arbitrary controlled operation to the state tensor. Args: - operation (~pennylane.operation.Operation): operation to apply + operation (~pennylane.operation.Operation): controlled operation to apply Returns: - array[complex]: the output state tensor + None """ state = self.state_vector @@ -419,32 +432,56 @@ def _apply_lightning_controlled(self, operation): operation.base.matrix, control_wires, control_values, target_wires, False ) - def apply_lightning(self, operations): + def _apply_lightning_midmeasure(self, operation: MidMeasureMP, mid_measurements: dict): + """Execute a MidMeasureMP operation and return the sample in mid_measurements. + + Args: + operation (~pennylane.operation.Operation): mid-circuit measurement + + Returns: + None + """ + wires = self.wires.indices(operation.wires) + wire = list(wires)[0] + sample = qml.math.reshape(self.generate_samples(shots=1), (-1,))[wire] + if operation.postselect is not None and sample != operation.postselect: + mid_measurements[operation] = -1 + return + mid_measurements[operation] = sample + getattr(self.state_vector, "collapse")(wire, bool(sample)) + if operation.reset and bool(sample): + self.apply([qml.PauliX(operation.wires)], mid_measurements=mid_measurements) + + def apply_lightning(self, operations, mid_measurements=None): """Apply a list of operations to the state tensor. Args: operations (list[~pennylane.operation.Operation]): operations to apply Returns: - array[complex]: the output state tensor + None """ state = self.state_vector - # Skip over identity operations instead of performing # matrix multiplication with it. for operation in operations: + if isinstance(operation, qml.Identity): + continue if isinstance(operation, Adjoint): name = operation.base.name invert_param = True else: name = operation.name invert_param = False - if name == "Identity": - continue method = getattr(state, name, None) wires = self.wires.indices(operation.wires) - if method is not None: # apply specialized gate + if isinstance(operation, Conditional): + if operation.meas_val.concretize(mid_measurements): + self.apply_lightning([operation.then_op]) + elif isinstance(operation, MidMeasureMP): + self._apply_lightning_midmeasure(operation, mid_measurements) + elif method is not None: # apply specialized gate param = operation.parameters method(wires, invert_param, param) elif ( @@ -464,7 +501,7 @@ def apply_lightning(self, operations): method(operation.matrix, wires, False) # pylint: disable=unused-argument - def apply(self, operations, rotations=None, **kwargs): + def apply(self, operations, rotations=None, mid_measurements=None, **kwargs): """Applies operations to the state vector.""" # State preparation is currently done in Python if operations: # make sure operations[0] exists @@ -484,7 +521,7 @@ def apply(self, operations, rotations=None, **kwargs): f"Operations have already been applied on a {self.short_name} device." ) - self.apply_lightning(operations) + self.apply_lightning(operations, mid_measurements=mid_measurements) # pylint: disable=protected-access def expval(self, observable, shot_range=None, bin_size=None): @@ -606,7 +643,7 @@ def var(self, observable, shot_range=None, bin_size=None): return measurements.var(observable.name, observable_wires) - def generate_samples(self): + def generate_samples(self, shots=None): """Generate samples Returns: @@ -618,13 +655,12 @@ def generate_samples(self): if self.use_csingle else MeasurementsC128(self.state_vector) ) + shots = self.shots if shots is None else shots if self._mcmc: return measurements.generate_mcmc_samples( - len(self.wires), self._kernel_name, self._num_burnin, self.shots + len(self.wires), self._kernel_name, self._num_burnin, shots ).astype(int, copy=False) - return measurements.generate_samples(len(self.wires), self.shots).astype( - int, copy=False - ) + return measurements.generate_samples(len(self.wires), shots).astype(int, copy=False) def probability_lightning(self, wires): """Return the probability of each computational basis state. diff --git a/requirements-dev.txt b/requirements-dev.txt index a4ceae1c8c..d444068eca 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -1,5 +1,5 @@ pip~=22.0 -git+https://github.com/PennyLaneAI/pennylane.git@master +git+https://github.com/PennyLaneAI/pennylane.git@feature/mid-measure-cpp ninja flaky pybind11 diff --git a/tests/test_native_mcm.py b/tests/test_native_mcm.py new file mode 100644 index 0000000000..1ff5cbcb8d --- /dev/null +++ b/tests/test_native_mcm.py @@ -0,0 +1,423 @@ +# Copyright 2024 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for default qubit preprocessing.""" +from functools import reduce +from typing import Sequence + +import numpy as np +import pennylane as qml +import pytest +from conftest import LightningDevice as ld +from conftest import device_name +from flaky import flaky + +if not ld._CPP_BINARY_AVAILABLE: + pytest.skip("No binary module found. Skipping.", allow_module_level=True) + +if device_name != "lightning.qubit": + pytest.skip(f"Native MCM not supported in device {device_name}.", allow_module_level=True) + + +def validate_counts(shots, results1, results2): + """Compares two counts. + + If the results are ``Sequence``s, loop over entries. + + Fails if a key of ``results1`` is not found in ``results2``. + Passes if counts are too low, chosen as ``100``. + Otherwise, fails if counts differ by more than ``20`` plus 20 percent. + """ + if isinstance(results1, Sequence): + assert isinstance(results2, Sequence) + assert len(results1) == len(results2) + for r1, r2 in zip(results1, results2): + validate_counts(shots, r1, r2) + return + for key1, val1 in results1.items(): + val2 = results2[key1] + if abs(val1 + val2) > 100: + assert np.allclose(val1, val2, rtol=20, atol=0.2) + + +def validate_samples(shots, results1, results2): + """Compares two samples. + + If the results are ``Sequence``s, loop over entries. + + Fails if the results do not have the same shape, within ``20`` entries plus 20 percent. + This is to handle cases when post-selection yields variable shapes. + Otherwise, fails if the sums of samples differ by more than ``20`` plus 20 percent. + """ + if isinstance(shots, Sequence): + assert isinstance(results1, Sequence) + assert isinstance(results2, Sequence) + assert len(results1) == len(results2) + for s, r1, r2 in zip(shots, results1, results2): + validate_samples(s, r1, r2) + else: + sh1, sh2 = results1.shape[0], results2.shape[0] + assert np.allclose(sh1, sh2, rtol=20, atol=0.2) + assert results1.ndim == results2.ndim + if results2.ndim > 1: + assert results1.shape[1] == results2.shape[1] + np.allclose(np.sum(results1), np.sum(results2), rtol=20, atol=0.2) + + +def validate_expval(shots, results1, results2): + """Compares two expval, probs or var. + + If the results are ``Sequence``s, validate the average of items. + + If ``shots is None``, validate using ``np.allclose``'s default parameters. + Otherwise, fails if the results do not match within ``0.01`` plus 20 percent. + """ + if isinstance(results1, Sequence): + assert isinstance(results2, Sequence) + assert len(results1) == len(results2) + results1 = reduce(lambda x, y: x + y, results1) / len(results1) + results2 = reduce(lambda x, y: x + y, results2) / len(results2) + validate_expval(shots, results1, results2) + return + if shots is None: + assert np.allclose(results1, results2) + return + assert np.allclose(results1, results2, atol=0.01, rtol=0.2) + + +def validate_measurements(func, shots, results1, results2): + """Calls the correct validation function based on measurement type.""" + if func is qml.counts: + validate_counts(shots, results1, results2) + return + + if func is qml.sample: + validate_samples(shots, results1, results2) + return + + validate_expval(shots, results1, results2) + + +def test_all_invalid_shots_circuit(): + dev = qml.device(device_name, wires=2) + dq = qml.device("default.qubit", wires=2) + + def circuit_op(): + m = qml.measure(0, postselect=1) + qml.cond(m, qml.PauliX)(1) + return ( + qml.expval(op=qml.PauliZ(1)), + qml.probs(op=qml.PauliY(0) @ qml.PauliZ(1)), + qml.var(op=qml.PauliZ(1)), + ) + + res1 = qml.QNode(circuit_op, dq)() + res2 = qml.QNode(circuit_op, dev)(shots=10) + for r1, r2 in zip(res1, res2): + if isinstance(r1, Sequence): + assert len(r1) == len(r2) + assert np.all(np.isnan(r1)) + assert np.all(np.isnan(r2)) + + def circuit_mcm(): + m = qml.measure(0, postselect=1) + qml.cond(m, qml.PauliX)(1) + return qml.expval(op=m), qml.probs(op=m), qml.var(op=m) + + res1 = qml.QNode(circuit_mcm, dq)() + res2 = qml.QNode(circuit_mcm, dev)(shots=10) + for r1, r2 in zip(res1, res2): + if isinstance(r1, Sequence): + assert len(r1) == len(r2) + assert np.all(np.isnan(r1)) + assert np.all(np.isnan(r2)) + + +def test_unsupported_measurement(): + dev = qml.device(device_name, wires=2, shots=1000) + params = np.pi / 4 * np.ones(2) + + @qml.qnode(dev) + def func(x, y): + qml.RX(x, wires=0) + m0 = qml.measure(0) + qml.cond(m0, qml.RY)(y, wires=1) + return qml.classical_shadow(wires=0) + + with pytest.raises( + TypeError, + match=f"Native mid-circuit measurement mode does not support {type(qml.classical_shadow(wires=0)).__name__}", + ): + func(*params) + + +@flaky(max_runs=5) +@pytest.mark.parametrize("shots", [5000, [5000, 5001]]) +@pytest.mark.parametrize("postselect", [None, 0, 1]) +@pytest.mark.parametrize("reset", [False, True]) +@pytest.mark.parametrize("measure_f", [qml.counts, qml.expval, qml.probs, qml.sample, qml.var]) +def test_simple_mcm(shots, postselect, reset, measure_f): + """Tests that LightningQubit handles a circuit with a single mid-circuit measurement and a + conditional gate. A single measurement of the mid-circuit measurement value is performed at + the end.""" + + dev = qml.device(device_name, wires=1, shots=shots) + dq = qml.device("default.qubit", shots=shots) + params = np.pi / 4 * np.ones(2) + + def func(x, y): + qml.RX(x, wires=0) + m0 = qml.measure(0, reset=reset, postselect=postselect) + qml.cond(m0, qml.RY)(y, wires=0) + return measure_f(op=qml.PauliZ(0)) + + func1 = qml.QNode(func, dev) + func2 = qml.defer_measurements(qml.QNode(func, dq)) + + results1 = func1(*params) + results2 = func2(*params) + + if postselect is None or measure_f in (qml.expval, qml.probs, qml.var): + validate_measurements(measure_f, shots, results1, results2) + + +@flaky(max_runs=5) +@pytest.mark.parametrize("shots", [1000, [1000, 1001]]) +@pytest.mark.parametrize("postselect", [None, 0, 1]) +@pytest.mark.parametrize("reset", [False, True]) +@pytest.mark.parametrize("measure_f", [qml.counts, qml.expval, qml.probs, qml.sample, qml.var]) +def test_single_mcm_single_measure_mcm(shots, postselect, reset, measure_f): + """Tests that LightningQubit handles a circuit with a single mid-circuit measurement and a + conditional gate. A single measurement of the mid-circuit measurement value is performed at + the end.""" + + dev = qml.device(device_name, wires=2, shots=shots) + dq = qml.device("default.qubit", shots=shots) + params = np.pi / 4 * np.ones(2) + + def func(x, y): + qml.RX(x, wires=0) + m0 = qml.measure(0, reset=reset, postselect=postselect) + qml.cond(m0, qml.RY)(y, wires=1) + return measure_f(op=m0) + + func1 = qml.QNode(func, dev) + func2 = qml.defer_measurements(qml.QNode(func, dq)) + + results1 = func1(*params) + results2 = func2(*params) + + if postselect is None or measure_f in (qml.expval, qml.probs, qml.var): + validate_measurements(measure_f, shots, results1, results2) + + +# pylint: disable=unused-argument +def obs_tape(x, y, z, reset=False, postselect=None): + qml.RX(x, 0) + qml.RZ(np.pi / 4, 0) + m0 = qml.measure(0, reset=reset) + qml.cond(m0 == 0, qml.RX)(np.pi / 4, 0) + qml.cond(m0 == 0, qml.RZ)(np.pi / 4, 0) + qml.cond(m0 == 1, qml.RX)(-np.pi / 4, 0) + qml.cond(m0 == 1, qml.RZ)(-np.pi / 4, 0) + qml.RX(y, 1) + qml.RZ(np.pi / 4, 1) + m1 = qml.measure(1, postselect=postselect) + qml.cond(m1 == 0, qml.RX)(np.pi / 4, 1) + qml.cond(m1 == 0, qml.RZ)(np.pi / 4, 1) + qml.cond(m1 == 1, qml.RX)(-np.pi / 4, 1) + qml.cond(m1 == 1, qml.RZ)(-np.pi / 4, 1) + return m0, m1 + + +@flaky(max_runs=5) +@pytest.mark.parametrize("shots", [5000, [5000, 5001]]) +@pytest.mark.parametrize("postselect", [None, 0, 1]) +@pytest.mark.parametrize("reset", [False, True]) +@pytest.mark.parametrize("measure_f", [qml.counts, qml.expval, qml.probs, qml.sample, qml.var]) +@pytest.mark.parametrize("obs", [qml.PauliZ(0), qml.PauliY(1), qml.PauliZ(0) @ qml.PauliY(1)]) +def test_single_mcm_single_measure_obs(shots, postselect, reset, measure_f, obs): + """Tests that LightningQubit handles a circuit with a single mid-circuit measurement and a + conditional gate. A single measurement of a common observable is performed at the end.""" + + dev = qml.device(device_name, wires=2, shots=shots) + dq = qml.device("default.qubit", shots=shots) + params = [np.pi / 7, np.pi / 6, -np.pi / 5] + + def func(x, y, z): + obs_tape(x, y, z, reset=reset, postselect=postselect) + return measure_f(op=obs) + + func1 = qml.QNode(func, dev) + func2 = qml.defer_measurements(qml.QNode(func, dq)) + + results1 = func1(*params) + results2 = func2(*params) + + if postselect is None or measure_f in (qml.expval, qml.probs, qml.var): + validate_measurements(measure_f, shots, results1, results2) + + +@flaky(max_runs=5) +@pytest.mark.parametrize("shots", [3000, [3000, 3001]]) +@pytest.mark.parametrize("postselect", [None, 0, 1]) +@pytest.mark.parametrize("reset", [False, True]) +@pytest.mark.parametrize("measure_f", [qml.counts, qml.probs, qml.sample]) +@pytest.mark.parametrize("wires", [[0], [0, 1]]) +def test_single_mcm_single_measure_wires(shots, postselect, reset, measure_f, wires): + """Tests that LightningQubit handles a circuit with a single mid-circuit measurement and a + conditional gate. A single measurement of one or several wires is performed at the end.""" + + dev = qml.device(device_name, wires=2, shots=shots) + dq = qml.device("default.qubit", shots=shots) + params = np.pi / 4 * np.ones(2) + + def func(x, y): + qml.RX(x, wires=0) + m0 = qml.measure(0, reset=reset, postselect=postselect) + qml.cond(m0, qml.RY)(y, wires=1) + return measure_f(wires=wires) + + func1 = qml.QNode(func, dev) + func2 = qml.defer_measurements(qml.QNode(func, dq)) + + results1 = func1(*params) + results2 = func2(*params) + + if postselect is None or measure_f in (qml.expval, qml.probs, qml.var): + validate_measurements(measure_f, shots, results1, results2) + + +@flaky(max_runs=5) +@pytest.mark.parametrize("shots", [5000]) +@pytest.mark.parametrize("postselect", [None, 0, 1]) +@pytest.mark.parametrize("reset", [False, True]) +@pytest.mark.parametrize("measure_f", [qml.counts, qml.expval, qml.probs, qml.sample, qml.var]) +def test_single_mcm_multiple_measurements(shots, postselect, reset, measure_f): + """Tests that LightningQubit handles a circuit with a single mid-circuit measurement with reset + and a conditional gate. Multiple measurements of the mid-circuit measurement value are + performed.""" + + dev = qml.device(device_name, wires=2, shots=shots) + dq = qml.device("default.qubit", shots=shots) + params = [np.pi / 7, np.pi / 6, -np.pi / 5] + obs = qml.PauliY(1) + + def func(x, y, z): + mcms = obs_tape(x, y, z, reset=reset, postselect=postselect) + return measure_f(op=obs), measure_f(op=mcms[0]) + + func1 = qml.QNode(func, dev) + func2 = qml.defer_measurements(qml.QNode(func, dq)) + + results1 = func1(*params) + results2 = func2(*params) + + if postselect is None or measure_f in (qml.expval, qml.probs, qml.var): + for r1, r2 in zip(results1, results2): + validate_measurements(measure_f, shots, r1, r2) + + +@flaky(max_runs=5) +@pytest.mark.parametrize("shots", [5000, [5000, 5001]]) +@pytest.mark.parametrize("postselect", [None, 0, 1]) +@pytest.mark.parametrize("reset", [False, True]) +@pytest.mark.parametrize("measure_f", [qml.counts, qml.expval, qml.sample, qml.var]) +def test_composite_mcm_measure_composite_mcm(shots, postselect, reset, measure_f): + """Tests that LightningQubit handles a circuit with a composite mid-circuit measurement and a + conditional gate. A single measurement of a composite mid-circuit measurement is performed + at the end.""" + + dev = qml.device(device_name, wires=2, shots=shots) + dq = qml.device("default.qubit", shots=shots) + param = np.pi / 3 + + def func(x): + qml.RX(x, 0) + m0 = qml.measure(0) + qml.RX(0.5 * x, 1) + m1 = qml.measure(1, reset=reset, postselect=postselect) + qml.cond((m0 + m1) == 2, qml.RY)(2.0 * x, 0) + m2 = qml.measure(0) + return measure_f(op=(m0 - 2 * m1) * m2 + 7) + + func1 = qml.QNode(func, dev) + func2 = qml.defer_measurements(qml.QNode(func, dq)) + + results1 = func1(param) + results2 = func2(param) + + if postselect is None or measure_f in (qml.expval, qml.probs, qml.var): + validate_measurements(measure_f, shots, results1, results2) + + +@flaky(max_runs=5) +@pytest.mark.parametrize("shots", [5000, [5000, 5001]]) +@pytest.mark.parametrize("postselect", [None, 0, 1]) +@pytest.mark.parametrize("reset", [False, True]) +@pytest.mark.parametrize("measure_f", [qml.counts, qml.expval, qml.probs, qml.sample, qml.var]) +def test_composite_mcm_single_measure_obs(shots, postselect, reset, measure_f): + """Tests that LightningQubit handles a circuit with a composite mid-circuit measurement and a + conditional gate. A single measurement of a common observable is performed at the end.""" + + dev = qml.device(device_name, wires=2, shots=shots) + dq = qml.device("default.qubit", shots=shots) + params = [np.pi / 7, np.pi / 6, -np.pi / 5] + obs = qml.PauliZ(0) @ qml.PauliY(1) + + def func(x, y, z): + mcms = obs_tape(x, y, z, reset=reset, postselect=postselect) + qml.cond(mcms[0] != mcms[1], qml.RY)(z, wires=0) + qml.cond(mcms[0] == mcms[1], qml.RY)(z, wires=1) + return measure_f(op=obs) + + func1 = qml.QNode(func, dev) + func2 = qml.defer_measurements(qml.QNode(func, dq)) + + results1 = func1(*params) + results2 = func2(*params) + + if postselect is None or measure_f in (qml.expval, qml.probs, qml.var): + validate_measurements(measure_f, shots, results1, results2) + + +@flaky(max_runs=5) +@pytest.mark.parametrize("shots", [5000, [5000, 5001]]) +@pytest.mark.parametrize("postselect", [None, 0, 1]) +@pytest.mark.parametrize("reset", [False, True]) +@pytest.mark.parametrize("measure_f", [qml.counts, qml.probs, qml.sample]) +def test_composite_mcm_measure_value_list(shots, postselect, reset, measure_f): + """Tests that LightningQubit handles a circuit with a composite mid-circuit measurement and a + conditional gate. A single measurement of a composite mid-circuit measurement is performed + at the end.""" + + dev = qml.device(device_name, wires=2, shots=shots) + dq = qml.device("default.qubit", shots=shots) + param = np.pi / 3 + + def func(x): + qml.RX(x, 0) + m0 = qml.measure(0) + qml.RX(0.5 * x, 1) + m1 = qml.measure(1, reset=reset, postselect=postselect) + qml.cond((m0 + m1) == 2, qml.RY)(2.0 * x, 0) + m2 = qml.measure(0) + return measure_f(op=[m0, m1, m2]) + + func1 = qml.QNode(func, dev) + func2 = qml.defer_measurements(qml.QNode(func, dq)) + + results1 = func1(param) + results2 = func2(param) + + validate_measurements(measure_f, shots, results1, results2)