Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support NVIDIA's Transformer Engine #17597

Merged
merged 28 commits into from
Jul 19, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/lightning/fabric/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added support for true half-precision as `L.Fabric(precision="16-true"|"bf16-true")` ([#17287](https://github.com/Lightning-AI/lightning/pull/17287))


- Added support for mixed 8bit precision as `L.Fabric(precision="8-mixed"|"8-mixed-transformer-engine")` using [Nvidia's Transformer Engine](https://docs.nvidia.com/deeplearning/transformer-engine) ([#17597](https://github.com/Lightning-AI/lightning/pull/17597))
carmocca marked this conversation as resolved.
Show resolved Hide resolved


- Added error messaging for missed `.launch()` when it is required ([#17570](https://github.com/Lightning-AI/lightning/pull/17570))


Expand Down
3 changes: 3 additions & 0 deletions src/lightning/fabric/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
TorchElasticEnvironment,
)
from lightning.fabric.plugins.precision.double import DoublePrecision
from lightning.fabric.plugins.precision.fp8_transformer_engine import Fp8TransformerEnginePrecision
from lightning.fabric.plugins.precision.fsdp import FSDPPrecision
from lightning.fabric.plugins.precision.precision import (
_PRECISION_INPUT,
Expand Down Expand Up @@ -448,6 +449,8 @@ def _check_and_init_precision(self) -> Precision:
return Precision()
if self._precision_input == "64-true":
return DoublePrecision()
if self._precision_input in ("8-mixed", "8-mixed-transformer-engine"):
return Fp8TransformerEnginePrecision()

if self._precision_input == "16-mixed" and self._accelerator_flag == "cpu":
rank_zero_warn(
Expand Down
65 changes: 65 additions & 0 deletions src/lightning/fabric/plugins/precision/fp8_transformer_engine.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
# Copyright The Lightning AI team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from contextlib import contextmanager
from typing import Any, Dict, Generator, Literal, Optional, TYPE_CHECKING, Union
carmocca marked this conversation as resolved.
Show resolved Hide resolved

from lightning_utilities.core.imports import RequirementCache

from lightning.fabric.plugins.precision.precision import Precision

_TRANSFORMER_ENGINE_AVAILABLE = RequirementCache("transformer_engine")

if TYPE_CHECKING and _TRANSFORMER_ENGINE_AVAILABLE:
from transformer_engine.common.recipe import DelayedScaling
else:
DelayedScaling = None
carmocca marked this conversation as resolved.
Show resolved Hide resolved


class Fp8TransformerEnginePrecision(Precision):
"""Plugin for training with fp8 precision via nvidia's `Transformer Engine
<https://docs.nvidia.com/deeplearning/transformer-engine`__.

.. warning:: This is an :ref:`experimental <versioning:Experimental API>` feature.

Args:
precision: The precision
recipe: Recipe for the DelayedScaling
`configuration <https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/api/common.html#transform
er_engine.common.recipe.DelayedScaling`__. In dict format or the dataclass format.
"""

precision: Literal["8-mixed-transformer-engine"] = "8-mixed-transformer-engine"

def __init__(self, recipe: Optional[Union[Dict[str, Any], "DelayedScaling"]] = None) -> None:
if not _TRANSFORMER_ENGINE_AVAILABLE:
raise ModuleNotFoundError(str(_TRANSFORMER_ENGINE_AVAILABLE))
from transformer_engine.common.recipe import DelayedScaling

if recipe is None:
recipe = DelayedScaling()
elif isinstance(recipe, dict):
carmocca marked this conversation as resolved.
Show resolved Hide resolved
carmocca marked this conversation as resolved.
Show resolved Hide resolved
if "fp8_format" in recipe:
from transformer_engine.common.recipe import Format

recipe["fp8_format"] = getattr(Format, recipe["fp8_format"])
recipe = DelayedScaling(**recipe)

self.recipe = recipe

@contextmanager
def forward_context(self) -> Generator[None, None, None]:
carmocca marked this conversation as resolved.
Show resolved Hide resolved
import transformer_engine.pytorch as te

with te.fp8_autocast(enabled=True, fp8_recipe=self.recipe):
carmocca marked this conversation as resolved.
Show resolved Hide resolved
yield
4 changes: 3 additions & 1 deletion src/lightning/fabric/plugins/precision/precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@
_PRECISION_INPUT_INT = Literal[64, 32, 16]
_PRECISION_INPUT_STR_ALIAS_CONVERSION = {"64": "64-true", "32": "32-true", "16": "16-mixed", "bf16": "bf16-mixed"}
carmocca marked this conversation as resolved.
Show resolved Hide resolved
_PRECISION_INPUT_STR_ALIAS = Literal["64", "32", "16", "bf16"]
_PRECISION_INPUT_STR = Literal["16-true", "16-mixed", "bf16-true", "bf16-mixed", "32-true", "64-true"]
_PRECISION_INPUT_STR = Literal[
"8-mixed", "8-mixed-transformer-engine", "16-true", "16-mixed", "bf16-true", "bf16-mixed", "32-true", "64-true"
]
_PRECISION_INPUT = Union[_PRECISION_INPUT_INT, _PRECISION_INPUT_STR, _PRECISION_INPUT_STR_ALIAS]


Expand Down
30 changes: 30 additions & 0 deletions tests/tests_fabric/test_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License
import inspect
import os
import sys
from typing import Any, Dict
from unittest import mock
from unittest.mock import Mock
Expand Down Expand Up @@ -40,6 +41,7 @@
XLAEnvironment,
)
from lightning.fabric.plugins.io import TorchCheckpointIO
from lightning.fabric.plugins.precision.fp8_transformer_engine import Fp8TransformerEnginePrecision
from lightning.fabric.strategies import (
DataParallelStrategy,
DDPStrategy,
Expand Down Expand Up @@ -1014,3 +1016,31 @@ def _mock_interactive():
assert isinstance(connector.strategy.cluster_environment, XLAEnvironment)
assert connector.strategy.launcher._start_method == "fork"
assert connector.strategy.launcher.is_interactive_compatible


def test_connector_fp8_transformer_engine(monkeypatch):
monkeypatch.setattr(
lightning.fabric.plugins.precision.fp8_transformer_engine, "_TRANSFORMER_ENGINE_AVAILABLE", lambda: True
)
monkeypatch.setitem(sys.modules, "transformer_engine", Mock())
monkeypatch.setitem(sys.modules, "transformer_engine.common", Mock())
recipe_mock = Mock()
monkeypatch.setitem(sys.modules, "transformer_engine.common.recipe", recipe_mock)

connector = _Connector(precision="8-mixed")
assert isinstance(connector.precision, Fp8TransformerEnginePrecision)

connector = _Connector(precision="8-mixed-transformer-engine")
assert isinstance(connector.precision, Fp8TransformerEnginePrecision)

recipe_mock.reset_mock()
precision = Fp8TransformerEnginePrecision()
connector = _Connector(plugins=precision)
assert connector.precision is precision
recipe_mock.DelayedScaling.assert_called_once_with()

recipe_mock.reset_mock()
precision = Fp8TransformerEnginePrecision({"foo": 0, "fp8_format": "HYBRID"})
connector = _Connector(plugins=precision)
assert connector.precision is precision
recipe_mock.DelayedScaling.assert_called_once_with(foo=0, fp8_format=recipe_mock.Format.HYBRID)