Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

jax-metal: gather fails with unsigned indices #21547

Closed
jonatanklosko opened this issue May 31, 2024 · 1 comment
Closed

jax-metal: gather fails with unsigned indices #21547

jonatanklosko opened this issue May 31, 2024 · 1 comment
Assignees
Labels
Apple GPU (Metal) plugin bug Something isn't working

Comments

@jonatanklosko
Copy link

Description

import jax
import jax.numpy as jnp

def f(x, idx):
  dnums = jax.lax.GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0,), start_index_map=(0,))
  return jax.lax.gather(x, idx, dimension_numbers=dnums, slice_sizes=(1,))

x = jnp.array([1.0, 2.0, 3.0])
idx = jnp.array([[1], [0]], dtype=jnp.uint32)

# Print lowered HLO
print(jax.jit(f).lower(x, idx).as_text())
print(jax.jit(f)(x, idx))
HLO
module @jit_f attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
  func.func public @main(%arg0: tensor<3xf32> {mhlo.layout_mode = "default"}, %arg1: tensor<2x1xui32> {mhlo.layout_mode = "default"}) -> (tensor<2xf32> {jax.result_info = "", mhlo.layout_mode = "default"}) {
    %0 = "stablehlo.gather"(%arg0, %arg1) <{dimension_numbers = #stablehlo.gather<collapsed_slice_dims = [0], start_index_map = [0], index_vector_dim = 1>, indices_are_sorted = false, slice_sizes = array<i64: 1>}> : (tensor<3xf32>, tensor<2x1xui32>) -> tensor<2xf32>
    return %0 : tensor<2xf32>
  }
}

The above fails with

LLVM ERROR: Failed to infer result type(s).

On the CPU it correctly returns [2. 1.] on the CPU. The issue is specifically about indices being uint32.

Note that this is not an issue when doing a usual access like x[idx], because before this gets to jax.lax.gather, the indices are always converted to int32.

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.4.26
jaxlib: 0.4.26
numpy:  1.26.4
python: 3.10.8 (main, Nov 16 2022, 12:45:33) [Clang 14.0.0 (clang-1400.0.29.202)]
jax.devices (1 total, 1 local): [METAL(id=0)]
process_count: 1
platform: uname_result(system='Darwin', node='chonker', release='23.5.0', version='Darwin Kernel Version 23.5.0: Wed May  1 20:12:58 PDT 2024; root:xnu-10063.121.3~5/RELEASE_ARM64_T6000', machine='arm64')

jax-metal 0.0.7

@jonatanklosko
Copy link
Author

Works in 0.1.0, thanks :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Apple GPU (Metal) plugin bug Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants