diff --git a/tests/test_measures.py b/tests/test_measures.py index a78eadcc83..16108d6d06 100644 --- a/tests/test_measures.py +++ b/tests/test_measures.py @@ -118,13 +118,15 @@ def circuit(): @pytest.mark.parametrize( "cases", [ - [[1, 0], [0.9165164490394898, 0.08348355096051052, 0.0, 0.0]], + [[0, 1], [1, 0]], + [[1, 0], [0, 1]], ], ) - def test_fail_probs_tape_unordered_wires(self, cases, tol, dev): - """Test probs with a circuit on wires=[0]""" + def test_fail_probs_tape_unordered_wires(self, cases, tol): + """Test probs with a circuit on wires=[0] fails for out-of-order wires passed to probs.""" x, y, z = [0.5, 0.3, -0.7] + dev = qml.device("lightning.qubit", wires=cases[1]) @qml.qnode(dev) def circuit(): @@ -137,7 +139,29 @@ def circuit(): RuntimeError, match="Lightning does not currently support out-of-order indices for probabilities", ): - assert np.allclose(circuit(), cases[1], atol=tol, rtol=0) + _ = circuit() + + @pytest.mark.parametrize( + "cases", + [ + [[1, 0], [1, 0], [0.9165164490394898, 0.08348355096051052, 0.0, 0.0]], + [[2, 0], [2, 0, 1], [0.9165164490394898, 0.08348355096051052, 0.0, 0.0]], + ], + ) + def test_probs_matching_device_wire_order(self, cases, tol): + """Test probs with a circuit on wires=[0] passes if wires are sorted wrt device wires.""" + + x, y, z = [0.5, 0.3, -0.7] + dev = qml.device("lightning.qubit", wires=cases[1]) + + @qml.qnode(dev) + def circuit(): + qml.RX(0.4, wires=[0]) + qml.Rot(x, y, z, wires=[0]) + qml.RY(-0.2, wires=[0]) + return qml.probs(wires=cases[0]) + + assert np.allclose(circuit(), cases[2], atol=tol, rtol=0) @pytest.mark.parametrize( "cases",