Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make layout pining optional for cross core communication #3511

Merged
merged 2 commits into from
Apr 19, 2022

Conversation

JackCaoG
Copy link
Collaborator

This is to fix #3506. I verified that with all pin_layout=True(default) or all pin_layout=False, test will passed.

PyTorch/xla currently compile graph separately for each TPU core(or GPU core). It is possible that graph being generated in each core is slightly different due to

  1. slightly different input shape
  2. different embedding table size
  3. ...

XLA compiler generate can tolerate small different among different cores but this will be a problem for communication ops. If the input shape difference ended up resulting in a layout difference among tensors that user want to call communication op, there can be a data corruption.

To overcome this problem we introduce the layout pining, which guarantee that all cores that participate in communication has the same layout for input tensor. However in some corner cases, pinging all layout will not work. For example all_gather(pin) + reduce_scatter(pin) might fail in some case.

This pr aim to provide a workaround when such failure happened. PyTorch/XLA will pin all communcation op layout by defualt, but if there is a compilation error with message HloModule has a mix of layout constrained user can choose to unpin all layout.

FYI @ronghanghu @hjm-aws

@JackCaoG JackCaoG requested a review from yeounoh April 19, 2022 03:15
@JackCaoG
Copy link
Collaborator Author

I will merge this pr when all test passed to unblock the user, but I want someone to review it too. I can have a follow up pr to fix the review comment. @yeounoh

@ronghanghu
Copy link
Collaborator

ronghanghu commented Apr 19, 2022

Thanks! Just to double-check, before this PR, we currently have the following behavior:

  • all_reduce: pinned
  • all_to_all: pinned
  • all_gather: unpinned
  • reduce_scatter: unpinned

Is this right? (Asking as I'm trying to understand what I need to do if I want to re-run the tests under the current behavior.)

@JackCaoG
Copy link
Collaborator Author

@ronghanghu yea, your statement is correct.

@JackCaoG
Copy link
Collaborator Author

I modify the default parameter and only pin the all_reduce layout by default to make test_mp_distributed_mm.py work. This is also what Blake suggested.

@JackCaoG JackCaoG force-pushed the layout_pin_optional branch from fcdbd24 to e4e01e4 Compare April 19, 2022 07:45
@JackCaoG JackCaoG merged commit 5ece4ca into master Apr 19, 2022
@JackCaoG JackCaoG deleted the layout_pin_optional branch April 19, 2022 07:53
ronghanghu added a commit to ronghanghu/xla that referenced this pull request Apr 20, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

All-reduce together w/ reduce-scatter causes crash on nightly 20220413 build
2 participants