-
Notifications
You must be signed in to change notification settings - Fork 505
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Introduce virtual device #4091
Introduce virtual device #4091
Conversation
a7274ef
to
9976bb2
Compare
9976bb2
to
2b5dd6e
Compare
Let's make sure that we cover the explict sharded cases, where we want to avoid the initial unpartitioned data transfer. We will have to double-check, but |
Notes after chat with Yeounoh:
|
This is the only entry for the transfer data to device. |
6a326ec
to
32ff407
Compare
Remaining implementation details before I can start testing:
|
f64c1a5
to
01fa14d
Compare
df56ba8
to
ede9395
Compare
17cf0a0
to
5b7ac6f
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thank you @steventk-g 👍
5b7ac6f
to
1a43d3a
Compare
Changes in this PR
XLA_USE_SPMD
_xla_mark_sharding
. The re-downloading path is preserved so thatXLA_USE_SPMD=0
still works as well.xla:0
fromxm.xla_device()
. At this point, users should expect all tensors to be treated as if they are on the virtual device when SPMD is enabled.