-
Notifications
You must be signed in to change notification settings - Fork 100
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix numba AdvancedIncSubtensor1 with broadcasted values #757
Fix numba AdvancedIncSubtensor1 with broadcasted values #757
Conversation
28203f6
to
8ac578a
Compare
8ac578a
to
1e5a1a1
Compare
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #757 +/- ##
=======================================
Coverage 80.85% 80.85%
=======================================
Files 162 162
Lines 47019 47045 +26
Branches 11502 11514 +12
=======================================
+ Hits 38017 38039 +22
- Misses 6748 6753 +5
+ Partials 2254 2253 -1
|
I think we might need a workaround for this issue: numba/numba#9573 |
1e5a1a1
to
82527a4
Compare
The last test case I added in this PR is a tensor scalar on the rhs: (
pt.as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))),
pt.as_tensor(rng.poisson(size=())), # <- Scalar y
([2, 0],),
), This is the fgraph of the last check: AdvancedIncSubtensor1{inplace,inc} [id A] <Tensor3(int64, shape=(3, 4, 5))> 0
├─ <Tensor3(int64, shape=(3, 4, 5))> [id B] <Tensor3(int64, shape=(3, 4, 5))>
├─ <Scalar(int64, shape=())> [id C] <Scalar(int64, shape=())>
└─ [2 0] [id D] <Vector(int64, shape=(2,))> I don't think we allow pytensor scalars in AdvancedSetSubtensor. They always get converted to TensorVariables: import pytensor.tensor as pt
import pytensor.scalar as ps
x = pt.vector("x")
y = ps.float64("y")
x[[0, 1]].set(y).dprint(print_type=True)
# AdvancedSetSubtensor [id A] <Vector(float64, shape=(?,))>
# ├─ x [id B] <Vector(float64, shape=(?,))>
# ├─ TensorFromScalar [id C] <Scalar(float64, shape=())>
# │ └─ y [id D] <float64>
# └─ [0 1] [id E] <Vector(int64, shape=(2,))> |
6b1e8ae
to
512c09d
Compare
This comment was marked as outdated.
This comment was marked as outdated.
Okay finally got it, the bug is when the lhs is a vector, and the rhs a 0d array. It's fine if the lhs is a higher rank array, which was what I was testing before |
512c09d
to
85b0c5d
Compare
@aseyboldt I added a patch for the bug you found and a test that would fail without the patch. Let me know what you think |
@numba_njit(boundscheck=True) | ||
def advancedincsubtensor1_inplace(x, val, idxs): | ||
if val.ndim == x.ndim: | ||
core_val = val[0] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have to reassign to a new named variable, otherwise typing fails. Is that expected from Numba?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sometimes numba seems to require this when the type changes, but I don't know the exact rules either.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Opened an issue, perhaps it's documented behavior, seemed strange though: numba/numba#9587
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good
@numba_njit(boundscheck=True) | ||
def advancedincsubtensor1_inplace(x, val, idxs): | ||
if val.ndim == x.ndim: | ||
core_val = val[0] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This indexing is always safe because x must be at least 1D (otherwise make_node
raises)
Description
Related Issue
Checklist
Type of change