-
Notifications
You must be signed in to change notification settings - Fork 3.5k
Cast input before moving to device for all strategies #18264
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
for more information, see https://pre-commit.ci
⚡ Required checks status: All passing 🟢Groups summary🟢 pytorch_lightning: Tests workflow
These checks are required after the changes to 🟢 pytorch_lightning: Azure GPU
These checks are required after the changes to 🟢 pytorch_lightning: Benchmarks
These checks are required after the changes to 🟢 pytorch_lightning: Docs
These checks are required after the changes to 🟢 mypy
These checks are required after the changes to 🟢 installThese checks are required after the changes to 🟢 link-check
These checks are required after the changes to Thank you for your contribution! 💜
|
What does this PR do?
All strategies now cast the input/batch (if needed) before moving it to the device. DeepSpeed was doing this already.
This is a follow up to the discussion #18217 (comment)
This change only affects double and half-precision plugins (16-true, bf16-true, 64-true).
cc @Borda @carmocca @justusschock @awaelchli