Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add v4 config to xla_dist #3440

Merged
merged 1 commit into from
Apr 8, 2022
Merged

Add v4 config to xla_dist #3440

merged 1 commit into from
Apr 8, 2022

Conversation

JackCaoG
Copy link
Collaborator

Currently xla_dist does not support v4 since mesh config is missing.

@JackCaoG JackCaoG requested a review from miladm March 23, 2022 04:15
Copy link
Collaborator

@miladm miladm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @JackCaoG.

Added a house cleaning comment.
LGTM otherwise.

@@ -316,6 +316,17 @@ def _tpuvm_env_vars_cmd(self, worker_idx):
'v3-512': '8,8,1',
'v3-1024': '8,16,1',
'v3-2048': '16,16,1',
# v4
Copy link
Collaborator

@miladm miladm Mar 28, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd expect these configs to continue growing.
Instead of hard coding these confits, I suggest we moving them to a xla_dist_configs.yaml file.
Wdyt?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: suggesting to include a comment defining dimensions:
number of host machines, number of chips per host machine, number of cores per chip

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I expect this file to be heavily modified after pjrt migration. xla_dist is xrt focus. I think we can clean up the code later.

@@ -327,6 +338,13 @@ def _tpuvm_env_vars_cmd(self, worker_idx):

def _env_vars_cmd(self, worker_idx):
client_worker = self._cluster.get_client_workers()[worker_idx]
accelerator_gen = self._cluster.get_service_workers(
)[0]._machine_type.split('-')[0]
accelerator_gen_to_tpu_num_devices = {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think it is worth to put every config into yaml, especially if it only 3 lines long. TPU_NUM_DEVICES will most likely go away after pjrt migration too.

@miladm miladm self-requested a review April 8, 2022 17:26
Copy link
Collaborator

@miladm miladm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good. Let's address these issues when we handle PJRT at a later time.

Thanks @JackCaoG

@JackCaoG JackCaoG merged commit cc04953 into master Apr 8, 2022
@JackCaoG JackCaoG deleted the v4_pod_config branch April 8, 2022 17:28
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants