From 2c39a2bf1f913376ff6b6fbdf08a78f777f91df0 Mon Sep 17 00:00:00 2001 From: Mudit Pandey Date: Fri, 22 Mar 2024 14:33:11 -0400 Subject: [PATCH 1/9] Fixed LQ adjoint decomp --- .github/CHANGELOG.md | 3 ++ .../lightning_qubit/lightning_qubit.py | 4 +- tests/new_api/test_device.py | 53 +++++++++++++++++++ 3 files changed, 59 insertions(+), 1 deletion(-) diff --git a/.github/CHANGELOG.md b/.github/CHANGELOG.md index 7fd030a55d..b8d961b07e 100644 --- a/.github/CHANGELOG.md +++ b/.github/CHANGELOG.md @@ -47,6 +47,9 @@ * Fix random `coverage xml` CI issues. [(#635)](https://github.com/PennyLaneAI/pennylane-lightning/pull/635) +* `lightning.qubit` correctly decomposed state preparation operations with adjoint differentiation. + [(#)]() + ### Contributors This release contains contributions from (in alphabetical order): diff --git a/pennylane_lightning/lightning_qubit/lightning_qubit.py b/pennylane_lightning/lightning_qubit/lightning_qubit.py index 5f42947376..f41fa8f036 100644 --- a/pennylane_lightning/lightning_qubit/lightning_qubit.py +++ b/pennylane_lightning/lightning_qubit/lightning_qubit.py @@ -268,7 +268,9 @@ def _add_adjoint_transforms(program: TransformProgram) -> None: name = "adjoint + lightning.qubit" program.add_transform(no_sampling, name=name) - program.add_transform(decompose, stopping_condition=adjoint_ops, name=name) + program.add_transform( + decompose, stopping_condition=adjoint_ops, name=name, skip_initial_state_prep=False + ) program.add_transform(validate_observables, accepted_observables, name=name) program.add_transform( validate_measurements, analytic_measurements=adjoint_measurements, name=name diff --git a/tests/new_api/test_device.py b/tests/new_api/test_device.py index 6407ed60b5..6d11be5199 100644 --- a/tests/new_api/test_device.py +++ b/tests/new_api/test_device.py @@ -552,6 +552,59 @@ def test_derivatives_no_trainable_params(self, dev, execute_and_derivatives, bat assert len(jac) == 1 assert qml.math.shape(jac[0]) == (0,) + @pytest.mark.parametrize("execute_and_derivatives", [True, False]) + @pytest.mark.parametrize( + "state_prep, params, wires", + [ + (qml.BasisState, [1, 1], [0, 1]), + (qml.StatePrep, [0.0, 0.0, 0.0, 1.0], [0, 1]), + (qml.StatePrep, qml.numpy.array([0.0, 1.0]), [1]), + ], + ) + @pytest.mark.parametrize( + "trainable_params", + [(0, 1, 2), (1, 2)], + ) + def test_state_prep_ops( + self, dev, state_prep, params, wires, execute_and_derivatives, batch_obs, trainable_params + ): + """Test that a circuit containing state prep operations is differentiated correctly.""" + qs = QuantumScript( + [state_prep(params, wires), qml.RX(1.23, 0), qml.CNOT([0, 1]), qml.RX(4.56, 1)], + [qml.expval(qml.PauliZ(1))], + ) + + config = ExecutionConfig(gradient_method="adjoint", device_options={"batch_obs": batch_obs}) + program, new_config = dev.preprocess(config) + tapes, fn = program([qs]) + tapes[0].trainable_params = trainable_params + if execute_and_derivatives: + res, jac = dev.execute_and_compute_derivatives(tapes, new_config) + res = fn(res) + else: + res, jac = ( + fn(dev.execute(tapes, new_config)), + dev.compute_derivatives(tapes, new_config), + ) + + dev_ref = DefaultQubit(max_workers=1) + config = ExecutionConfig(gradient_method="adjoint") + program, new_config = dev_ref.preprocess(config) + tapes, fn = program([qs]) + tapes[0].trainable_params = trainable_params + if execute_and_derivatives: + expected, expected_jac = dev_ref.execute_and_compute_derivatives(tapes, new_config) + expected = fn(expected) + else: + expected, expected_jac = ( + fn(dev_ref.execute(tapes, new_config)), + dev_ref.compute_derivatives(tapes, new_config), + ) + + tol = 1e-5 if dev.c_dtype == np.complex64 else 1e-7 + assert np.allclose(res, expected, atol=tol, rtol=0) + assert np.allclose(jac, expected_jac, atol=tol, rtol=0) + def test_state_jacobian_not_supported(self, dev, batch_obs): """Test that an error is raised if derivatives are requested for state measurement""" qs = QuantumScript([qml.RX(1.23, 0)], [qml.state()], trainable_params=[0]) From 00f61dd955900dc669593d162e833239c8fe2698 Mon Sep 17 00:00:00 2001 From: Dev version update bot Date: Fri, 22 Mar 2024 18:35:37 +0000 Subject: [PATCH 2/9] Auto update version --- pennylane_lightning/core/_version.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pennylane_lightning/core/_version.py b/pennylane_lightning/core/_version.py index 778353eb5b..4089152989 100644 --- a/pennylane_lightning/core/_version.py +++ b/pennylane_lightning/core/_version.py @@ -16,4 +16,4 @@ Version number (major.minor.patch[-label]) """ -__version__ = "0.36.0-dev16" +__version__ = "0.36.0-dev17" From 5a560869aa69f97a5924ecf045694aed1dca981b Mon Sep 17 00:00:00 2001 From: Mudit Pandey Date: Fri, 22 Mar 2024 14:39:01 -0400 Subject: [PATCH 3/9] Trigger CI From 6e05bd3d589352dbf287b6c0316b0dd66d09aa4d Mon Sep 17 00:00:00 2001 From: Mudit Pandey Date: Fri, 22 Mar 2024 14:45:39 -0400 Subject: [PATCH 4/9] Apply suggestions from code review Co-authored-by: Ali Asadi <10773383+maliasadi@users.noreply.github.com> --- .github/CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/CHANGELOG.md b/.github/CHANGELOG.md index b8d961b07e..1bd153e52c 100644 --- a/.github/CHANGELOG.md +++ b/.github/CHANGELOG.md @@ -48,7 +48,7 @@ [(#635)](https://github.com/PennyLaneAI/pennylane-lightning/pull/635) * `lightning.qubit` correctly decomposed state preparation operations with adjoint differentiation. - [(#)]() + [(#661)](https://github.com/PennyLaneAI/pennylane-lightning/pull/661) ### Contributors From c2681f3259419998496a9a45dbd2e364aba6f5c6 Mon Sep 17 00:00:00 2001 From: Dev version update bot Date: Fri, 22 Mar 2024 18:45:52 +0000 Subject: [PATCH 5/9] Auto update version --- pennylane_lightning/core/_version.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pennylane_lightning/core/_version.py b/pennylane_lightning/core/_version.py index 4089152989..ec9ead0f50 100644 --- a/pennylane_lightning/core/_version.py +++ b/pennylane_lightning/core/_version.py @@ -16,4 +16,4 @@ Version number (major.minor.patch[-label]) """ -__version__ = "0.36.0-dev17" +__version__ = "0.36.0-dev18" From 55f83fde774d06d1a2032dc08f47c0c5536c43e5 Mon Sep 17 00:00:00 2001 From: Vincent Michaud-Rioux Date: Fri, 22 Mar 2024 14:50:11 -0400 Subject: [PATCH 6/9] trigger ci From 6e540bf103536762a375f246319b76834d6547e7 Mon Sep 17 00:00:00 2001 From: Vincent Michaud-Rioux Date: Fri, 22 Mar 2024 14:50:21 -0400 Subject: [PATCH 7/9] trigger ci From 9dd53949e57515403cd5c01d6e99536b0533b70c Mon Sep 17 00:00:00 2001 From: Mudit Pandey Date: Fri, 22 Mar 2024 19:30:43 -0400 Subject: [PATCH 8/9] Added fix for MCM with adjoint --- pennylane_lightning/lightning_qubit/lightning_qubit.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pennylane_lightning/lightning_qubit/lightning_qubit.py b/pennylane_lightning/lightning_qubit/lightning_qubit.py index b3e6ada608..53df30e45d 100644 --- a/pennylane_lightning/lightning_qubit/lightning_qubit.py +++ b/pennylane_lightning/lightning_qubit/lightning_qubit.py @@ -263,7 +263,7 @@ def _supports_adjoint(circuit): try: prog((circuit,)) - except (qml.operation.DecompositionUndefinedError, qml.DeviceError): + except (qml.operation.DecompositionUndefinedError, qml.DeviceError, AttributeError): return False return True From c83caef4ce1d6fdccb2b8f069b40f157f191cb93 Mon Sep 17 00:00:00 2001 From: Mudit Pandey Date: Fri, 22 Mar 2024 20:05:47 -0400 Subject: [PATCH 9/9] Fixed LQ tests --- tests/new_api/test_device.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/new_api/test_device.py b/tests/new_api/test_device.py index e3658c9900..accb8d17c7 100644 --- a/tests/new_api/test_device.py +++ b/tests/new_api/test_device.py @@ -86,6 +86,7 @@ def test_add_adjoint_transforms(self): decompose, stopping_condition=adjoint_ops, name=name, + skip_initial_state_prep=False, ) expected_program.add_transform(validate_observables, accepted_observables, name=name) expected_program.add_transform( @@ -248,6 +249,7 @@ def test_preprocess(self, adjoint): decompose, stopping_condition=adjoint_ops, name=name, + skip_initial_state_prep=False, ) expected_program.add_transform(validate_observables, accepted_observables, name=name) expected_program.add_transform(