Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove conservative checks for supported Subtensors operations in JAX #849

Merged
merged 1 commit into from
Jun 28, 2024

Conversation

ricardoV94
Copy link
Member

Description

The check was failing incorrectly for cases that are supported such as constant Boolean arrays. Besides that, user may dispatch without necessarily jitting the graph. There is no reason to fail eagerly.

Related Issue

  • Closes #
  • Related to #

Checklist

Type of change

  • New feature / enhancement
  • Bug fix
  • Documentation
  • Maintenance
  • Other (please specify):

out_pt = x_pt[1, 2, 0]
assert isinstance(out_pt.owner.op, pt_subtensor.Subtensor)
out_fg = FunctionGraph([], [out_pt])
compare_jax_and_py(out_fg, [])
out_fg = FunctionGraph([x_pt], [out_pt])
Copy link
Member Author

@ricardoV94 ricardoV94 Jun 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Otherwise I suspect we were not really testing jax subtensor dispatch as the index operation would be just constant_folded. This was not always the case, as we used to not run any rewrites in compare_jax_and_py. Now we do

compare_jax_and_py(out_fg, [x_np])

# Boolean indexing should work if indexes are constant
out_pt = x_pt[np.random.binomial(1, 0.5, size=(3, 4, 5))]
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was the case that led me to open this PR. JAX is happy to do it but our check wouldn't let it compile

The check was failing incorrectly for cases that are supported such as constant Boolean arrays.
Besides that, user may dispatch without necessarily jitting the graph. There is no reason to fail eagerly.
Copy link

codecov bot commented Jun 24, 2024

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 80.88%. Comparing base (d3bd1f1) to head (2b99224).
Report is 158 commits behind head on main.

Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main     #849      +/-   ##
==========================================
- Coverage   80.89%   80.88%   -0.01%     
==========================================
  Files         169      169              
  Lines       46979    46966      -13     
  Branches    11478    11472       -6     
==========================================
- Hits        38002    37989      -13     
  Misses       6764     6764              
  Partials     2213     2213              
Files with missing lines Coverage Δ
pytensor/link/jax/dispatch/subtensor.py 86.95% <100.00%> (-2.88%) ⬇️

@ricardoV94 ricardoV94 changed the title Remove false positive checks for supported Subtensors operations in JAX Remove conservative checks for supported Subtensors operations in JAX Jun 25, 2024
@lucianopaz lucianopaz merged commit 684a929 into pymc-devs:main Jun 28, 2024
57 checks passed
@ricardoV94 ricardoV94 mentioned this pull request Jun 28, 2024
11 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants