-
Notifications
You must be signed in to change notification settings - Fork 615
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[BUG] BasisRotation
is not JAX compatible
#6004
Comments
@josh146 what priority would you assign to this? Perhaps a P1? |
@trbromley there are a few other related bugs (I think?) @KetpuntoG is working on simplifying our stateprep suite based on this epic: https://app.shortcut.com/xanaduai/epic/66499?group_by=none&vc_group_by=day&ct_workflow=all&cf_workflow=500000005. There might be some dragons he runs into when completing this work related to these bugs. Maybe best to wait for him to get started on the work and see if these bugs are blockers, then we assign PX? |
### Before submitting Please complete the following checklist when submitting a PR: - [x] All new features must include a unit test. If you've fixed a bug or added code that should be tested, add a test to the test directory! - [x] All new functions and code must be clearly commented and documented. If you do make documentation changes, make sure that the docs build and render correctly by running `make docs`. - [x] Ensure that the test suite passes, by running `make test`. - [x] Add a new entry to the `doc/releases/changelog-dev.md` file, summarizing the change, and including a link back to the PR. - [x] The PennyLane source code conforms to [PEP8 standards](https://www.python.org/dev/peps/pep-0008/). We check all of our code against [Pylint](https://www.pylint.org/). To lint modified files, simply `pip install pylint`, and then run `pylint pennylane/path/to/file.py`. When all the above are checked, delete everything above the dashed line and fill in the pull request template. ------------------------------------------------------------------------------------------------------------ **Context:** `qml.BasisRotation` is was not jit-compatible. **Description of the Change:** We modified all the numpy arrays to `jax.numpy` and ensure the tests were passing **Benefits:** The basis rotation is jittable and thus `qjit` compatible. **Possible Drawbacks:** jax numpy may be slower than basis numpy. **Related GitHub Issues:** #6004 --------- Co-authored-by: obliviateandsurrender <utkarshazad98@gmail.com> Co-authored-by: Josh Izaac <josh146@gmail.com> Co-authored-by: soranjh <40344468+soranjh@users.noreply.github.com> Co-authored-by: Will <wmaxwell90@gmail.com>
### Before submitting Please complete the following checklist when submitting a PR: - [x] All new features must include a unit test. If you've fixed a bug or added code that should be tested, add a test to the test directory! - [x] All new functions and code must be clearly commented and documented. If you do make documentation changes, make sure that the docs build and render correctly by running `make docs`. - [x] Ensure that the test suite passes, by running `make test`. - [x] Add a new entry to the `doc/releases/changelog-dev.md` file, summarizing the change, and including a link back to the PR. - [x] The PennyLane source code conforms to [PEP8 standards](https://www.python.org/dev/peps/pep-0008/). We check all of our code against [Pylint](https://www.pylint.org/). To lint modified files, simply `pip install pylint`, and then run `pylint pennylane/path/to/file.py`. When all the above are checked, delete everything above the dashed line and fill in the pull request template. ------------------------------------------------------------------------------------------------------------ **Context:** `qml.BasisRotation` is was not jit-compatible. **Description of the Change:** We modified all the numpy arrays to `jax.numpy` and ensure the tests were passing **Benefits:** The basis rotation is jittable and thus `qjit` compatible. **Possible Drawbacks:** jax numpy may be slower than basis numpy. **Related GitHub Issues:** #6004 --------- Co-authored-by: obliviateandsurrender <utkarshazad98@gmail.com> Co-authored-by: Josh Izaac <josh146@gmail.com> Co-authored-by: soranjh <40344468+soranjh@users.noreply.github.com> Co-authored-by: Will <wmaxwell90@gmail.com>
**Context:** PR #6019 only fixes `BasisRotation` when using backprop on `default.qubit`. It is not jit compatible on any other device. This is because `unitary_matrix` was being considered a hyperparameter, not a piece of data. So we could not detect that the matrix was a tracer and we were in jitting mode, and we could not convert the matrix back into numpy data. **Description of the Change:** Make `unitary_matrix` a piece of data instead of a hyperparameter. This allows us to detect when it is being jitted. As a by-product, I also made it valid pytree. By making `unitary_matrix` a piece of data, we were able to get rid of the custom comparison method in `qml.equal`. **Benefits:** **Possible Drawbacks:** **Related GitHub Issues:** [sc-51603] Fixes #6004
Expected behavior
qml.BasisRotation
is not supported withjax.jit
orqml.qjit
.Actual behavior
Additional information
This is occurring because the decomposition for
qml.BasisRotation
callsqml.qchem.givens_decomposition
function, which is not JAX compatible:U
is being converted to a NumPy array and copiedqml.math
functions) are usedThe text was updated successfully, but these errors were encountered: