Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix numba AdvancedIncSubtensor1 with broadcasted values #757

Merged
merged 1 commit into from
May 24, 2024

Conversation

ricardoV94
Copy link
Member

Description

Related Issue

Checklist

Type of change

  • New feature / enhancement
  • Bug fix
  • Documentation
  • Maintenance
  • Other (please specify):

Copy link

codecov bot commented May 8, 2024

Codecov Report

Attention: Patch coverage is 89.74359% with 4 lines in your changes missing coverage. Please review.

Project coverage is 80.85%. Comparing base (1e96b89) to head (85b0c5d).
Report is 209 commits behind head on main.

Files Patch % Lines
pytensor/link/numba/dispatch/basic.py 89.74% 2 Missing and 2 partials ⚠️
Additional details and impacted files

Impacted file tree graph

@@           Coverage Diff           @@
##             main     #757   +/-   ##
=======================================
  Coverage   80.85%   80.85%           
=======================================
  Files         162      162           
  Lines       47019    47045   +26     
  Branches    11502    11514   +12     
=======================================
+ Hits        38017    38039   +22     
- Misses       6748     6753    +5     
+ Partials     2254     2253    -1     
Files Coverage Δ
pytensor/link/numba/dispatch/basic.py 85.68% <89.74%> (-0.06%) ⬇️

... and 3 files with indirect coverage changes

@ricardoV94 ricardoV94 added the bug Something isn't working label May 9, 2024
@aseyboldt
Copy link
Member

I think we might need a workaround for this issue: numba/numba#9573
If the value on the right hand side of the update is a rank zero array (ie not a scalar) that gets broadcasted, I think the current code will break. We should at least test for that I think. Maybe by manually calling funcify and because of the recent no_cpython_wrapper addition, a manual wrapper function around that?

@ricardoV94
Copy link
Member Author

ricardoV94 commented May 24, 2024

I think we might need a workaround for this issue: numba/numba#9573 If the value on the right hand side of the update is a rank zero array (ie not a scalar) that gets broadcasted, I think the current code will break. We should at least test for that I think. Maybe by manually calling funcify and because of the recent no_cpython_wrapper addition, a manual wrapper function around that?

The last test case I added in this PR is a tensor scalar on the rhs:

(
            pt.as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))),
            pt.as_tensor(rng.poisson(size=())),  # <- Scalar y
            ([2, 0],),
        ),

This is the fgraph of the last check:

AdvancedIncSubtensor1{inplace,inc} [id A] <Tensor3(int64, shape=(3, 4, 5))> 0
 ├─ <Tensor3(int64, shape=(3, 4, 5))> [id B] <Tensor3(int64, shape=(3, 4, 5))>
 ├─ <Scalar(int64, shape=())> [id C] <Scalar(int64, shape=())>
 └─ [2 0] [id D] <Vector(int64, shape=(2,))>

I don't think we allow pytensor scalars in AdvancedSetSubtensor. They always get converted to TensorVariables:

import pytensor.tensor as pt
import pytensor.scalar as ps

x = pt.vector("x")
y = ps.float64("y")
x[[0, 1]].set(y).dprint(print_type=True)

# AdvancedSetSubtensor [id A] <Vector(float64, shape=(?,))>
#  ├─ x [id B] <Vector(float64, shape=(?,))>
#  ├─ TensorFromScalar [id C] <Scalar(float64, shape=())>
#  │  └─ y [id D] <float64>
#  └─ [0 1] [id E] <Vector(int64, shape=(2,))>

@ricardoV94 ricardoV94 force-pushed the advanced_subtensor1_numba branch 2 times, most recently from 6b1e8ae to 512c09d Compare May 24, 2024 14:00
@ricardoV94

This comment was marked as outdated.

@ricardoV94
Copy link
Member Author

Okay finally got it, the bug is when the lhs is a vector, and the rhs a 0d array. It's fine if the lhs is a higher rank array, which was what I was testing before

@ricardoV94
Copy link
Member Author

@aseyboldt I added a patch for the bug you found and a test that would fail without the patch. Let me know what you think

@numba_njit(boundscheck=True)
def advancedincsubtensor1_inplace(x, val, idxs):
if val.ndim == x.ndim:
core_val = val[0]
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have to reassign to a new named variable, otherwise typing fails. Is that expected from Numba?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sometimes numba seems to require this when the type changes, but I don't know the exact rules either.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Opened an issue, perhaps it's documented behavior, seemed strange though: numba/numba#9587

Copy link
Member

@aseyboldt aseyboldt left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good

@numba_njit(boundscheck=True)
def advancedincsubtensor1_inplace(x, val, idxs):
if val.ndim == x.ndim:
core_val = val[0]
Copy link
Member Author

@ricardoV94 ricardoV94 May 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This indexing is always safe because x must be at least 1D (otherwise make_node raises)

@ricardoV94 ricardoV94 merged commit 5f374db into pymc-devs:main May 24, 2024
55 of 56 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working indexing numba
Projects
None yet
Development

Successfully merging this pull request may close these issues.

No broadcasting support in numba AdvancedIncSubtensor1
2 participants