Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PyTorch] Prototype for operation-based API #707

Merged
merged 74 commits into from
Jul 9, 2024

Conversation

timmoon10
Copy link
Collaborator

@timmoon10 timmoon10 commented Mar 9, 2024

Currently, Transformer Engine exposes fused operations with custom modules like LayerNormLinear. These are highly tuned for certain workloads (especially GPT), but are not easy to generalize to other models. This approach is especially cumbersome when the forward and backward passes have different fusion opportunities (e.g. forward GEMM+bias+gelu and backward dgelu+dbias+cast+transpose).

This PR adds a new API for specifying Transformer Engine models. Instead of using large compound modules (e.g. LayerNormLinear), users can build up a Sequential module out of small FusibleOperations (e.g. LayerNorm, Linear). The Sequential module (with a similar API as torch.nn.Sequential) will internally attempt to fuse operations together (possibly differently in the forward and backward passes).

Some of the more important components:

  • te.ops.FusibleOperation: A neural network operation that can be processed by the fuser. They have forward and backward functions similar to torch.autograd.Function.
  • te.ops.BasicOperation: A minimal FusibleOperation. Their forward and backward functions must be implemented and they should hold the model state and parameters.
  • te.ops.FusedOperation: A FusibleOperation that is interchangeable with multiple UnfusedOpeations. If it implements a forward or backward function, they must save the same context as the UnfusedOperations.
  • te.ops.Sequential: A container module with a similar API as torch.nn.Sequential.
  • te.ops.OperationFuser: A helper class that manages autograd, performs the operation fusions, and keeps track of corresponding BasicOperations and FusedOperations.

As a proof-of-concept, I've been able to fuse Linear and Bias operations, on a single GPU and with tensor parallelism. These modules have been implemented to support Float8Tensor, which simplifies the implementation and will be important for future work with e.g. FP8 attention. I've also added single-GPU and multi-GPU tests.

This work is heavily influenced by #377 from @janekb04.

Remaining tasks:

  • FP8 scaling factor updates
  • Checkpointing
  • Documentation

Future work:

  • Operations: layer norm, activations, attention
  • Fusions
  • Possibly reimplementing the existing modules using this infrastructure

Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Runs, but need to validate. Runtime errors with non-FP8 params and FP8 compute, or FP8 params and non-FP8 compute.

Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Test does not pass with FP8.

Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Not supported by cuBLAS.

Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
@timmoon10 timmoon10 added the enhancement New feature or request label Mar 9, 2024
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Still need to implement amax reductions.

Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Add documentation for unfused ops

Signed-off-by: Tim Moon <tmoon@nvidia.com>
Expand documentation

Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
@timmoon10
Copy link
Collaborator Author

/te-ci pytorch

timmoon10 and others added 2 commits June 14, 2024 17:44
Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com>
@timmoon10
Copy link
Collaborator Author

/te-ci pytorch

Copy link
Collaborator

@sudhakarsingh27 sudhakarsingh27 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pass 1

transformer_engine/pytorch/ops/basic/basic_linear.py Outdated Show resolved Hide resolved
@property
@abc.abstractmethod
def is_fused_op(self) -> bool:
"""Whether this op is the fusion of one or more basic ops"""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"""Whether this op is the fusion of one or more basic ops"""
"""Whether this op is the fusion of one or more basic ops"""
pass

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PyLint prefers just putting the docstring: 738df8a

"""Whether this op is the fusion of one or more basic ops"""

def pre_forward(self) -> None:
"""Preprocessing before forward pass"""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"""Preprocessing before forward pass"""
"""Preprocessing before forward pass"""
pass

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PyLint prefers just putting the docstring: 738df8a

curr_len = meta.amax_history.size(0)
if curr_len == amax_history_len:
continue
with torch.no_grad():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just curious why do we need torch.no_grad here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think it's needed, but I'm being paranoid about leaking the autograd graph. This code path is infrequent but called outside the OperationFuser's autograd function:


Parameters
----------
mode: {"input", "param", "grad_output"}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is name a better fit for this arg?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for fp8_meta in self._fp8_metas.values():
self._check_fp8_meta(fp8_meta)

