You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
On the CPU it correctly returns [2. 1.] on the CPU. The issue is specifically about indices being uint32.
Note that this is not an issue when doing a usual access like x[idx], because before this gets to jax.lax.gather, the indices are always converted to int32.
System info (python version, jaxlib version, accelerator, etc.)
Description
HLO
The above fails with
On the CPU it correctly returns
[2. 1.]
on the CPU. The issue is specifically about indices beinguint32
.Note that this is not an issue when doing a usual access like
x[idx]
, because before this gets tojax.lax.gather
, the indices are always converted toint32
.System info (python version, jaxlib version, accelerator, etc.)
jax-metal 0.0.7
The text was updated successfully, but these errors were encountered: