diff --git a/.github/CHANGELOG.md b/.github/CHANGELOG.md index 462fe9eb8b..bb55e38510 100644 --- a/.github/CHANGELOG.md +++ b/.github/CHANGELOG.md @@ -50,6 +50,9 @@ * Fix random `coverage xml` CI issues. [(#635)](https://github.com/PennyLaneAI/pennylane-lightning/pull/635) +* `lightning.qubit` correctly decomposed state preparation operations with adjoint differentiation. + [(#661)](https://github.com/PennyLaneAI/pennylane-lightning/pull/661) + ### Contributors This release contains contributions from (in alphabetical order): diff --git a/pennylane_lightning/core/_version.py b/pennylane_lightning/core/_version.py index 4089152989..ec9ead0f50 100644 --- a/pennylane_lightning/core/_version.py +++ b/pennylane_lightning/core/_version.py @@ -16,4 +16,4 @@ Version number (major.minor.patch[-label]) """ -__version__ = "0.36.0-dev17" +__version__ = "0.36.0-dev18" diff --git a/pennylane_lightning/lightning_qubit/lightning_qubit.py b/pennylane_lightning/lightning_qubit/lightning_qubit.py index 7499f0190e..53df30e45d 100644 --- a/pennylane_lightning/lightning_qubit/lightning_qubit.py +++ b/pennylane_lightning/lightning_qubit/lightning_qubit.py @@ -263,7 +263,7 @@ def _supports_adjoint(circuit): try: prog((circuit,)) - except (qml.operation.DecompositionUndefinedError, qml.DeviceError): + except (qml.operation.DecompositionUndefinedError, qml.DeviceError, AttributeError): return False return True @@ -282,7 +282,9 @@ def _add_adjoint_transforms(program: TransformProgram) -> None: name = "adjoint + lightning.qubit" program.add_transform(no_sampling, name=name) - program.add_transform(decompose, stopping_condition=adjoint_ops, name=name) + program.add_transform( + decompose, stopping_condition=adjoint_ops, name=name, skip_initial_state_prep=False + ) program.add_transform(validate_observables, accepted_observables, name=name) program.add_transform( validate_measurements, analytic_measurements=adjoint_measurements, name=name diff --git a/tests/new_api/test_device.py b/tests/new_api/test_device.py index 3a0461d55d..accb8d17c7 100644 --- a/tests/new_api/test_device.py +++ b/tests/new_api/test_device.py @@ -86,6 +86,7 @@ def test_add_adjoint_transforms(self): decompose, stopping_condition=adjoint_ops, name=name, + skip_initial_state_prep=False, ) expected_program.add_transform(validate_observables, accepted_observables, name=name) expected_program.add_transform( @@ -248,6 +249,7 @@ def test_preprocess(self, adjoint): decompose, stopping_condition=adjoint_ops, name=name, + skip_initial_state_prep=False, ) expected_program.add_transform(validate_observables, accepted_observables, name=name) expected_program.add_transform( @@ -554,6 +556,59 @@ def test_derivatives_no_trainable_params(self, dev, execute_and_derivatives, bat assert len(jac) == 1 assert qml.math.shape(jac[0]) == (0,) + @pytest.mark.parametrize("execute_and_derivatives", [True, False]) + @pytest.mark.parametrize( + "state_prep, params, wires", + [ + (qml.BasisState, [1, 1], [0, 1]), + (qml.StatePrep, [0.0, 0.0, 0.0, 1.0], [0, 1]), + (qml.StatePrep, qml.numpy.array([0.0, 1.0]), [1]), + ], + ) + @pytest.mark.parametrize( + "trainable_params", + [(0, 1, 2), (1, 2)], + ) + def test_state_prep_ops( + self, dev, state_prep, params, wires, execute_and_derivatives, batch_obs, trainable_params + ): + """Test that a circuit containing state prep operations is differentiated correctly.""" + qs = QuantumScript( + [state_prep(params, wires), qml.RX(1.23, 0), qml.CNOT([0, 1]), qml.RX(4.56, 1)], + [qml.expval(qml.PauliZ(1))], + ) + + config = ExecutionConfig(gradient_method="adjoint", device_options={"batch_obs": batch_obs}) + program, new_config = dev.preprocess(config) + tapes, fn = program([qs]) + tapes[0].trainable_params = trainable_params + if execute_and_derivatives: + res, jac = dev.execute_and_compute_derivatives(tapes, new_config) + res = fn(res) + else: + res, jac = ( + fn(dev.execute(tapes, new_config)), + dev.compute_derivatives(tapes, new_config), + ) + + dev_ref = DefaultQubit(max_workers=1) + config = ExecutionConfig(gradient_method="adjoint") + program, new_config = dev_ref.preprocess(config) + tapes, fn = program([qs]) + tapes[0].trainable_params = trainable_params + if execute_and_derivatives: + expected, expected_jac = dev_ref.execute_and_compute_derivatives(tapes, new_config) + expected = fn(expected) + else: + expected, expected_jac = ( + fn(dev_ref.execute(tapes, new_config)), + dev_ref.compute_derivatives(tapes, new_config), + ) + + tol = 1e-5 if dev.c_dtype == np.complex64 else 1e-7 + assert np.allclose(res, expected, atol=tol, rtol=0) + assert np.allclose(jac, expected_jac, atol=tol, rtol=0) + def test_state_jacobian_not_supported(self, dev, batch_obs): """Test that an error is raised if derivatives are requested for state measurement""" qs = QuantumScript([qml.RX(1.23, 0)], [qml.state()], trainable_params=[0])