-
Notifications
You must be signed in to change notification settings - Fork 100
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Remove conservative checks for supported Subtensors operations in JAX #849
Conversation
out_pt = x_pt[1, 2, 0] | ||
assert isinstance(out_pt.owner.op, pt_subtensor.Subtensor) | ||
out_fg = FunctionGraph([], [out_pt]) | ||
compare_jax_and_py(out_fg, []) | ||
out_fg = FunctionGraph([x_pt], [out_pt]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Otherwise I suspect we were not really testing jax subtensor dispatch as the index operation would be just constant_folded. This was not always the case, as we used to not run any rewrites in compare_jax_and_py
. Now we do
compare_jax_and_py(out_fg, [x_np]) | ||
|
||
# Boolean indexing should work if indexes are constant | ||
out_pt = x_pt[np.random.binomial(1, 0.5, size=(3, 4, 5))] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This was the case that led me to open this PR. JAX is happy to do it but our check wouldn't let it compile
c27898a
to
6fc8f7a
Compare
The check was failing incorrectly for cases that are supported such as constant Boolean arrays. Besides that, user may dispatch without necessarily jitting the graph. There is no reason to fail eagerly.
6fc8f7a
to
2b99224
Compare
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## main #849 +/- ##
==========================================
- Coverage 80.89% 80.88% -0.01%
==========================================
Files 169 169
Lines 46979 46966 -13
Branches 11478 11472 -6
==========================================
- Hits 38002 37989 -13
Misses 6764 6764
Partials 2213 2213
|
Description
The check was failing incorrectly for cases that are supported such as constant Boolean arrays. Besides that, user may dispatch without necessarily jitting the graph. There is no reason to fail eagerly.
Related Issue
Checklist
Type of change