Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update LightningQubit to adhere to MCM qnode arguments #736

Merged
merged 31 commits into from
Jun 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
66840b0
Updated LQ to work with runtime kwargs
mudit2812 May 21, 2024
4f446d3
Auto update version from '0.37.0-dev14' to '0.37.0-dev15'
ringo-but-quantum May 21, 2024
be8702d
Update changelog
mudit2812 May 21, 2024
0689e61
[skip ci] Skip CI
mudit2812 May 21, 2024
e21874e
Added tests
mudit2812 May 22, 2024
479f599
Auto update version from '0.37.0-dev15' to '0.37.0-dev16'
ringo-but-quantum May 22, 2024
7ea8eee
Merge branch 'master' into mcm-kwargs
mudit2812 May 22, 2024
39f08ca
Trigger CI
mudit2812 May 22, 2024
e434082
Trigger CI
mudit2812 May 22, 2024
dc81cd3
Fixed test
mudit2812 May 22, 2024
9314278
Fix preprocess test
mudit2812 May 22, 2024
4e29899
Auto update version from '0.37.0-dev16' to '0.37.0-dev17'
ringo-but-quantum May 22, 2024
18c30ab
Merge branch 'master' into mcm-kwargs
mudit2812 May 22, 2024
36faf75
Updated tests
mudit2812 May 22, 2024
83cbd39
Changed postselect_shots to postselect_mode
mudit2812 May 23, 2024
ac019ac
Auto update version from '0.37.0-dev17' to '0.37.0-dev18'
ringo-but-quantum May 23, 2024
d4cfb4d
Trigger CI
mudit2812 May 23, 2024
914ae8b
Merge branch 'master' into mcm-kwargs
mudit2812 May 24, 2024
f7de3e2
Merge branch 'master' into mcm-kwargs
mudit2812 May 31, 2024
be64657
Auto update version from '0.37.0-dev23' to '0.37.0-dev24'
ringo-but-quantum May 31, 2024
3b06152
Updated LQ and LK for mcm config support
mudit2812 Jun 3, 2024
3c44bd4
Auto update version from '0.37.0-dev24' to '0.37.0-dev25'
ringo-but-quantum Jun 3, 2024
97397cd
Update default postselect_mode to None
mudit2812 Jun 3, 2024
b09734f
Auto update version from '0.37.0-dev25' to '0.37.0-dev26'
ringo-but-quantum Jun 3, 2024
feaabaa
Merge branch 'master' into mcm-kwargs
mudit2812 Jun 3, 2024
aea378c
Linting
mudit2812 Jun 3, 2024
9986c7a
Fixed failing test
mudit2812 Jun 4, 2024
6f92dbb
Merge branch 'master' into mcm-kwargs
mudit2812 Jun 4, 2024
61bf035
Auto update version from '0.37.0-dev26' to '0.37.0-dev27'
ringo-but-quantum Jun 4, 2024
c854b6a
Update requirements-dev.txt
mudit2812 Jun 4, 2024
0459ede
Trigger CI
mudit2812 Jun 5, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion .github/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,14 +53,16 @@
* Changed the name of `lightning.tensor` to `default.tensor` with the `quimb` backend.
[(#719)](https://github.com/PennyLaneAI/pennylane-lightning/pull/719)

* `lightning.qubit` and `lightning.kokkos` adhere to user specified mid-circuit measurement configuration options.
[(#736)](https://github.com/PennyLaneAI/pennylane-lightning/pull/736)

* Patch the C++ `Measurements.probs(wires)` method in Lightning-Qubit and Lighnting-Kokkos to `Measurements.probs()` when called with all wires.
This will trigger a more optimized implementation for calculating the probabilities of the entire system.
[(#744)](https://github.com/PennyLaneAI/pennylane-lightning/pull/744)

* Remove the daily schedule from the "Compat Check w/PL - release/release" GitHub action.
[(#746)](https://github.com/PennyLaneAI/pennylane-lightning/pull/746)


### Documentation

### Bug fixes
Expand Down
2 changes: 1 addition & 1 deletion pennylane_lightning/core/_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,4 @@
Version number (major.minor.patch[-label])
"""

__version__ = "0.37.0-dev26"
__version__ = "0.37.0-dev27"
25 changes: 17 additions & 8 deletions pennylane_lightning/lightning_kokkos/lightning_kokkos.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,9 +224,7 @@
@classmethod
def capabilities(cls):
capabilities = super().capabilities().copy()
capabilities.update(
supports_mid_measure=True,
)
capabilities.update(supports_mid_measure=True)
return capabilities

@staticmethod
Expand Down Expand Up @@ -383,22 +381,29 @@
num = self._get_basis_state_index(state, wires)
self._create_basis_state(num)

def _apply_lightning_midmeasure(self, operation: MidMeasureMP, mid_measurements: dict):
def _apply_lightning_midmeasure(
self, operation: MidMeasureMP, mid_measurements: dict, postselect_mode: str
):
"""Execute a MidMeasureMP operation and return the sample in mid_measurements.

Args:
operation (~pennylane.operation.Operation): mid-circuit measurement

Returns:
None
"""
wires = self.wires.indices(operation.wires)
wire = list(wires)[0]
sample = qml.math.reshape(self.generate_samples(shots=1), (-1,))[wire]
if postselect_mode == "fill-shots" and operation.postselect is not None:
sample = operation.postselect

Check warning on line 398 in pennylane_lightning/lightning_kokkos/lightning_kokkos.py

View check run for this annotation

Codecov / codecov/patch

pennylane_lightning/lightning_kokkos/lightning_kokkos.py#L398

Added line #L398 was not covered by tests
else:
sample = qml.math.reshape(self.generate_samples(shots=1), (-1,))[wire]
mid_measurements[operation] = sample
getattr(self.state_vector, "collapse")(wire, bool(sample))
if operation.reset and bool(sample):
self.apply([qml.PauliX(operation.wires)], mid_measurements=mid_measurements)

def apply_lightning(self, operations, mid_measurements=None):
def apply_lightning(self, operations, mid_measurements=None, postselect_mode=None):
"""Apply a list of operations to the state tensor.

Args:
Expand Down Expand Up @@ -429,7 +434,7 @@
if ops.meas_val.concretize(mid_measurements):
self.apply_lightning([ops.then_op])
elif isinstance(ops, MidMeasureMP):
self._apply_lightning_midmeasure(ops, mid_measurements)
self._apply_lightning_midmeasure(ops, mid_measurements, postselect_mode)
elif isinstance(ops, qml.ops.op_math.Controlled) and isinstance(
ops.base, qml.GlobalPhase
):
Expand Down Expand Up @@ -471,14 +476,18 @@
self._apply_basis_state(operations[0].parameters[0], operations[0].wires)
operations = operations[1:]

postselect_mode = kwargs.get("postselect_mode", None)

for operation in operations:
if isinstance(operation, (StatePrep, BasisState)):
raise DeviceError(
f"Operation {operation.name} cannot be used after other "
+ f"Operations have already been applied on a {self.short_name} device."
)

self.apply_lightning(operations, mid_measurements=mid_measurements)
self.apply_lightning(
operations, mid_measurements=mid_measurements, postselect_mode=postselect_mode
)

# pylint: disable=protected-access
def expval(self, observable, shot_range=None, bin_size=None):
Expand Down
52 changes: 40 additions & 12 deletions pennylane_lightning/lightning_qubit/_state_vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,33 +250,45 @@ def _apply_lightning_controlled(self, operation):
False,
)

def _apply_lightning_midmeasure(self, operation: MidMeasureMP, mid_measurements: dict):
def _apply_lightning_midmeasure(
self, operation: MidMeasureMP, mid_measurements: dict, postselect_mode: str
):
"""Execute a MidMeasureMP operation and return the sample in mid_measurements.

Args:
operation (~pennylane.operation.Operation): mid-circuit measurement
mid_measurements (None, dict): Dictionary of mid-circuit measurements
postselect_mode (str): Configuration for handling shots with mid-circuit measurement
postselection. Use ``"hw-like"`` to discard invalid shots and ``"fill-shots"`` to
keep the same number of shots.

Returns:
None
"""
wires = self.wires.indices(operation.wires)
wire = list(wires)[0]
circuit = QuantumScript([], [qml.sample(wires=operation.wires)], shots=1)
sample = LightningMeasurements(self).measure_final_state(circuit)
sample = np.squeeze(sample)
if operation.postselect is not None and sample != operation.postselect:
mid_measurements[operation] = -1
return
if postselect_mode == "fill-shots" and operation.postselect is not None:
sample = operation.postselect
else:
sample = LightningMeasurements(self).measure_final_state(circuit)
sample = np.squeeze(sample)
mid_measurements[operation] = sample
getattr(self.state_vector, "collapse")(wire, bool(sample))
if operation.reset and bool(sample):
self.apply_operations([qml.PauliX(operation.wires)], mid_measurements=mid_measurements)

def _apply_lightning(self, operations, mid_measurements: dict = None):
def _apply_lightning(
self, operations, mid_measurements: dict = None, postselect_mode: str = None
):
"""Apply a list of operations to the state tensor.

Args:
operations (list[~pennylane.operation.Operation]): operations to apply
mid_measurements (None, dict): Dictionary of mid-circuit measurements
postselect_mode (str): Configuration for handling shots with mid-circuit measurement
postselection. Use ``"hw-like"`` to discard invalid shots and ``"fill-shots"`` to
keep the same number of shots. Default is ``None``.

Returns:
None
Expand All @@ -301,7 +313,9 @@ def _apply_lightning(self, operations, mid_measurements: dict = None):
if operation.meas_val.concretize(mid_measurements):
self._apply_lightning([operation.then_op])
elif isinstance(operation, MidMeasureMP):
self._apply_lightning_midmeasure(operation, mid_measurements)
self._apply_lightning_midmeasure(
operation, mid_measurements, postselect_mode=postselect_mode
)
elif method is not None: # apply specialized gate
param = operation.parameters
method(wires, invert_param, param)
Expand All @@ -317,7 +331,9 @@ def _apply_lightning(self, operations, mid_measurements: dict = None):
# To support older versions of PL
method(operation.matrix, wires, False)

def apply_operations(self, operations, mid_measurements: dict = None):
def apply_operations(
self, operations, mid_measurements: dict = None, postselect_mode: str = None
):
"""Applies operations to the state vector."""
# State preparation is currently done in Python
if operations: # make sure operations[0] exists
Expand All @@ -328,9 +344,16 @@ def apply_operations(self, operations, mid_measurements: dict = None):
self._apply_basis_state(operations[0].parameters[0], operations[0].wires)
operations = operations[1:]

self._apply_lightning(operations, mid_measurements=mid_measurements)
self._apply_lightning(
operations, mid_measurements=mid_measurements, postselect_mode=postselect_mode
)

def get_final_state(self, circuit: QuantumScript, mid_measurements: dict = None):
def get_final_state(
self,
circuit: QuantumScript,
mid_measurements: dict = None,
postselect_mode: str = None,
):
"""
Get the final state that results from executing the given quantum script.

Expand All @@ -339,11 +362,16 @@ def get_final_state(self, circuit: QuantumScript, mid_measurements: dict = None)
Args:
circuit (QuantumScript): The single circuit to simulate
mid_measurements (None, dict): Dictionary of mid-circuit measurements
postselect_mode (str): Configuration for handling shots with mid-circuit measurement
postselection. Use ``"hw-like"`` to discard invalid shots and ``"fill-shots"`` to
keep the same number of shots. Default is ``None``.

Returns:
LightningStateVector: Lightning final state class.

"""
self.apply_operations(circuit.operations, mid_measurements=mid_measurements)
self.apply_operations(
circuit.operations, mid_measurements=mid_measurements, postselect_mode=postselect_mode
)

return self
27 changes: 23 additions & 4 deletions pennylane_lightning/lightning_qubit/lightning_qubit.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,12 @@
PostprocessingFn = Callable[[ResultBatch], Result_or_ResultBatch]


def simulate(circuit: QuantumScript, state: LightningStateVector, mcmc: dict = None) -> Result:
def simulate(
circuit: QuantumScript,
state: LightningStateVector,
mcmc: dict = None,
postselect_mode: str = None,
) -> Result:
"""Simulate a single quantum script.

Args:
Expand All @@ -67,6 +72,9 @@ def simulate(circuit: QuantumScript, state: LightningStateVector, mcmc: dict = N
mcmc (dict): Dictionary containing the Markov Chain Monte Carlo
parameters: mcmc, kernel_name, num_burnin. Descriptions of
these fields are found in :class:`~.LightningQubit`.
postselect_mode (str): Configuration for handling shots with mid-circuit measurement
postselection. Use ``"hw-like"`` to discard invalid shots and ``"fill-shots"`` to
keep the same number of shots. Default is ``None``.

Returns:
Tuple[TensorLike]: The results of the simulation
Expand All @@ -88,7 +96,9 @@ def simulate(circuit: QuantumScript, state: LightningStateVector, mcmc: dict = N
for _ in range(circuit.shots.total_shots):
state.reset_state()
mid_measurements = {}
final_state = state.get_final_state(aux_circ, mid_measurements=mid_measurements)
final_state = state.get_final_state(
aux_circ, mid_measurements=mid_measurements, postselect_mode=postselect_mode
)
results.append(
LightningMeasurements(final_state, **mcmc).measure_final_state(
aux_circ, mid_measurements=mid_measurements
Expand Down Expand Up @@ -571,7 +581,9 @@ def preprocess(self, execution_config: ExecutionConfig = DefaultExecutionConfig)
program.add_transform(validate_measurements, name=self.name)
program.add_transform(validate_observables, accepted_observables, name=self.name)
program.add_transform(validate_device_wires, self.wires, name=self.name)
program.add_transform(mid_circuit_measurements, device=self)
program.add_transform(
mid_circuit_measurements, device=self, mcm_config=exec_config.mcm_config
)
program.add_transform(
decompose,
stopping_condition=stopping_condition,
Expand Down Expand Up @@ -609,7 +621,14 @@ def execute(
for circuit in circuits:
if self._wire_map is not None:
[circuit], _ = qml.map_wires(circuit, self._wire_map)
results.append(simulate(circuit, self._statevector, mcmc=mcmc))
results.append(
simulate(
circuit,
self._statevector,
mcmc=mcmc,
postselect_mode=execution_config.mcm_config.postselect_mode,
)
)

return tuple(results)

Expand Down
6 changes: 4 additions & 2 deletions tests/new_api/test_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import pennylane as qml
import pytest
from conftest import PHI, THETA, VARPHI, LightningDevice
from pennylane.devices import DefaultExecutionConfig, DefaultQubit, ExecutionConfig
from pennylane.devices import DefaultExecutionConfig, DefaultQubit, ExecutionConfig, MCMConfig
from pennylane.devices.default_qubit import adjoint_ops
from pennylane.tape import QuantumScript

Expand Down Expand Up @@ -259,7 +259,9 @@ def test_preprocess(self, adjoint):
expected_program.add_transform(validate_measurements, name=device.name)
expected_program.add_transform(validate_observables, accepted_observables, name=device.name)
expected_program.add_transform(validate_device_wires, device.wires, name=device.name)
expected_program.add_transform(mid_circuit_measurements, device=device)
expected_program.add_transform(
mid_circuit_measurements, device=device, mcm_config=MCMConfig()
)
expected_program.add_transform(
decompose,
stopping_condition=stopping_condition,
Expand Down
55 changes: 55 additions & 0 deletions tests/test_native_mcm.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,61 @@ def func(x, y):
func(*params)


@pytest.mark.parametrize("mcm_method", ["deferred", "one-shot"])
def test_qnode_mcm_method(mcm_method, mocker):
"""Test that user specified qnode arg for mid-circuit measurements transform are used correctly"""
spy = (
mocker.spy(qml.dynamic_one_shot, "_transform")
if mcm_method == "one-shot"
else mocker.spy(qml.defer_measurements, "_transform")
)
other_spy = (
mocker.spy(qml.defer_measurements, "_transform")
if mcm_method == "one-shot"
else mocker.spy(qml.dynamic_one_shot, "_transform")
)

shots = 10
device = qml.device(device_name, wires=3, shots=shots)

@qml.qnode(device, mcm_method=mcm_method)
def f(x):
qml.RX(x, 0)
_ = qml.measure(0)
qml.CNOT([0, 1])
return qml.sample(wires=[0, 1])

_ = f(np.pi / 8)

spy.assert_called_once()
other_spy.assert_not_called()


@pytest.mark.parametrize("postselect_mode", ["hw-like", "fill-shots"])
def test_qnode_postselect_mode(postselect_mode):
"""Test that user specified qnode arg for discarding invalid shots is used correctly"""
shots = 100
device = qml.device(device_name, wires=3, shots=shots)
postselect = 1

@qml.qnode(device, postselect_mode=postselect_mode)
def f(x):
qml.RX(x, 0)
_ = qml.measure(0, postselect=postselect)
qml.CNOT([0, 1])
return qml.sample(wires=[1])

# Using small-ish rotation angle ensures the number of valid shots will be less than the
# original number of shots. This helps avoid stochastic failures for the assertion below
res = f(np.pi / 2)

if postselect_mode == "hw-like":
assert len(res) < shots
else:
assert len(res) == shots
assert np.allclose(res, postselect)


@flaky(max_runs=5)
@pytest.mark.parametrize("shots", [5000, [5000, 5001]])
@pytest.mark.parametrize("postselect", [None, 0, 1])
Expand Down
Loading