-
Notifications
You must be signed in to change notification settings - Fork 304
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[PyTorch] Minor optimizations to reduce CPU overheads in modules #1191
base: main
Are you sure you want to change the base?
Conversation
Avoid enable_grad context when possible in cast function. Cache distributed group properties. Signed-off-by: Tim Moon <tmoon@nvidia.com>
Avoid torch.nn.Module impl of __setattr__. Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you propagate the CPU offloading importing fix to GroupedLinear as well?
from ..cpu_offload import CPUOffloadEnabled |
Signed-off-by: Tim Moon <tmoon@nvidia.com>
for more information, see https://pre-commit.ci
self.fp8_parameters = FP8GlobalStateManager.with_fp8_parameters() | ||
self.fp8 = FP8GlobalStateManager.is_fp8_enabled() | ||
self.fp8_calibration = FP8GlobalStateManager.is_fp8_calibration() | ||
self._fast_setattr("fp8_parameters", FP8GlobalStateManager.with_fp8_parameters()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder if we couldn't instead just do something like
te_params = self.get_te_params() # calls _fast_getattr internally, te_params is a normal object
te_params.fp8_parameters = FP8GlobalStateManager.with_fp8_parameters()
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Otherwise it will be hard to enforce everybody using only the _fast_get/setattr
I think.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Even better we could store these attrs in fp8_meta
or some other dict. I feel like the behavior of torch.nn.Module
is a hint we shouldn't change its attrs frequently.
@@ -369,7 +371,7 @@ def forward( | |||
out, _ = allreduce(out, tp_group) | |||
|
|||
# [*, in_features] -> [*, out_features] except first dimension changes for SP | |||
out = out.view(-1, *inp.shape[1:-1], out.shape[-1]) | |||
out = out.view(-1, *inp_shape[1:-1], out_features) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why can't we just create out with the right shape in the gemm call instead?
Signed-off-by: Tim Moon <tmoon@nvidia.com>
for more information, see https://pre-commit.ci
Description
We have observed that TE modules experience non-trivial CPU overhead, which often becomes a performance bottleneck in the forward pass of small models. For example, measuring the CPU runtime for Megatron-core modules with BF16 compute and TP=1:
ColumnParallelLinear
: 74 us per forward passTEColumnParallelLinear
: 140 us per forward passUnfortunately this overhead is distributed throughout the forward pass. Many basic PyTorch operations, e.g. getting attributes from
torch.Tensor
, involve O(1 us) overhead, so even basic checks to handle all of our advanced features will eventually add up to something non-trivial.This PR makes a few minor optimizations:
te.pytorch.cpu_offload
in every forward passtorch.nn.Module.__setattr__
when possibletorch.nn.Module
when possibleI see a 1.22x speedup, with 115 us per forward pass.
Type of change
Changes
te.pytorch.cpu_offload
in every forward passtorch.nn.Module.__setattr__
when possibletorch.nn.Module
when possibleChecklist: