diff --git a/frontend/test/pytest/test_gradient.py b/frontend/test/pytest/test_gradient.py index d07237e510..a179ef51d0 100644 --- a/frontend/test/pytest/test_gradient.py +++ b/frontend/test/pytest/test_gradient.py @@ -1780,5 +1780,46 @@ def fn(x): assert np.allclose(res_pattern_partial, expected) +class TestDecompositionGradient: + """Test usage Gradient functions over decomposed gates""" + + def test_compute_decomposition(self): + """Test usage over computed decomposition.""" + + dev = qml.device("lightning.qubit", wires=1) + + @qml.qnode(dev) + def circuit(x): + U = jnp.array([[1, 0], [0, x]]) + decomp = qml.QubitUnitary.compute_decomposition(U, wires=0) + for op in decomp: + qml.apply(op) + return qml.probs() + + def f(x): + probs = circuit(x) + return probs[0] + probs[1] + + assert np.isnan(grad(qjit(f), argnums=0)(0.0)) == True + + def test_unitary_to_rot(self): + """Test usage with unitary to rot transform.""" + + dev = qml.device("lightning.qubit", wires=1) + + @qml.transforms.unitary_to_rot + @qml.qnode(dev) + def circuit(x): + U = jnp.array([[1, 0], [0, x]]) + qml.QubitUnitary(U, wires=0) + return qml.probs() + + def f(x): + probs = circuit(x) + return probs[0] + probs[1] + + assert np.isnan(grad(qjit(f), argnums=0)(0.0)) == True + + if __name__ == "__main__": pytest.main(["-x", __file__])