-
Notifications
You must be signed in to change notification settings - Fork 505
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support multihost SPMD execution #4573
Conversation
b86aaec
to
9e5fb70
Compare
@@ -88,7 +89,7 @@ def mark_sharding(t: Union[torch.Tensor, XLAShardedTensor], mesh: Mesh, | |||
Examples | |||
—------------------------------ | |||
mesh_shape = (4, 2) | |||
num_devices = len(xm.get_xla_supported_devices()) | |||
num_devices = pjrt.global_device_count() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great :)
const std::vector<std::string>& devices) { | ||
std::unordered_map<int, int> device_index; | ||
for (int i = 0; i < devices.size(); ++i) { | ||
int global_ordinal = ParseDeviceString(devices[i]).ordinal(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The first global device gets the local index 0, so the order of the input devices list is important. Is this a correct understanding? Can we add some comments on this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The first device in the list gets local index 0, but the order of the global ordinals within devices
doesn't matter. I'll add some more documentation around this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, a minor comment.
9e5fb70
to
fe74d84
Compare
ba5ed74
to
e999a95
Compare
@@ -931,30 +931,24 @@ std::vector<torch::lazy::BackendDataPtr> CreateTensorsData( | |||
std::vector<xla::ComputationClient::DataPtr> new_handles; // out | |||
if (shardings[i] != nullptr) { | |||
xla::OpSharding sharding = shardings[i]->sharding; | |||
// TODO(yeounoh) PJRT runs a process per host for SPMD and without cross | |||
// host communications. This means that we may need to manually shard | |||
// across global devices for multi-host training. | |||
std::vector<std::string> local_devices = |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does GetLocalDevices()
return local devices with global ordinals? If so, let's leave a comment.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, have 2 comments --nit.
e999a95
to
d1fc7b1
Compare
d1fc7b1
to
bc538a8
Compare
The only main change to support multihost execution is to restrict the generated shards in ShardTensor to those which belong to addressable devices.