Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] BasisRotation is not JAX compatible #6004

Closed
josh146 opened this issue Jul 17, 2024 · 2 comments · Fixed by #6019 or #6779
Closed

[BUG] BasisRotation is not JAX compatible #6004

josh146 opened this issue Jul 17, 2024 · 2 comments · Fixed by #6019 or #6779
Labels
bug 🐛 Something isn't working

Comments

@josh146
Copy link
Member

josh146 commented Jul 17, 2024

Expected behavior

qml.BasisRotation is not supported with jax.jit or qml.qjit.

Actual behavior

dev = qml.device("lightning.qubit", wires=4)

@jax.jit
@qml.qnode(dev)
def f(U):
    for i in range(4):
        qml.Hadamard(i)
    qml.BasisRotation(unitary_matrix=U, wires=[0, 1, 2, 3], check=False)
    return qml.expval(qml.PauliZ(0))
>>> from scipy.stats import unitary_group
>>> U = unitary_group.rvs(4)
>>> f(U)
[/usr/local/lib/python3.10/dist-packages/pennylane/operation.py](https://localhost:8080/#) in decomposition(self)
   1318             list[Operator]: decomposition of the operator
   1319         """
-> 1320         return self.compute_decomposition(
   1321             *self.parameters, wires=self.wires, **self.hyperparameters
   1322         )

[/usr/local/lib/python3.10/dist-packages/pennylane/templates/subroutines/basis_rotation.py](https://localhost:8080/#) in compute_decomposition(wires, unitary_matrix, check)
    169 
    170         op_list = []
--> 171         phase_list, givens_list = givens_decomposition(unitary_matrix)
    172 
    173         for idx, phase in enumerate(phase_list):

[/usr/local/lib/python3.10/dist-packages/pennylane/qchem/givens_decomposition.py](https://localhost:8080/#) in givens_decomposition(unitary)
    149     """
    150 
--> 151     unitary, (M, N) = qml.math.toarray(unitary).copy(), unitary.shape
    152     if M != N:
    153         raise ValueError(f"The unitary matrix should be of shape NxN, got {unitary.shape}")

[/usr/local/lib/python3.10/dist-packages/autoray/autoray.py](https://localhost:8080/#) in do(fn, like, *args, **kwargs)
     79     backend = _choose_backend(fn, args, kwargs, like=like)
     80     func = get_lib_fn(backend, fn)
---> 81     return func(*args, **kwargs)
     82 
     83 

[/usr/local/lib/python3.10/dist-packages/pennylane/math/single_dispatch.py](https://localhost:8080/#) in _to_numpy_jax(x)
    783         return np.array(getattr(x, "val", x))
    784     except TracerArrayConversionError as e:
--> 785         raise ValueError(
    786             "Converting a JAX array to a NumPy array not supported when using the JAX JIT."
    787         ) from e

ValueError: Converting a JAX array to a NumPy array not supported when using the JAX JIT.

Additional information

This is occurring because the decomposition for qml.BasisRotation calls qml.qchem.givens_decomposition function, which is not JAX compatible:

  • The unitary matrix U is being converted to a NumPy array and copied
  • The copied unitary matrix is being updated in place.
  • NumPy functions (rather than qml.math functions) are used
  • Exceptions based on value are being raised.
@josh146 josh146 added the bug 🐛 Something isn't working label Jul 17, 2024
@josh146 josh146 changed the title [BUG] [BUG] BasisRotation is not JAX compatible Jul 17, 2024
@trbromley
Copy link
Contributor

@josh146 what priority would you assign to this? Perhaps a P1?

@isaacdevlugt
Copy link
Contributor

@trbromley there are a few other related bugs (I think?)

#6006
#6007
#6008

@KetpuntoG is working on simplifying our stateprep suite based on this epic: https://app.shortcut.com/xanaduai/epic/66499?group_by=none&vc_group_by=day&ct_workflow=all&cf_workflow=500000005. There might be some dragons he runs into when completing this work related to these bugs. Maybe best to wait for him to get started on the work and see if these bugs are blockers, then we assign PX?

@PabloAMC PabloAMC linked a pull request Jul 22, 2024 that will close this issue
5 tasks
willjmax added a commit that referenced this issue Nov 4, 2024
### Before submitting

Please complete the following checklist when submitting a PR:

- [x] All new features must include a unit test.
If you've fixed a bug or added code that should be tested, add a test to
the
      test directory!

- [x] All new functions and code must be clearly commented and
documented.
If you do make documentation changes, make sure that the docs build and
      render correctly by running `make docs`.

- [x] Ensure that the test suite passes, by running `make test`.

- [x] Add a new entry to the `doc/releases/changelog-dev.md` file,
summarizing the
      change, and including a link back to the PR.

- [x] The PennyLane source code conforms to
      [PEP8 standards](https://www.python.org/dev/peps/pep-0008/).
We check all of our code against [Pylint](https://www.pylint.org/).
      To lint modified files, simply `pip install pylint`, and then
      run `pylint pennylane/path/to/file.py`.

When all the above are checked, delete everything above the dashed
line and fill in the pull request template.


------------------------------------------------------------------------------------------------------------

**Context:** `qml.BasisRotation` is was not jit-compatible.

**Description of the Change:** We modified all the numpy arrays to
`jax.numpy` and ensure the tests were passing

**Benefits:** The basis rotation is jittable and thus `qjit` compatible.

**Possible Drawbacks:** jax numpy may be slower than basis numpy.

**Related GitHub Issues:**
#6004

---------

Co-authored-by: obliviateandsurrender <utkarshazad98@gmail.com>
Co-authored-by: Josh Izaac <josh146@gmail.com>
Co-authored-by: soranjh <40344468+soranjh@users.noreply.github.com>
Co-authored-by: Will <wmaxwell90@gmail.com>
mudit2812 pushed a commit that referenced this issue Nov 11, 2024
### Before submitting

Please complete the following checklist when submitting a PR:

- [x] All new features must include a unit test.
If you've fixed a bug or added code that should be tested, add a test to
the
      test directory!

- [x] All new functions and code must be clearly commented and
documented.
If you do make documentation changes, make sure that the docs build and
      render correctly by running `make docs`.

- [x] Ensure that the test suite passes, by running `make test`.

- [x] Add a new entry to the `doc/releases/changelog-dev.md` file,
summarizing the
      change, and including a link back to the PR.

- [x] The PennyLane source code conforms to
      [PEP8 standards](https://www.python.org/dev/peps/pep-0008/).
We check all of our code against [Pylint](https://www.pylint.org/).
      To lint modified files, simply `pip install pylint`, and then
      run `pylint pennylane/path/to/file.py`.

When all the above are checked, delete everything above the dashed
line and fill in the pull request template.


------------------------------------------------------------------------------------------------------------

**Context:** `qml.BasisRotation` is was not jit-compatible.

**Description of the Change:** We modified all the numpy arrays to
`jax.numpy` and ensure the tests were passing

**Benefits:** The basis rotation is jittable and thus `qjit` compatible.

**Possible Drawbacks:** jax numpy may be slower than basis numpy.

**Related GitHub Issues:**
#6004

---------

Co-authored-by: obliviateandsurrender <utkarshazad98@gmail.com>
Co-authored-by: Josh Izaac <josh146@gmail.com>
Co-authored-by: soranjh <40344468+soranjh@users.noreply.github.com>
Co-authored-by: Will <wmaxwell90@gmail.com>
albi3ro added a commit that referenced this issue Jan 8, 2025
**Context:**

PR #6019 only fixes `BasisRotation` when using backprop on
`default.qubit`. It is not jit compatible on any other device. This is
because `unitary_matrix` was being considered a hyperparameter, not a
piece of data. So we could not detect that the matrix was a tracer and
we were in jitting mode, and we could not convert the matrix back into
numpy data.

**Description of the Change:**

Make `unitary_matrix` a piece of data instead of a hyperparameter. This
allows us to detect when it is being jitted.

As a by-product, I also made it valid pytree.

By making `unitary_matrix` a piece of data, we were able to get rid of
the custom comparison method in `qml.equal`.

**Benefits:**

**Possible Drawbacks:**

**Related GitHub Issues:**

[sc-51603] Fixes #6004
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug 🐛 Something isn't working
Projects
None yet
3 participants