Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support NVIDIA's Transformer Engine #17597

Merged
merged 28 commits into from
Jul 19, 2023
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/lightning/fabric/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added support for true half-precision as `L.Fabric(precision="16-true"|"bf16-true")` ([#17287](https://github.com/Lightning-AI/lightning/pull/17287))


- Added support for mixed 8bit precision as `L.Fabric(precision="8-mixed"|"8-mixed-transformer-engine")` using [Nvidia's Transformer Engine](https://docs.nvidia.com/deeplearning/transformer-engine) ([#17597](https://github.com/Lightning-AI/lightning/pull/17597))
carmocca marked this conversation as resolved.
Show resolved Hide resolved


- Added error messaging for missed `.launch()` when it is required ([#17570](https://github.com/Lightning-AI/lightning/pull/17570))


Expand Down
3 changes: 3 additions & 0 deletions src/lightning/fabric/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
TorchElasticEnvironment,
)
from lightning.fabric.plugins.precision.double import DoublePrecision
from lightning.fabric.plugins.precision.fp8_transformer_engine import Fp8TransformerEnginePrecision
from lightning.fabric.plugins.precision.fsdp import FSDPPrecision
from lightning.fabric.plugins.precision.precision import (
_PRECISION_INPUT,
Expand Down Expand Up @@ -448,6 +449,8 @@ def _check_and_init_precision(self) -> Precision:
return Precision()
if self._precision_input == "64-true":
return DoublePrecision()
if self._precision_input in ("8-mixed", "8-mixed-transformer-engine"):
return Fp8TransformerEnginePrecision()

if self._precision_input == "16-mixed" and self._accelerator_flag == "cpu":
rank_zero_warn(
Expand Down
116 changes: 116 additions & 0 deletions src/lightning/fabric/plugins/precision/fp8_transformer_engine.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
# Copyright The Lightning AI team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from contextlib import contextmanager
from typing import Any, Dict, Generator, Literal, Optional, TYPE_CHECKING, Union
carmocca marked this conversation as resolved.
Show resolved Hide resolved

import torch
from lightning_utilities.core.imports import RequirementCache

from lightning.fabric.plugins.precision.precision import Precision
from lightning.fabric.utilities.rank_zero import rank_zero_warn

_TRANSFORMER_ENGINE_AVAILABLE = RequirementCache("transformer_engine")

if TYPE_CHECKING and _TRANSFORMER_ENGINE_AVAILABLE:
from transformer_engine.common.recipe import DelayedScaling
else:
DelayedScaling = None
carmocca marked this conversation as resolved.
Show resolved Hide resolved


log = logging.getLogger(__name__)


class Fp8TransformerEnginePrecision(Precision):
"""Plugin for training with fp8 precision via nvidia's `Transformer Engine
<https://docs.nvidia.com/deeplearning/transformer-engine`__.

.. warning:: This is an :ref:`experimental <versioning:Experimental API>` feature.

Args:
recipe: Recipe for the DelayedScaling
`configuration <https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/api/common.html#transform
er_engine.common.recipe.DelayedScaling`__. In dict format or the dataclass format.
replace_layers: Whether to replace ``Linear`` and ``LayerNorm`` layers automatically with their Transformer
Engine alternatives.
"""

precision: Literal["8-mixed-transformer-engine"] = "8-mixed-transformer-engine"

def __init__(
self, recipe: Optional[Union[Dict[str, Any], "DelayedScaling"]] = None, replace_layers: bool = True
) -> None:
if not torch.cuda.get_device_capability() >= (8, 1):
raise NotImplementedError("Your CUDA device does not support fp8 mixed precision.")
if not _TRANSFORMER_ENGINE_AVAILABLE:
raise ModuleNotFoundError(str(_TRANSFORMER_ENGINE_AVAILABLE))
from transformer_engine.common.recipe import DelayedScaling

if recipe is None:
recipe = DelayedScaling()
elif isinstance(recipe, dict):
carmocca marked this conversation as resolved.
Show resolved Hide resolved
carmocca marked this conversation as resolved.
Show resolved Hide resolved
if "fp8_format" in recipe:
from transformer_engine.common.recipe import Format

recipe["fp8_format"] = getattr(Format, recipe["fp8_format"])
recipe = DelayedScaling(**recipe)

self.recipe = recipe
self.replace_layers = replace_layers

@contextmanager
def forward_context(self) -> Generator[None, None, None]:
carmocca marked this conversation as resolved.
Show resolved Hide resolved
import transformer_engine.pytorch as te

with te.fp8_autocast(enabled=True, fp8_recipe=self.recipe):
carmocca marked this conversation as resolved.
Show resolved Hide resolved
yield

def convert_module(self, module: torch.nn.Module) -> torch.nn.Module:
carmocca marked this conversation as resolved.
Show resolved Hide resolved
# avoid converting if any is found. assume the user took care of it
if self.replace_layers and not any("transformer_engine" in m.__module__ for m in module.modules()):
_convert_layers(module)
return module
# TODO: should we un-convert on teardown?
carmocca marked this conversation as resolved.
Show resolved Hide resolved


def _convert_layers(module: torch.nn.Module) -> None:
import transformer_engine.pytorch as te

for name, child in module.named_children():
if isinstance(child, torch.nn.Linear):
carmocca marked this conversation as resolved.
Show resolved Hide resolved
if child.in_features % 16 != 0 or child.out_features % 16 != 0:
carmocca marked this conversation as resolved.
Show resolved Hide resolved
# https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/examples/fp8_primer.html#FP8-autocasting
rank_zero_warn(
"Support for FP8 in the linear layers with `precision='8-mixed'` is currently limited to tensors"
f" with shapes where both dimensions are divisible by 16. The layer {name!r} does not fit this"
" criteria. You might want to add padding to your inputs."
)
continue
has_bias = child.bias is not None
replacement = te.Linear(child.in_features, child.out_features, bias=has_bias)
carmocca marked this conversation as resolved.
Show resolved Hide resolved
replacement.weight.data = child.weight.data.clone()
if has_bias:
replacement.bias.data = child.bias.data.clone()
log.debug(f"Replacing layer {name!r} with Transformer Engine equivalent")
module.__setattr__(name, replacement)
elif isinstance(child, torch.nn.LayerNorm):
replacement = te.LayerNorm(child.normalized_shape[0], eps=child.eps)
replacement.weight.data = child.weight.data.clone()
replacement.bias.data = child.bias.data.clone()
log.debug(f"Replacing layer {name!r} with Transformer Engine equivalent")
module.__setattr__(name, replacement)
else:
# there are other transformer engine layers that we could convert but are more niche. full list at:
carmocca marked this conversation as resolved.
Show resolved Hide resolved
# https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/api/pytorch.html
_convert_layers(child)
4 changes: 3 additions & 1 deletion src/lightning/fabric/plugins/precision/precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@
_PRECISION_INPUT_INT = Literal[64, 32, 16]
_PRECISION_INPUT_STR_ALIAS_CONVERSION = {"64": "64-true", "32": "32-true", "16": "16-mixed", "bf16": "bf16-mixed"}
carmocca marked this conversation as resolved.
Show resolved Hide resolved
_PRECISION_INPUT_STR_ALIAS = Literal["64", "32", "16", "bf16"]
_PRECISION_INPUT_STR = Literal["16-true", "16-mixed", "bf16-true", "bf16-mixed", "32-true", "64-true"]
_PRECISION_INPUT_STR = Literal[
"8-mixed", "8-mixed-transformer-engine", "16-true", "16-mixed", "bf16-true", "bf16-mixed", "32-true", "64-true"
]
_PRECISION_INPUT = Union[_PRECISION_INPUT_INT, _PRECISION_INPUT_STR, _PRECISION_INPUT_STR_ALIAS]


Expand Down
58 changes: 58 additions & 0 deletions tests/tests_fabric/test_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License
import inspect
import os
import sys
from typing import Any, Dict
from unittest import mock
from unittest.mock import Mock
Expand Down Expand Up @@ -40,6 +41,7 @@
XLAEnvironment,
)
from lightning.fabric.plugins.io import TorchCheckpointIO
from lightning.fabric.plugins.precision.fp8_transformer_engine import Fp8TransformerEnginePrecision
from lightning.fabric.strategies import (
DataParallelStrategy,
DDPStrategy,
Expand Down Expand Up @@ -1014,3 +1016,59 @@ def _mock_interactive():
assert isinstance(connector.strategy.cluster_environment, XLAEnvironment)
assert connector.strategy.launcher._start_method == "fork"
assert connector.strategy.launcher.is_interactive_compatible


@mock.patch("torch.cuda.get_device_capability", return_value=(9, 0))
def test_connector_fp8_transformer_engine(_, monkeypatch):
monkeypatch.setattr(
lightning.fabric.plugins.precision.fp8_transformer_engine, "_TRANSFORMER_ENGINE_AVAILABLE", lambda: True
)
monkeypatch.setitem(sys.modules, "transformer_engine", Mock())
monkeypatch.setitem(sys.modules, "transformer_engine.common", Mock())
recipe_mock = Mock()
monkeypatch.setitem(sys.modules, "transformer_engine.common.recipe", recipe_mock)

connector = _Connector(precision="8-mixed")
assert isinstance(connector.precision, Fp8TransformerEnginePrecision)

connector = _Connector(precision="8-mixed-transformer-engine")
assert isinstance(connector.precision, Fp8TransformerEnginePrecision)

recipe_mock.reset_mock()
precision = Fp8TransformerEnginePrecision()
connector = _Connector(plugins=precision)
assert connector.precision is precision
recipe_mock.DelayedScaling.assert_called_once_with()

recipe_mock.reset_mock()
precision = Fp8TransformerEnginePrecision({"foo": 0, "fp8_format": "HYBRID"})
connector = _Connector(plugins=precision)
assert connector.precision is precision
recipe_mock.DelayedScaling.assert_called_once_with(foo=0, fp8_format=recipe_mock.Format.HYBRID)

class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.l1 = torch.nn.Linear(16, 48)
self.l2 = torch.nn.Linear(1, 3)
self.l3 = torch.nn.LayerNorm(1)

monkeypatch.setitem(sys.modules, "transformer_engine.pytorch", Mock())
model = MyModule()

precision.replace_layers = False
precision.convert_module(model)
assert isinstance(model.l1, torch.nn.Linear)
assert isinstance(model.l3, torch.nn.LayerNorm)

precision.replace_layers = True
setattr_mock = Mock()
model.__setattr__ = setattr_mock
with pytest.warns(match="divisible by 16"):
precision.convert_module(model)
mock_calls = setattr_mock.mock_calls
assert len(mock_calls) == 2
assert mock_calls[0][1][0] == "l1"
assert mock_calls[1][1][0] == "l3"
assert mock_calls[0][1][1]._extract_mock_name() == "mock.pytorch.Linear()"
assert mock_calls[1][1][1]._extract_mock_name() == "mock.pytorch.LayerNorm()"