-
Notifications
You must be signed in to change notification settings - Fork 27.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
protect tensor parallel usage #34800
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🔥
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
@@ -5005,6 +5006,8 @@ def tensor_parallel(self, device_mesh): | |||
device_mesh (`torch.distributed.DeviceMesh`): | |||
The device mesh to use for tensor parallelism. | |||
""" | |||
if not is_torch_greater_or_equal_than_2_4: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @ArthurZucker - this mismatch between the torch 2.4 check and 2.5 requirement means that torch 2.4 still hits this issue (where torch 2.3 is now working properly)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the fix!
if is_torch_greater_or_equal_than_2_4: | ||
from torch.distributed.tensor import Replicate | ||
from torch.distributed.tensor.parallel import ( | ||
ColwiseParallel, | ||
RowwiseParallel, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually, Replicate
is in torch.distributed.tensor
>= torch 2.5.
Between 2.0-2.4 (included), it is in torch.distributed._tensor
.
Thus it seems there are two options:
Option 1:
try:
from torch.distributed.tensor import Replicate
except ImportError:
from torch.distributed._tensor import Replicate
Option 2:
bump the requirement to 2.5.
ColwiseParallel
and RowwiseParallel
well exists since 2.0.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reverts #6759 Requires from transformers: huggingface/transformers#34816 huggingface/transformers#34800 Todo: - [x] Need to merge first PR to get support for torch 2.4
What does this PR do?
Fixes #34795