Skip to content

Commit

Permalink
Fix state prep operation decomposition with LightningQubit (#661)
Browse files Browse the repository at this point in the history
* Fixed LQ adjoint decomp

* Auto update version

* Trigger CI

* Apply suggestions from code review

Co-authored-by: Ali Asadi <10773383+maliasadi@users.noreply.github.com>

* Auto update version

* trigger ci

* trigger ci

* Added fix for MCM with adjoint

* Fixed LQ tests

---------

Co-authored-by: Dev version update bot <github-actions[bot]@users.noreply.github.com>
Co-authored-by: Ali Asadi <10773383+maliasadi@users.noreply.github.com>
Co-authored-by: Vincent Michaud-Rioux <vincentm@nanoacademic.com>
Co-authored-by: Vincent Michaud-Rioux <vincent.michaud-rioux@xanadu.ai>
  • Loading branch information
5 people authored Mar 25, 2024
1 parent 5f95a2f commit 7e77c5d
Show file tree
Hide file tree
Showing 4 changed files with 63 additions and 3 deletions.
3 changes: 3 additions & 0 deletions .github/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,9 @@
* Fix random `coverage xml` CI issues.
[(#635)](https://github.com/PennyLaneAI/pennylane-lightning/pull/635)

* `lightning.qubit` correctly decomposed state preparation operations with adjoint differentiation.
[(#661)](https://github.com/PennyLaneAI/pennylane-lightning/pull/661)

### Contributors

This release contains contributions from (in alphabetical order):
Expand Down
2 changes: 1 addition & 1 deletion pennylane_lightning/core/_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,4 @@
Version number (major.minor.patch[-label])
"""

__version__ = "0.36.0-dev17"
__version__ = "0.36.0-dev18"
6 changes: 4 additions & 2 deletions pennylane_lightning/lightning_qubit/lightning_qubit.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,7 @@ def _supports_adjoint(circuit):

try:
prog((circuit,))
except (qml.operation.DecompositionUndefinedError, qml.DeviceError):
except (qml.operation.DecompositionUndefinedError, qml.DeviceError, AttributeError):
return False
return True

Expand All @@ -282,7 +282,9 @@ def _add_adjoint_transforms(program: TransformProgram) -> None:

name = "adjoint + lightning.qubit"
program.add_transform(no_sampling, name=name)
program.add_transform(decompose, stopping_condition=adjoint_ops, name=name)
program.add_transform(
decompose, stopping_condition=adjoint_ops, name=name, skip_initial_state_prep=False
)
program.add_transform(validate_observables, accepted_observables, name=name)
program.add_transform(
validate_measurements, analytic_measurements=adjoint_measurements, name=name
Expand Down
55 changes: 55 additions & 0 deletions tests/new_api/test_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ def test_add_adjoint_transforms(self):
decompose,
stopping_condition=adjoint_ops,
name=name,
skip_initial_state_prep=False,
)
expected_program.add_transform(validate_observables, accepted_observables, name=name)
expected_program.add_transform(
Expand Down Expand Up @@ -248,6 +249,7 @@ def test_preprocess(self, adjoint):
decompose,
stopping_condition=adjoint_ops,
name=name,
skip_initial_state_prep=False,
)
expected_program.add_transform(validate_observables, accepted_observables, name=name)
expected_program.add_transform(
Expand Down Expand Up @@ -554,6 +556,59 @@ def test_derivatives_no_trainable_params(self, dev, execute_and_derivatives, bat
assert len(jac) == 1
assert qml.math.shape(jac[0]) == (0,)

@pytest.mark.parametrize("execute_and_derivatives", [True, False])
@pytest.mark.parametrize(
"state_prep, params, wires",
[
(qml.BasisState, [1, 1], [0, 1]),
(qml.StatePrep, [0.0, 0.0, 0.0, 1.0], [0, 1]),
(qml.StatePrep, qml.numpy.array([0.0, 1.0]), [1]),
],
)
@pytest.mark.parametrize(
"trainable_params",
[(0, 1, 2), (1, 2)],
)
def test_state_prep_ops(
self, dev, state_prep, params, wires, execute_and_derivatives, batch_obs, trainable_params
):
"""Test that a circuit containing state prep operations is differentiated correctly."""
qs = QuantumScript(
[state_prep(params, wires), qml.RX(1.23, 0), qml.CNOT([0, 1]), qml.RX(4.56, 1)],
[qml.expval(qml.PauliZ(1))],
)

config = ExecutionConfig(gradient_method="adjoint", device_options={"batch_obs": batch_obs})
program, new_config = dev.preprocess(config)
tapes, fn = program([qs])
tapes[0].trainable_params = trainable_params
if execute_and_derivatives:
res, jac = dev.execute_and_compute_derivatives(tapes, new_config)
res = fn(res)
else:
res, jac = (
fn(dev.execute(tapes, new_config)),
dev.compute_derivatives(tapes, new_config),
)

dev_ref = DefaultQubit(max_workers=1)
config = ExecutionConfig(gradient_method="adjoint")
program, new_config = dev_ref.preprocess(config)
tapes, fn = program([qs])
tapes[0].trainable_params = trainable_params
if execute_and_derivatives:
expected, expected_jac = dev_ref.execute_and_compute_derivatives(tapes, new_config)
expected = fn(expected)
else:
expected, expected_jac = (
fn(dev_ref.execute(tapes, new_config)),
dev_ref.compute_derivatives(tapes, new_config),
)

tol = 1e-5 if dev.c_dtype == np.complex64 else 1e-7
assert np.allclose(res, expected, atol=tol, rtol=0)
assert np.allclose(jac, expected_jac, atol=tol, rtol=0)

def test_state_jacobian_not_supported(self, dev, batch_obs):
"""Test that an error is raised if derivatives are requested for state measurement"""
qs = QuantumScript([qml.RX(1.23, 0)], [qml.state()], trainable_params=[0])
Expand Down

0 comments on commit 7e77c5d

Please sign in to comment.