Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Pallas] Introduce make_kernel_from_pallas #6713

Merged
merged 5 commits into from
Mar 13, 2024
Merged

Conversation

alanwaketan
Copy link
Collaborator

Summary:
This pull request introduces make_kernel_from_pallas API which is the top level API to interact with the Pallas integration. It takes a pallas_call wrapper and than make it a custom pytorch op.

Test Plan:
python test/test_pallas.py

@alanwaketan alanwaketan requested review from JackCaoG and qihqi March 11, 2024 19:04
@@ -56,3 +62,50 @@ def _extract_backend_config(
if op.name == "stablehlo.custom_call":
return op.backend_config.value
return None


def convert_torch_dtype_to_jax(dtype: torch.dtype) -> jnp.dtype:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@qihqi do we already have such converstation somewhere in torchxla2?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's generated by copilot, lol

(x.shape, x.dtype))

dtypes = [torch.float32, torch.float
] # TODO: torch.float64, torch.bfloat16, torch.float16 don't work.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why bf16 won't work?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mosaic complaints. Need to dig more into it.

Comment on lines +141 to +143
import jax
import jax.numpy as jnp
import jax._src.pallas.mosaic.pallas_call_registration
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seems like this is repeated on multiple tests, maybe just move to the top?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is a compatibility issue where jax will try to lock tpu devices if we import them before any pt/xla computations... I will need to resolve that...

Copy link
Collaborator

@JackCaoG JackCaoG left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you need this pr in 2.3?

@alanwaketan
Copy link
Collaborator Author

Do you need this pr in 2.3?

Yea, will also need a couple for the TODOs.

@unittest.skipIf(xr.device_type() != 'TPU', "This test only works on TPU.")
# TODO: This test cannot be ran individually, let's fix it.
def test_tpu_custom_call_pallas_wrap_add_payload(self):
import jax
Copy link
Collaborator

@miladm miladm Mar 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am concerned JAX-based tests cause failures due to libtpu version inconsistencies, and in turn CI hiccups. How do we resolve this concern?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's resolved in the last PR: #6696



def make_kernel_from_pallas(kernel: Callable, output_shape_dtype_fn: Callable):
# TODO: Maybe we can cache the payload for the same input.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The payload may change if the input is dynamic. We need to confirm this with pallas folks.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, the cache itself should deal with the dynamism.

@alanwaketan alanwaketan force-pushed the alanwaketan/pallas_api branch from 8b8be2e to ae6b62b Compare March 12, 2024 23:38
@alanwaketan
Copy link
Collaborator Author

Can I get any reviews?

Copy link
Collaborator

@JackCaoG JackCaoG left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I still think we should refactor convert_torch_dtype_to_jax and invesgate bf16(which I assume most people will use), approve to unblock.

@alanwaketan
Copy link
Collaborator Author

I still think we should refactor convert_torch_dtype_to_jax and invesgate bf16(which I assume most people will use), approve to unblock.

Yea, for sure. Let me follow up with that.

@alanwaketan alanwaketan merged commit 1bbe333 into master Mar 13, 2024
19 checks passed
@alanwaketan alanwaketan deleted the alanwaketan/pallas_api branch March 13, 2024 18:39
lsy323 pushed a commit that referenced this pull request Mar 13, 2024
Summary:
This pull request introduces make_kernel_from_pallas API which is the top level API to interact with the Pallas integration. It takes a pallas_call wrapper and than make it a custom pytorch op.

Test Plan:
python test/test_pallas.py
lsy323 added a commit that referenced this pull request Mar 13, 2024
Co-authored-by: Jiewen Tan <jwtan@google.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants