diff --git a/frontend/test/pytest/test_peephole_optimizations.py b/frontend/test/pytest/test_peephole_optimizations.py index 769215f38f..8974518ab9 100644 --- a/frontend/test/pytest/test_peephole_optimizations.py +++ b/frontend/test/pytest/test_peephole_optimizations.py @@ -21,9 +21,23 @@ from catalyst import pipeline, qjit from catalyst.passes import cancel_inverses, merge_rotations +default_device = qml.device("default.qubit", wires=1) + # pylint: disable=missing-function-docstring +def _assert_against_reference(circuit, theta, backend, optimization): + + customized_device = qml.device(backend, wires=1) + + reference_workflow = qml.QNode(circuit, default_device) + qjitted_workflow = qjit(qml.QNode(circuit, customized_device)) + optimized_workflow = qjit(optimization(qml.QNode(circuit, customized_device))) + + assert np.allclose(reference_workflow(theta), qjitted_workflow(theta)) + assert np.allclose(reference_workflow(theta), optimized_workflow(theta)) + + # # cancel_inverses # @@ -33,74 +47,19 @@ @pytest.mark.parametrize("theta", [42.42]) def test_cancel_inverses_functionality(theta, backend): - @qjit - def workflow(): - @qml.qnode(qml.device(backend, wires=1)) - def f(x): - qml.RX(x, wires=0) - qml.Hadamard(wires=0) - qml.Hadamard(wires=0) - return qml.probs() - - @cancel_inverses - @qml.qnode(qml.device(backend, wires=1)) - def g(x): - qml.RX(x, wires=0) - qml.Hadamard(wires=0) - qml.Hadamard(wires=0) - return qml.probs() - - return f(theta), g(theta) - - @qml.qnode(qml.device("default.qubit", wires=1)) - def reference(x): + def circuit(x): qml.RX(x, wires=0) qml.Hadamard(wires=0) qml.Hadamard(wires=0) return qml.probs() - assert np.allclose(workflow()[0], workflow()[1]) - assert np.allclose(workflow()[1], reference(theta)) + _assert_against_reference(circuit, theta, backend, cancel_inverses) @pytest.mark.parametrize("theta", [42.42]) def test_merge_rotation_functionality(theta, backend): - @qjit - def workflow(): - @qml.qnode(qml.device(backend, wires=1)) - def f(x): - qml.RX(x, wires=0) - qml.RX(x, wires=0) - qml.RZ(x, wires=0) - qml.adjoint(qml.RZ)(x, wires=0) - qml.Rot(x, x, x, wires=0) - qml.Rot(x, x, x, wires=0) - qml.PhaseShift(x, wires=0) - qml.PhaseShift(x, wires=0) - qml.Hadamard(wires=0) - qml.Hadamard(wires=0) - return qml.probs() - - @merge_rotations - @qml.qnode(qml.device(backend, wires=1)) - def g(x): - qml.RX(x, wires=0) - qml.RX(x, wires=0) - qml.RZ(x, wires=0) - qml.adjoint(qml.RZ)(x, wires=0) - qml.Rot(x, x, x, wires=0) - qml.Rot(x, x, x, wires=0) - qml.PhaseShift(x, wires=0) - qml.PhaseShift(x, wires=0) - qml.Hadamard(wires=0) - qml.Hadamard(wires=0) - return qml.probs() - - return f(theta), g(theta) - - @qml.qnode(qml.device("default.qubit", wires=1)) - def reference(x): + def circuit(x): qml.RX(x, wires=0) qml.RX(x, wires=0) qml.RZ(x, wires=0) @@ -113,8 +72,7 @@ def reference(x): qml.Hadamard(wires=0) return qml.probs() - assert np.allclose(workflow()[0], workflow()[1]) - assert np.allclose(workflow()[1], reference(theta)) + _assert_against_reference(circuit, theta, backend, merge_rotations) @pytest.mark.parametrize("theta", [42.42])