We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
>>> import numpy as np >>> import jax.numpy as jnp >>> np.tile(np.array([0, 1, 2]), (1, 1, 2)) array([[[0, 1, 2, 0, 1, 2]]]) >>> jnp.tile(jnp.array([0, 1, 2]), (1, 2, 2)) DeviceArray([[[0, 1, 2, 0, 1, 2]]], dtype=int64) >>> np.tile(np.array([0, 1, 2]), (1, 0, 2)) array([], shape(1, 0, 6), dtype=int64) >>> jnp.tile(jnp.array([0, 1, 2]), (1, 0, 2) ... ValueError: Need at least one array to concatenate
A fix locally is to wrap the jax.numpy.tile like so:
jax.numpy.tile
if 0 in repeats: return jnp.array([]).reshape(np.array(tensor_in.shape) * np.array(repeats)) return jnp.tile(tensor_in, repeats)
The text was updated successfully, but these errors were encountered:
Fix jnp.tile for cases with zero reps (fixes jax-ml#3919)
5baf954
Fix jnp.tile for cases with zero reps (fixes #3919) (#3922)
0ec1e25
Wow that was fast @jakevdp! Thanks so much! :)
Sorry, something went wrong.
Successfully merging a pull request may close this issue.
A fix locally is to wrap the
jax.numpy.tile
like so:The text was updated successfully, but these errors were encountered: