Skip to content

Commit

Permalink
Make BasisRotation jit compatible (#6779)
Browse files Browse the repository at this point in the history
**Context:**

PR #6019 only fixes `BasisRotation` when using backprop on
`default.qubit`. It is not jit compatible on any other device. This is
because `unitary_matrix` was being considered a hyperparameter, not a
piece of data. So we could not detect that the matrix was a tracer and
we were in jitting mode, and we could not convert the matrix back into
numpy data.

**Description of the Change:**

Make `unitary_matrix` a piece of data instead of a hyperparameter. This
allows us to detect when it is being jitted.

As a by-product, I also made it valid pytree.

By making `unitary_matrix` a piece of data, we were able to get rid of
the custom comparison method in `qml.equal`.

**Benefits:**

**Possible Drawbacks:**

**Related GitHub Issues:**

[sc-51603] Fixes #6004
  • Loading branch information
albi3ro authored Jan 8, 2025
1 parent 22b172b commit 0f95698
Show file tree
Hide file tree
Showing 5 changed files with 13 additions and 45 deletions.
1 change: 1 addition & 0 deletions doc/releases/changelog-0.40.0.md
Original file line number Diff line number Diff line change
Expand Up @@ -389,6 +389,7 @@

* `qml.BasisRotation` template is now JIT compatible.
[(#6019)](https://github.com/PennyLaneAI/pennylane/pull/6019)
[(#6779)](https://github.com/PennyLaneAI/pennylane/pull/6779)

* The Jaxpr primitives for `for_loop`, `while_loop` and `cond` now store slices instead of
numbers of args.
Expand Down
33 changes: 0 additions & 33 deletions pennylane/ops/functions/equal.py
Original file line number Diff line number Diff line change
Expand Up @@ -754,39 +754,6 @@ def _equal_counts(op1: CountsMP, op2: CountsMP, **kwargs):
return _equal_measurements(op1, op2, **kwargs) and op1.all_outcomes == op2.all_outcomes


@_equal_dispatch.register
# pylint: disable=unused-argument
def _equal_basis_rotation(
op1: qml.BasisRotation,
op2: qml.BasisRotation,
check_interface=True,
check_trainability=True,
rtol=1e-5,
atol=1e-9,
):
if not qml.math.allclose(
op1.hyperparameters["unitary_matrix"],
op2.hyperparameters["unitary_matrix"],
atol=atol,
rtol=rtol,
):
return (
"The hyperparameter unitary_matrix is not equal for op1 and op2.\n"
f"Got {op1.hyperparameters['unitary_matrix']}\n and {op2.hyperparameters['unitary_matrix']}."
)
if op1.wires != op2.wires:
return f"op1 and op2 have different wires. Got {op1.wires} and {op2.wires}."
if check_interface:
interface1 = qml.math.get_interface(op1.hyperparameters["unitary_matrix"])
interface2 = qml.math.get_interface(op2.hyperparameters["unitary_matrix"])
if interface1 != interface2:
return (
"The hyperparameter unitary_matrix has different interfaces for op1 and op2."
f" Got {interface1} and {interface2}."
)
return True


@_equal_dispatch.register
def _equal_hilbert_schmidt(
op1: qml.HilbertSchmidt,
Expand Down
14 changes: 7 additions & 7 deletions pennylane/templates/subroutines/basis_rotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,10 @@ def _primitive_bind_call(cls, wires, unitary_matrix, check=False, id=None):

return cls._primitive.bind(*wires, unitary_matrix, check=check, id=id)

@classmethod
def _unflatten(cls, data, metadata):
return cls(wires=metadata[0], unitary_matrix=data[0])

def __init__(self, wires, unitary_matrix, check=False, id=None):
M, N = qml.math.shape(unitary_matrix)

Expand All @@ -124,19 +128,15 @@ def __init__(self, wires, unitary_matrix, check=False, id=None):
if len(wires) < 2:
raise ValueError(f"This template requires at least two wires, got {len(wires)}")

self._hyperparameters = {
"unitary_matrix": unitary_matrix,
}

super().__init__(wires=wires, id=id)
super().__init__(unitary_matrix, wires=wires, id=id)

@property
def num_params(self):
return 0
return 1

@staticmethod
def compute_decomposition(
wires, unitary_matrix, check=False
unitary_matrix, wires, check=False
): # pylint: disable=arguments-differ
r"""Representation of the operator as a product of other operators.
Expand Down
4 changes: 2 additions & 2 deletions tests/ops/functions/test_equal.py
Original file line number Diff line number Diff line change
Expand Up @@ -2598,7 +2598,7 @@ def test_different_tolerances_comparison(self, op, other_op):
assert_equal(op, other_op, atol=1e-5)
assert qml.equal(op, other_op, rtol=0, atol=1e-9) is False

with pytest.raises(AssertionError, match="The hyperparameter unitary_matrix is not equal"):
with pytest.raises(AssertionError, match="op1 and op2 have different data"):
assert_equal(op, other_op, rtol=0, atol=1e-9)

@pytest.mark.parametrize("op, other_op", [(op1, op2)])
Expand Down Expand Up @@ -2629,7 +2629,7 @@ def test_non_equal_interfaces(self, op):
assert_equal(op, other_op, check_interface=False)
assert qml.equal(op, other_op) is False

with pytest.raises(AssertionError, match=r"has different interfaces for op1 and op2"):
with pytest.raises(AssertionError, match=r"have different interfaces"):
assert_equal(op, other_op)


Expand Down
6 changes: 3 additions & 3 deletions tests/templates/test_subroutines/test_basis_rotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
import pennylane as qml


@pytest.mark.xfail # to be fixed by [sc-51603]
def test_standard_validity():
"""Run standard tests of operation validity."""
weights = np.array(
Expand Down Expand Up @@ -402,8 +401,9 @@ def test_autograd(self, tol):

assert np.allclose(grads, np.zeros_like(unitary_matrix, dtype=complex), atol=tol, rtol=0)

@pytest.mark.parametrize("device_name", ("default.qubit", "reference.qubit"))
@pytest.mark.jax
def test_jax_jit(self, tol):
def test_jax_jit(self, device_name, tol):
"""Test the jax interface."""

import jax
Expand All @@ -417,7 +417,7 @@ def test_jax_jit(self, tol):
]
)

dev = qml.device("default.qubit", wires=3)
dev = qml.device(device_name, wires=3)

circuit = jax.jit(qml.QNode(circuit_template, dev), static_argnames="check")
circuit2 = qml.QNode(circuit_template, dev)
Expand Down

0 comments on commit 0f95698

Please sign in to comment.