Make layout pining optional for cross core communication #3511
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
This is to fix #3506. I verified that with all
pin_layout=True
(default) or allpin_layout=False
, test will passed.PyTorch/xla currently compile graph separately for each TPU core(or GPU core). It is possible that graph being generated in each core is slightly different due to
XLA compiler generate can tolerate small different among different cores but this will be a problem for communication ops. If the input shape difference ended up resulting in a layout difference among tensors that user want to call communication op, there can be a data corruption.
To overcome this problem we introduce the layout pining, which guarantee that all cores that participate in communication has the same layout for input tensor. However in some corner cases, pinging all layout will not work. For example
all_gather(pin)
+reduce_scatter(pin)
might fail in some case.This pr aim to provide a workaround when such failure happened. PyTorch/XLA will pin all communcation op layout by defualt, but if there is a compilation error with message
HloModule has a mix of layout constrained
user can choose to unpin all layout.FYI @ronghanghu @hjm-aws