-
Notifications
You must be signed in to change notification settings - Fork 505
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SPMD] Support manual all-reduce #7576
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
approve to unblock, but I think we should fix the tensor method name
@@ -392,6 +392,13 @@ void all_reduce(const std::vector<XLATensorPtr>& inputs, | |||
} | |||
} | |||
|
|||
XLATensorPtr all_reduce(const XLATensorPtr& input, AllReduceType reduce_type, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you call it all_reduce _no_token
, the only difference in signature is it does not take pin_layout
but the main difference in the op is that it does not set token.. It is better to reflect that in the name.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure. I can follow up with that.
for array support do you plan to call |
I don't think that's necessary. I'm thinking the compiler should be smart enough to fuse all-reduces if the fusion is necessary. |
Thanks Jack for approving. |
Summary:
This is to add manual all-reduce support to SPMD and it currently only supports one input tensor. For array support, we can do that in python layer instead.
Test Plan:
python ./test/spmd/test_xla_sharding.py -v -k test_spmd_all_reduce