Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Pallas] Introduce jax_import_guard #6747

Merged
merged 7 commits into from
Mar 20, 2024
Merged

[Pallas] Introduce jax_import_guard #6747

merged 7 commits into from
Mar 20, 2024

Conversation

alanwaketan
Copy link
Collaborator

Summary:
Importing JAX will lock the TPU devices and prevent any pytorxh/xla's TPU computations. To address it, we need to acquire the TPU first. Somehow xm.xla_device() is enough to acquire the TPU device.

Test Plan:
python test/test_pallas.py

@alanwaketan
Copy link
Collaborator Author

@will-cromar Do you have better ideas?


def jax_import_guard():
# Somehow, this could grab the TPU before JAX locks it. Otherwise, any pt-xla TPU operations will hang.
xm.xla_device()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

torch_xla._XLAC._init_computation_client()

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When will this be called in general?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we don't usually call this API directly I believe, runtime init usually happens when we try to get the device. However if all you need is to init the runtime this api is cleaner. I used it in https://github.com/pytorch/pytorch/blob/a04e7fca8eddde492b239da6ac23d6a056666a0e/benchmarks/dynamo/common.py#L91 as well

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would this API work correctly on a multipod environment? @JackCaoG

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm it should. We just need to make sure that PyTorch/XLA init the runtime first and grab the libtpu, JAX init the runtime after pytorch doesn't cause any issue on single pod for us at least. We don't use JAX to execute any device program so it is OK.

In mutipod it sounds like init the runtime twice will cause issue? I never look into that too much

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jonb377 any thoughts on why this issue may happen on multi-pod?

@JackCaoG
Copy link
Collaborator

This is funny, @will-cromar and I were discusses how you handle the TPU ownership conflict between PyTorch/XLA and JAX this morning.

Copy link
Collaborator

@will-cromar will-cromar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will this break multiprocess? If you import this, then the current process will init the runtime, and you would not be able to call xmp.spawn after that point.

@alanwaketan
Copy link
Collaborator Author

alanwaketan commented Mar 14, 2024

Will this break multiprocess? If you import this, then the current process will init the runtime, and you would not be able to call xmp.spawn after that point.

What do you mean? xmp.spawn is supposed to be called before any torch-xla code, right? I hope they don't import pallas in the launcher...

@will-cromar
Copy link
Collaborator

What do you mean? xmp.spawn is supposed to be called before any torch-xla code, right? I hope they don't import pallas in the launcher...

This would be a problem if custom_kernels is imported at the global scope at all, e.g.

# Inits TPU
from torch_xla.experimental import custom_kernel

def main():
   # uses `custom_kernel`
   ...

if __name__ == "__main__":
  # Would fail because TPU is initialized
  xmp.spawn(main)

IMO it would be safer to import jax inside make_kernel_from_pallas where it is used.

@alanwaketan
Copy link
Collaborator Author

@will-cromar That works for me as well. Let me update it.

@JackCaoG
Copy link
Collaborator

hmm test failed with

Traceback (most recent call last):
  File "/tmp/pytorch/xla/test/test_pallas.py", line 10, in <module>
    from torch_xla.experimental.custom_kernel import jax_import_guard
  File "/opt/conda/lib/python3.8/site-packages/torch_xla-2.3.0+git22e6548-py3.8-linux-x86_64.egg/torch_xla/experimental/custom_kernel.py", line 18, in <module>
    import jax
ModuleNotFoundError: No module named 'jax'

@alanwaketan
Copy link
Collaborator Author

hmm test failed with

Traceback (most recent call last):
  File "/tmp/pytorch/xla/test/test_pallas.py", line 10, in <module>
    from torch_xla.experimental.custom_kernel import jax_import_guard
  File "/opt/conda/lib/python3.8/site-packages/torch_xla-2.3.0+git22e6548-py3.8-linux-x86_64.egg/torch_xla/experimental/custom_kernel.py", line 18, in <module>
    import jax
ModuleNotFoundError: No module named 'jax'

Yea, the CPU and GPU CI is not configured with JAX installed. Will fix it now.

@alanwaketan
Copy link
Collaborator Author

@JackCaoG @will-cromar I think this is ready for the new round of reviews.

@alanwaketan alanwaketan merged commit 7cf9f10 into master Mar 20, 2024
18 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants