-
Notifications
You must be signed in to change notification settings - Fork 505
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Pallas] Introduce make_kernel_from_pallas #6713
Conversation
@@ -56,3 +62,50 @@ def _extract_backend_config( | |||
if op.name == "stablehlo.custom_call": | |||
return op.backend_config.value | |||
return None | |||
|
|||
|
|||
def convert_torch_dtype_to_jax(dtype: torch.dtype) -> jnp.dtype: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@qihqi do we already have such converstation somewhere in torchxla2?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's generated by copilot, lol
(x.shape, x.dtype)) | ||
|
||
dtypes = [torch.float32, torch.float | ||
] # TODO: torch.float64, torch.bfloat16, torch.float16 don't work. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why bf16 won't work?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Mosaic complaints. Need to dig more into it.
import jax | ||
import jax.numpy as jnp | ||
import jax._src.pallas.mosaic.pallas_call_registration |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
seems like this is repeated on multiple tests, maybe just move to the top?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is a compatibility issue where jax will try to lock tpu devices if we import them before any pt/xla computations... I will need to resolve that...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you need this pr in 2.3?
Yea, will also need a couple for the TODOs. |
@unittest.skipIf(xr.device_type() != 'TPU', "This test only works on TPU.") | ||
# TODO: This test cannot be ran individually, let's fix it. | ||
def test_tpu_custom_call_pallas_wrap_add_payload(self): | ||
import jax |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am concerned JAX-based tests cause failures due to libtpu
version inconsistencies, and in turn CI hiccups. How do we resolve this concern?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's resolved in the last PR: #6696
|
||
|
||
def make_kernel_from_pallas(kernel: Callable, output_shape_dtype_fn: Callable): | ||
# TODO: Maybe we can cache the payload for the same input. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The payload may change if the input is dynamic. We need to confirm this with pallas folks.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Right, the cache itself should deal with the dynamism.
8b8be2e
to
ae6b62b
Compare
Can I get any reviews? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I still think we should refactor convert_torch_dtype_to_jax
and invesgate bf16(which I assume most people will use), approve to unblock.
Yea, for sure. Let me follow up with that. |
Summary: This pull request introduces make_kernel_from_pallas API which is the top level API to interact with the Pallas integration. It takes a pallas_call wrapper and than make it a custom pytorch op. Test Plan: python test/test_pallas.py
Summary:
This pull request introduces make_kernel_from_pallas API which is the top level API to interact with the Pallas integration. It takes a pallas_call wrapper and than make it a custom pytorch op.
Test Plan:
python test/test_pallas.py