-
Notifications
You must be signed in to change notification settings - Fork 585
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Improvements to Jitting: broadcasted measurements and counts on all wires #6108
Conversation
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## master #6108 +/- ##
==========================================
- Coverage 99.67% 99.66% -0.01%
==========================================
Files 432 443 +11
Lines 41839 42240 +401
==========================================
+ Hits 41702 42098 +396
- Misses 137 142 +5 ☔ View full report in Codecov by Sentry. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM 🚀
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🎉
…ires (#6108) **Context:** When using jax jit and non-backprop, we need to know the exact shape of the result value in order to use a `pure_callback` for the device execution. The framework we use to determine this shape currently does not work for measurements on all available wires when the device does not specify a wire order. It also explicitly errors out with `qml.counts`, even when `all_outcomes=True`. But when `all_outcomes=True`, we know the exact resulting shape, so we can integrate it with jax-jit anyway. **Description of the Change:** This PR makes a breaking change to `MeasurementProcess.shape`. Now it's call signature is `(self, shots: Optional[int]=None, num_device_wires:int =0)`. `num_device_wires` can take on the value of `len(tape.wires)` when the measurement is broadcasted on all available wires, but the device does not specify a number of wires. **Benefits:** Improved jit support. **Possible Drawbacks:** Breaking changes always cause some draw backs. We don't know who may be relying on this method somewhere in the wild. But given we only use the method for the jax-jit interface, I think we are safe to make a breaking change here. **Related GitHub Issues:** [sc-65313] [sc-59327] Fixes #5813
Context:
When using jax jit and non-backprop, we need to know the exact shape of the result value in order to use a
pure_callback
for the device execution. The framework we use to determine this shape currently does not work for measurements on all available wires when the device does not specify a wire order.It also explicitly errors out with
qml.counts
, even whenall_outcomes=True
. But whenall_outcomes=True
, we know the exact resulting shape, so we can integrate it with jax-jit anyway.Description of the Change:
This PR makes a breaking change to
MeasurementProcess.shape
. Now it's call signature is(self, shots: Optional[int]=None, num_device_wires:int =0)
.num_device_wires
can take on the value oflen(tape.wires)
when the measurement is broadcasted on all available wires, but the device does not specify a number of wires.Benefits:
Improved jit support.
Possible Drawbacks:
Breaking changes always cause some draw backs. We don't know who may be relying on this method somewhere in the wild. But given we only use the method for the jax-jit interface, I think we are safe to make a breaking change here.
Related GitHub Issues:
Fixes #5813
Related Shortcut Stories:
[sc-65313] [sc-59327]