-
Notifications
You must be signed in to change notification settings - Fork 358
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support JAX's distributed arrays #252
Comments
Nice catch. I did not have in mind any special axes (like sharding axis in your example) when designing this. Pre-composing and post-decomposing of ellipsis dramatically simplify intermediate logic, will need to re-design logic there. In the meantime recommendation is to not include sharding axis (usually batch) in ellipsis. |
@shoyer can you try this branch? https://github.com/arogozhnikov/einops/tree/preserve-axis-identity I've implemented there logic without pre-composing ellipsis by precomputing "recipe" for every input dimensionality |
I can confirm that the issue reported in my original post seems to be resolved! Both of test cases now demonstrate the expected behavior. I would consider adding at least the NumPy example as a regression test: x = np.zeros((2, 3, 4), order='F')
y = einops.rearrange(x, '... -> ...')
x[...] = 1
np.testing.assert_array_equal(x, y) |
this is fixed now, thank you! |
Describe the bug
Using
...
ineinops.rearrange
introduces extraneous reshape operations, where multiple dimensions are flattened into 1D and then reshaped back.This is typically fine, but can be problematic in (at least) two contexts:
np.reshape
can entail a memory copy.Reproduction steps
Using JAX:
Using NumPy:
Expected behavior
This operation should just be the identity, preserving the original array shapes and memory views:
{ lambda ; a:f32[2,3,4]. let in (a,) }
.y
would be all ones after modifyingx
.This is what happens currently if you use explicitly named dimensions:
More generally
...
should generate code equivalent to fully explicit dimension names.Your platform
einops 0.6.1
The text was updated successfully, but these errors were encountered: