-
Notifications
You must be signed in to change notification settings - Fork 505
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Pallas] Introduce jax_import_guard #6747
Conversation
@will-cromar Do you have better ideas? |
|
||
def jax_import_guard(): | ||
# Somehow, this could grab the TPU before JAX locks it. Otherwise, any pt-xla TPU operations will hang. | ||
xm.xla_device() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
torch_xla._XLAC._init_computation_client()
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When will this be called in general?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we don't usually call this API directly I believe, runtime init usually happens when we try to get the device. However if all you need is to init the runtime this api is cleaner. I used it in https://github.com/pytorch/pytorch/blob/a04e7fca8eddde492b239da6ac23d6a056666a0e/benchmarks/dynamo/common.py#L91 as well
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
would this API work correctly on a multipod environment? @JackCaoG
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hmm it should. We just need to make sure that PyTorch/XLA init the runtime first and grab the libtpu, JAX init the runtime after pytorch doesn't cause any issue on single pod for us at least. We don't use JAX to execute any device program so it is OK.
In mutipod it sounds like init the runtime twice will cause issue? I never look into that too much
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@jonb377 any thoughts on why this issue may happen on multi-pod?
This is funny, @will-cromar and I were discusses how you handle the TPU ownership conflict between PyTorch/XLA and JAX this morning. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will this break multiprocess? If you import this, then the current process will init the runtime, and you would not be able to call xmp.spawn
after that point.
What do you mean? xmp.spawn is supposed to be called before any torch-xla code, right? I hope they don't import pallas in the launcher... |
This would be a problem if
IMO it would be safer to import |
@will-cromar That works for me as well. Let me update it. |
hmm test failed with
|
Yea, the CPU and GPU CI is not configured with JAX installed. Will fix it now. |
788d6fe
to
bcfd99c
Compare
@JackCaoG @will-cromar I think this is ready for the new round of reviews. |
fb20428
to
0115d0f
Compare
Summary:
Importing JAX will lock the TPU devices and prevent any pytorxh/xla's TPU computations. To address it, we need to acquire the TPU first. Somehow
xm.xla_device()
is enough to acquire the TPU device.Test Plan:
python test/test_pallas.py