Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PyTorch] Minor optimizations to reduce CPU overheads in modules #1191

Open
wants to merge 11 commits into
base: main
Choose a base branch
from

Conversation

timmoon10
Copy link
Collaborator

@timmoon10 timmoon10 commented Sep 18, 2024

Description

We have observed that TE modules experience non-trivial CPU overhead, which often becomes a performance bottleneck in the forward pass of small models. For example, measuring the CPU runtime for Megatron-core modules with BF16 compute and TP=1:

Unfortunately this overhead is distributed throughout the forward pass. Many basic PyTorch operations, e.g. getting attributes from torch.Tensor, involve O(1 us) overhead, so even basic checks to handle all of our advanced features will eventually add up to something non-trivial.

This PR makes a few minor optimizations:

  • Avoid importing from te.pytorch.cpu_offload in every forward pass
  • Memoize NCCL process group properties
  • Avoid custom logic in torch.nn.Module.__setattr__ when possible
  • Avoid custom logic for accessing params in torch.nn.Module when possible
  • Avoid accessing tensor attrs more than necessary

I see a 1.22x speedup, with 115 us per forward pass.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refractor

Changes

  • Avoid importing from te.pytorch.cpu_offload in every forward pass
  • Memoize NCCL process group properties
  • Avoid custom logic in torch.nn.Module.__setattr__ when possible
  • Avoid custom logic for accessing params in torch.nn.Module when possible
  • Avoid accessing tensor attrs more than necessary

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Avoid enable_grad context when possible in cast function. Cache distributed group properties.

Signed-off-by: Tim Moon <tmoon@nvidia.com>
Avoid torch.nn.Module impl of __setattr__.

Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
@timmoon10 timmoon10 added the enhancement New feature or request label Sep 18, 2024
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Copy link
Collaborator

@yaox12 yaox12 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you propagate the CPU offloading importing fix to GroupedLinear as well?

from ..cpu_offload import CPUOffloadEnabled

self.fp8_parameters = FP8GlobalStateManager.with_fp8_parameters()
self.fp8 = FP8GlobalStateManager.is_fp8_enabled()
self.fp8_calibration = FP8GlobalStateManager.is_fp8_calibration()
self._fast_setattr("fp8_parameters", FP8GlobalStateManager.with_fp8_parameters())
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if we couldn't instead just do something like

te_params = self.get_te_params()  # calls _fast_getattr internally, te_params is a normal object
te_params.fp8_parameters = FP8GlobalStateManager.with_fp8_parameters()

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Otherwise it will be hard to enforce everybody using only the _fast_get/setattr I think.

Copy link
Collaborator Author

@timmoon10 timmoon10 Sep 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Even better we could store these attrs in fp8_meta or some other dict. I feel like the behavior of torch.nn.Module is a hint we shouldn't change its attrs frequently.

@@ -369,7 +371,7 @@ def forward(
out, _ = allreduce(out, tp_group)

# [*, in_features] -> [*, out_features] except first dimension changes for SP
out = out.view(-1, *inp.shape[1:-1], out.shape[-1])
out = out.view(-1, *inp_shape[1:-1], out_features)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why can't we just create out with the right shape in the gemm call instead?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants