Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Pallas] Introduce jax_import_guard #6747

Merged
merged 7 commits into from
Mar 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 7 additions & 10 deletions test/test_pallas.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,13 @@
import torch_xla
from torch_xla import runtime as xr

if xr.device_type() == 'TPU':
from torch_xla.experimental.custom_kernel import jax_import_guard
jax_import_guard()
import jax
import jax.numpy as jnp
from jax.experimental import pallas as pl


class PallasTest(unittest.TestCase):

Expand Down Expand Up @@ -111,12 +118,8 @@ def add_one_pallas(output, inputs, payload):

@unittest.skipIf(xr.device_type() != 'TPU', "This test only works on TPU.")
def test_tpu_custom_call_pallas_extract_add_payload(self):
import jax
import jax.numpy as jnp
import jax._src.pallas.mosaic.pallas_call_registration

from jax.experimental import pallas as pl

def add_vectors_kernel(x_ref, y_ref, o_ref):
x, y = x_ref[...], y_ref[...]
o_ref[...] = x + y
Expand All @@ -136,13 +139,7 @@ def add_vectors(x: jax.Array, y: jax.Array) -> jax.Array:
self.assertIn("custom_call_config", payload)

@unittest.skipIf(xr.device_type() != 'TPU', "This test only works on TPU.")
# TODO: This test cannot be ran individually, let's fix it.
def test_tpu_custom_call_pallas_wrap_add_payload(self):
import jax
import jax.numpy as jnp
import jax._src.pallas.mosaic.pallas_call_registration

from jax.experimental import pallas as pl

def add_vectors_kernel(x_ref, y_ref, o_ref):
x, y = x_ref[...], y_ref[...]
Expand Down
58 changes: 33 additions & 25 deletions torch_xla/experimental/custom_kernel.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,8 @@
import functools
import jax
import jax.numpy as jnp
import jax._src.pallas.mosaic.pallas_call_registration
import torch
import torch_xla
import torch_xla.core.xla_model as xm

from jax.experimental import pallas as pl
from typing import List, Callable
from torch.library import impl
from torch_xla.core.xla_model import XLA_LIB
Expand Down Expand Up @@ -64,30 +60,42 @@ def _extract_backend_config(
return None


def convert_torch_dtype_to_jax(dtype: torch.dtype) -> jnp.dtype:
if dtype == torch.float32:
return jnp.float32
elif dtype == torch.float64:
return jnp.float64
elif dtype == torch.float16:
return jnp.float16
elif dtype == torch.bfloat16:
return jnp.bfloat16
elif dtype == torch.int32:
return jnp.int32
elif dtype == torch.int64:
return jnp.int64
elif dtype == torch.int16:
return jnp.int16
elif dtype == torch.int8:
return jnp.int8
elif dtype == torch.uint8:
return jnp.uint8
else:
raise ValueError(f"Unsupported dtype: {dtype}")
def jax_import_guard():
# Somehow, we need to grab the TPU before JAX locks it. Otherwise, any pt-xla TPU operations will hang.
torch_xla._XLAC._init_computation_client()


def make_kernel_from_pallas(kernel: Callable, output_shape_dtype_fn: Callable):
# Import JAX within the function such that we don't need to call the jax_import_guard()
# in the global scope which could cause problems for xmp.spawn.
jax_import_guard()
import jax
import jax.numpy as jnp
import jax._src.pallas.mosaic.pallas_call_registration
from jax.experimental import pallas as pl

def convert_torch_dtype_to_jax(dtype: torch.dtype) -> jnp.dtype:
if dtype == torch.float32:
return jnp.float32
elif dtype == torch.float64:
return jnp.float64
elif dtype == torch.float16:
return jnp.float16
elif dtype == torch.bfloat16:
return jnp.bfloat16
elif dtype == torch.int32:
return jnp.int32
elif dtype == torch.int64:
return jnp.int64
elif dtype == torch.int16:
return jnp.int16
elif dtype == torch.int8:
return jnp.int8
elif dtype == torch.uint8:
return jnp.uint8
else:
raise ValueError(f"Unsupported dtype: {dtype}")

# TODO: Maybe we can cache the payload for the same input.
def wrapped_kernel(kernel: Callable, output_shape_dtype_fn: Callable, *args):
jax_args = []
Expand Down
Loading