-
Notifications
You must be signed in to change notification settings - Fork 505
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add v4 config to xla_dist #3440
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @JackCaoG.
Added a house cleaning comment.
LGTM otherwise.
@@ -316,6 +316,17 @@ def _tpuvm_env_vars_cmd(self, worker_idx): | |||
'v3-512': '8,8,1', | |||
'v3-1024': '8,16,1', | |||
'v3-2048': '16,16,1', | |||
# v4 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd expect these configs to continue growing.
Instead of hard coding these confits, I suggest we moving them to a xla_dist_configs.yaml
file.
Wdyt?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: suggesting to include a comment defining dimensions:
number of host machines
, number of chips per host machine
, number of cores per chip
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I expect this file to be heavily modified after pjrt migration. xla_dist
is xrt focus. I think we can clean up the code later.
@@ -327,6 +338,13 @@ def _tpuvm_env_vars_cmd(self, worker_idx): | |||
|
|||
def _env_vars_cmd(self, worker_idx): | |||
client_worker = self._cluster.get_client_workers()[worker_idx] | |||
accelerator_gen = self._cluster.get_service_workers( | |||
)[0]._machine_type.split('-')[0] | |||
accelerator_gen_to_tpu_num_devices = { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ditto
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think it is worth to put every config into yaml, especially if it only 3 lines long. TPU_NUM_DEVICES
will most likely go away after pjrt migration too.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good. Let's address these issues when we handle PJRT at a later time.
Thanks @JackCaoG
Currently xla_dist does not support v4 since mesh config is missing.