# Register FP8 metadata for amax and scale update
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this part of the code (or in spirit) from prepare_for_forward from the original API?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Exactly:

if self.fp8 and not FP8GlobalStateManager.fp8_graph_capturing():
FP8GlobalStateManager.add_fp8_tensors_to_global_buffer(
self.fp8_meta, fp8_weights=self._get_fp8_params()
)

Although now that you mention it, we should register "grad_output" in the backward pass instead of the forward.

Copy link
Collaborator Author

@timmoon10 timmoon10 Jun 27, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, this matches the module API. The fp8_metas are registered in the forward pass, and we manually trigger an update in the backward pass:

https://github.com/timmoon10/TransformerEngine/blob/f4e6af92e8956d948fe1fbaefbc1b2dd6f32b457/transformer_engine/pytorch/ops/fuser.py#L169-L171

if ctx.reduce_and_update_bwd_fp8_tensors and not is_graph_capturing():
FP8GlobalStateManager.reduce_and_update_fp8_tensors(forward=False)

torch.Tensor:
Output tensor

"""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add pass

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PyLint prefers just putting the docstring: 738df8a

Iterable of torch.Tensor:
Loss gradients w.r.t. parameters

"""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add pass

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PyLint prefers just putting the docstring: 738df8a

self.append(module)

def add_module(self, name: str, module: Optional[torch.nn.Module]) -> None:
self._module_groups = None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

self._module_groups is already set to None at the begin. of __init__. Why do we set it to None again?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we add a module after calculating operation fusions, then we need to invalidate the operation fusions and recalculate.


def _get_keys_by_idx(self, idx: int | slice) -> list[str]:
"""Get module keys corresponding to indices"""
if isinstance(idx, slice):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should there be slice indices check as well?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In principle, but it's simpler to rely on the bounds checking in list. This implementation is similar to torch.nn.Sequential:
https://github.com/pytorch/pytorch/blob/389492e2640730b0a199ffe506582ed4fd2c4afc/torch/nn/modules/container.py#L140

# Reshape FP8 tensor
# Note: Preserve cached transpose if possible
if is_float8_tensor(tensor):
out = Float8Tensor.make_like(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How does this preserve the cache?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The transpose is part of Float8Tensor._fp8_attrs:

_transpose = property(**_make_fp8_attr_property_funcs("transpose"))

This function is not quite equivalent to the Float8Tensor's view or reshape functions since typically reshaping a tensor changes its transpose, while this function tries to preserve the 2D transpose.

def op_forward(
self,
ctx: OperationContext,
input: torch.Tensor, # pylint: disable=redefined-builtin
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are worth changing IMO.
inputinp

Copy link
Collaborator Author

@timmoon10 timmoon10 Jun 27, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd agree for internal implementations, but input feels much better for a user-facing API:

op = te.ops.AllGather(...)
y = op(input=x)

I suppose BasicOperation.op_forward can be considered internal implementation, so I've changed the arg name to input_. I feel strongly about about keeping the input arg in other functions like FusableOperation.forward.

Comment on lines +354 to +358
basic_op_ctxs[0],
input_,
basic_op_prev_ops[0],
basic_op_next_ops[0],
**basic_op_kwargs[0],
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you explain why we index 0 here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OperationFuser doesn't make any distinction between BasicOperation or FusedOperation, but interacts with them via the base class (e.g. FusableOperation.fuser_forward). A FusableOperation consists of one or more BasicOperations, so a BasicOperation will recieve just one ctx from OperationFuser while FusedOperation may recieve multiple.

Fix spelling of "fusible". Avoid "input" name in internal APIs.

Signed-off-by: Tim Moon <tmoon@nvidia.com>
@timmoon10
Copy link
Collaborator Author

/te-ci pytorch

Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com>
@timmoon10
Copy link
Collaborator Author

/te-ci pytorch

@timmoon10
Copy link
Collaborator Author

Merging with approval from @ksivaman, @sudhakarsingh27, @ptrendx. This feature is still experimental and incomplete.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants