Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support JAX's distributed arrays #252

Closed
shoyer opened this issue Apr 21, 2023 · 4 comments
Closed

Support JAX's distributed arrays #252

shoyer opened this issue Apr 21, 2023 · 4 comments
Labels
enhancement New feature or request

Comments

@shoyer
Copy link

shoyer commented Apr 21, 2023

Describe the bug

Using ... in einops.rearrange introduces extraneous reshape operations, where multiple dimensions are flattened into 1D and then reshaped back.

This is typically fine, but can be problematic in (at least) two contexts:

  1. When using JAX's distributed arrays, the reshape operation removes axis identity. This makes it harder (impossible?) for XLA to preserve the sharding of distributed arrays.
  2. With non-C-contiguous arrays, np.reshape can entail a memory copy.

Reproduction steps

Using JAX:

import einops
import jax
import numpy as np

x = np.zeros((2, 3, 4))
jax.make_jaxpr(lambda x: einops.rearrange(x, '... -> ...'))(x)
# { lambda ; a:f32[2,3,4]. let
#    b:f32[24] = reshape[dimensions=None new_sizes=(24,)] a
#    c:f32[2,3,4] = reshape[dimensions=None new_sizes=(2, 3, 4)] b
#  in (c,) }

Using NumPy:

import einops
import numpy as np

x = np.zeros((2, 3, 4), order='F')
y = einops.rearrange(x, '... -> ...')
x[...] = 1
print(y)  # all zeros

Expected behavior

This operation should just be the identity, preserving the original array shapes and memory views:

  • For the JAX example, the JAXpr would be just { lambda ; a:f32[2,3,4]. let in (a,) }.
  • For the NumPy example, the array y would be all ones after modifying x.

This is what happens currently if you use explicitly named dimensions:

import einops
import jax
import numpy as np

x = np.zeros((2, 3, 4))
jax.make_jaxpr(lambda x: einops.rearrange(x, 'x y z -> x y z '))(x)
# { lambda ; a:f32[2,3,4]. let  in (a,) }
import einops
import numpy as np

x = np.zeros((2, 3, 4), order='F')
y = einops.rearrange(x, 'x y z -> x y z')
x[...] = 1
print(y)  # all ones

More generally ... should generate code equivalent to fully explicit dimension names.

Your platform

einops 0.6.1

@shoyer shoyer added the bug Something isn't working label Apr 21, 2023
@arogozhnikov
Copy link
Owner

Nice catch. I did not have in mind any special axes (like sharding axis in your example) when designing this.

Pre-composing and post-decomposing of ellipsis dramatically simplify intermediate logic, will need to re-design logic there.

In the meantime recommendation is to not include sharding axis (usually batch) in ellipsis.

@arogozhnikov
Copy link
Owner

@shoyer can you try this branch?

https://github.com/arogozhnikov/einops/tree/preserve-axis-identity

I've implemented there logic without pre-composing ellipsis by precomputing "recipe" for every input dimensionality

@shoyer
Copy link
Author

shoyer commented May 3, 2023

I can confirm that the issue reported in my original post seems to be resolved! Both of test cases now demonstrate the expected behavior.

I would consider adding at least the NumPy example as a regression test:

x = np.zeros((2, 3, 4), order='F')
y = einops.rearrange(x, '... -> ...')
x[...] = 1
np.testing.assert_array_equal(x, y)

@arogozhnikov arogozhnikov added enhancement New feature or request and removed bug Something isn't working labels Jul 7, 2023
@arogozhnikov arogozhnikov changed the title [BUG] Ellipsis introduces extranous reshapes Support Jax's distributed arrays Jul 8, 2023
@arogozhnikov arogozhnikov changed the title Support Jax's distributed arrays Support JAX's distributed arrays Jul 8, 2023
@shoyer
Copy link
Author

shoyer commented Oct 2, 2023

this is fixed now, thank you!

@shoyer shoyer closed this as completed Oct 2, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

2 participants