diff --git a/docs/source-fabric/api/fabric_args.rst b/docs/source-fabric/api/fabric_args.rst index 4e140814a7657..b285fc2dd8f43 100644 --- a/docs/source-fabric/api/fabric_args.rst +++ b/docs/source-fabric/api/fabric_args.rst @@ -110,10 +110,12 @@ Learn more about :ref:`distributed multi-node training on clusters `_). -Half precision, or mixed precision, combines 32 and 16-bit floating points to reduce the memory footprint during model training. -Automatic mixed precision settings are denoted by a ``"-mixed"`` suffix, while settings that only work in the specified precision have a ``"-true"`` suffix. -This can result in improved performance, achieving significant speedups on modern GPUs. +There are two different techniques to set the mixed precision. "True" precision and "Mixed" precision. +For an extensive guide into their differences, please see: :doc:`../fundamentals/precision` + +Fabric supports doing floating point operations in 64-bit precision ("double"), 32-bit precision ("full"), or 16-bit ("half") with both regular and `bfloat16 `_). +This selected precision will have a direct impact in the performance and memory usage based on your hardware. +Automatic mixed precision settings are denoted by a ``"-mixed"`` suffix, while "true" precision settings have a ``"-true"`` suffix: .. code-block:: python @@ -129,6 +131,9 @@ This can result in improved performance, achieving significant speedups on moder # 16-bit bfloat mixed precision (model weights remain in torch.float32) fabric = Fabric(precision="bf16-mixed", devices=1) + # 8-bit mixed precision via TransformerEngine (model weights remain in torch.float32) + fabric = Fabric(precision="transformer-engine", devices=1) + # 16-bit precision (model weights get cast to torch.float16) fabric = Fabric(precision="16-true", devices=1) @@ -138,8 +143,6 @@ This can result in improved performance, achieving significant speedups on moder # 64-bit (double) precision (model weights get cast to torch.float64) fabric = Fabric(precision="64-true", devices=1) -See also: :doc:`../fundamentals/precision` - plugins ======= diff --git a/docs/source-fabric/fundamentals/precision.rst b/docs/source-fabric/fundamentals/precision.rst index 298c561d1df47..0ab7d46dcb718 100644 --- a/docs/source-fabric/fundamentals/precision.rst +++ b/docs/source-fabric/fundamentals/precision.rst @@ -14,8 +14,8 @@ Save memory with mixed precision What is Mixed Precision? ************************ -Like most deep learning frameworks, PyTorch trains on 32-bit floating-point (FP32) arithmetic by default. -However, many deep learning models do not require this to reach complete accuracy. +Like most deep learning frameworks, PyTorch runs on 32-bit floating-point (FP32) arithmetic by default. +However, many deep learning models do not require this to reach complete accuracy during training. Mixed precision training delivers significant computational speedup by conducting operations in half-precision while keeping minimum information in single-precision to maintain as much information as possible in crucial areas of the network. Switching to mixed precision has resulted in considerable training speedups since the introduction of Tensor Cores in the Volta and Turing architectures. It combines FP32 and lower-bit floating points (such as FP16) to reduce memory footprint and increase performance during model training and evaluation. @@ -31,25 +31,34 @@ This is how you select the precision in Fabric: # This is the default fabric = Fabric(precision="32-true") - # Also FP32 + # Also FP32 (legacy) fabric = Fabric(precision=32) - # FP32 as well + # FP32 as well (legacy) fabric = Fabric(precision="32") - # FP16 mixed precision + # Float16 mixed precision fabric = Fabric(precision="16-mixed") - # BFloat16 precision (Volta GPUs and later) + # Float16 true half precision + fabric = Fabric(precision="16-true") + + # BFloat16 mixed precision (Volta GPUs and later) fabric = Fabric(precision="bf16-mixed") + # BFloat16 true half precision (Volta GPUs and later) + fabric = Fabric(precision="bf16-true") + + # 8-bit mixed precision via TransformerEngine (Hopper GPUs and later) + fabric = Fabric(precision="transformer-engine") + # Double precision fabric = Fabric(precision="64-true") - # Or + # Or (legacy) fabric = Fabric(precision="64") - # Or + # Or (legacy) fabric = Fabric(precision=64) @@ -75,7 +84,7 @@ FP16 Mixed Precision In most cases, mixed precision uses FP16. Supported `PyTorch operations `_ automatically run in FP16, saving memory and improving throughput on the supported accelerators. -Since computation happens in FP16, there is a chance of numerical instability during training. +Since computation happens in FP16, which has a very limited "dynamic range", there is a chance of numerical instability during training. This is handled internally by a dynamic grad scaler which skips invalid steps and adjusts the scaler to ensure subsequent steps fall within a finite range. For more information `see the autocast docs `_. @@ -114,7 +123,47 @@ It is also possible to use BFloat16 mixed precision on the CPU, relying on MKLDN .. note:: BFloat16 may not provide significant speedups or memory improvements, offering better numerical stability. - For GPUs, the most significant benefits require `Ampere `_ based GPUs, such as A100s or 3090s. + For GPUs, the most significant benefits require `Ampere `_ based GPUs or newer, such as A100s or 3090s. + + +---- + + +***************************************************** +Float8 Mixed Precision via Nvidia's TransformerEngine +***************************************************** + +`Transformer Engine `__ (TE) is a library for accelerating models on the +latest NVIDIA GPUs using 8-bit floating point (FP8) precision on Hopper GPUs, to provide better performance with lower +memory utilization in both training and inference. It offers improved performance over half precision with no degradation in accuracy. + +Using TE requires replacing some of the layers in your model. Fabric automatically replaces the :class:`torch.nn.Linear` +and :class:`torch.nn.LayerNorm` layers in your model with their TE alternatives, however, TE also offers +`fused layers `__ +to squeeze out all the possible performance. If Fabric detects that any layer has been replaced already, automatic +replacement is not done. + +This plugin is a mix of "mixed" and "true" precision. The computation is downcasted to FP8 precision on the fly, but +the model and inputs can be kept in true full or half precision. + +.. code-block:: python + + # Select 8bit mixed precision via TransformerEngine + fabric = Fabric(precision="transformer-engine") + + # Customize the fp8 recipe or set a different base precision: + from lightning.fabric.plugins.precision import TransformerEnginePrecision + + recipe = {"fp8_format": "HYBRID", "amax_history_len": 16, "amax_compute_algo": "max"} + precision = TransformerEnginePrecision(dtype=torch.bfloat16, recipe=recipe) + fabric = Fabric(plugins=precision) + + +Under the hood, we use `transformer_engine.pytorch.fp8_autocast `__ with the default fp8 recipe. + +.. note:: + + This requires `Hopper `_ based GPUs or newer, such the H100. ---- diff --git a/src/lightning/fabric/CHANGELOG.md b/src/lightning/fabric/CHANGELOG.md index 3acd73df4e1ee..442abcfeec020 100644 --- a/src/lightning/fabric/CHANGELOG.md +++ b/src/lightning/fabric/CHANGELOG.md @@ -55,6 +55,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added support for true half-precision as `L.Fabric(precision="16-true"|"bf16-true")` ([#17287](https://github.com/Lightning-AI/lightning/pull/17287)) +- Added support for mixed 8-bit precision as `L.Fabric(precision="transformer-engine")` using [Nvidia's Transformer Engine](https://docs.nvidia.com/deeplearning/transformer-engine) ([#17597](https://github.com/Lightning-AI/lightning/pull/17597)) + + - Added error messaging for missed `.launch()` when it is required ([#17570](https://github.com/Lightning-AI/lightning/pull/17570)) diff --git a/src/lightning/fabric/connector.py b/src/lightning/fabric/connector.py index 255d439e6df92..0e57d5f2f1df8 100644 --- a/src/lightning/fabric/connector.py +++ b/src/lightning/fabric/connector.py @@ -50,6 +50,7 @@ _PRECISION_INPUT_STR_ALIAS, _PRECISION_INPUT_STR_ALIAS_CONVERSION, ) +from lightning.fabric.plugins.precision.transformer_engine import TransformerEnginePrecision from lightning.fabric.strategies import ( DeepSpeedStrategy, ParallelStrategy, @@ -449,6 +450,8 @@ def _check_and_init_precision(self) -> Precision: return Precision() if self._precision_input == "64-true": return DoublePrecision() + if self._precision_input == "transformer-engine": + return TransformerEnginePrecision() if self._precision_input == "16-mixed" and self._accelerator_flag == "cpu": rank_zero_warn( diff --git a/src/lightning/fabric/plugins/precision/__init__.py b/src/lightning/fabric/plugins/precision/__init__.py index f5c5ac9817fb9..b5b1ca0ef0d27 100644 --- a/src/lightning/fabric/plugins/precision/__init__.py +++ b/src/lightning/fabric/plugins/precision/__init__.py @@ -17,6 +17,7 @@ from lightning.fabric.plugins.precision.fsdp import FSDPPrecision from lightning.fabric.plugins.precision.half import HalfPrecision from lightning.fabric.plugins.precision.precision import Precision +from lightning.fabric.plugins.precision.transformer_engine import TransformerEnginePrecision from lightning.fabric.plugins.precision.xla import XLAPrecision from lightning.fabric.plugins.precision.xlabf16 import XLABf16Precision @@ -29,4 +30,5 @@ "XLAPrecision", "XLABf16Precision", "FSDPPrecision", + "TransformerEnginePrecision", ] diff --git a/src/lightning/fabric/plugins/precision/precision.py b/src/lightning/fabric/plugins/precision/precision.py index 147bd55800b78..1add95da884ad 100644 --- a/src/lightning/fabric/plugins/precision/precision.py +++ b/src/lightning/fabric/plugins/precision/precision.py @@ -23,7 +23,9 @@ _PRECISION_INPUT_INT = Literal[64, 32, 16] _PRECISION_INPUT_STR_ALIAS_CONVERSION = {"64": "64-true", "32": "32-true", "16": "16-mixed", "bf16": "bf16-mixed"} _PRECISION_INPUT_STR_ALIAS = Literal["64", "32", "16", "bf16"] -_PRECISION_INPUT_STR = Literal["16-true", "16-mixed", "bf16-true", "bf16-mixed", "32-true", "64-true"] +_PRECISION_INPUT_STR = Literal[ + "transformer-engine", "16-true", "16-mixed", "bf16-true", "bf16-mixed", "32-true", "64-true" +] _PRECISION_INPUT = Union[_PRECISION_INPUT_INT, _PRECISION_INPUT_STR, _PRECISION_INPUT_STR_ALIAS] diff --git a/src/lightning/fabric/plugins/precision/transformer_engine.py b/src/lightning/fabric/plugins/precision/transformer_engine.py new file mode 100644 index 0000000000000..3d8d0c4ccfdc9 --- /dev/null +++ b/src/lightning/fabric/plugins/precision/transformer_engine.py @@ -0,0 +1,163 @@ +# Copyright The Lightning AI team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +from contextlib import contextmanager +from typing import Any, Generator, Literal, Mapping, Optional, TYPE_CHECKING, Union + +import torch +from lightning_utilities import apply_to_collection +from lightning_utilities.core.imports import RequirementCache +from torch import Tensor + +from lightning.fabric.plugins.precision.precision import Precision +from lightning.fabric.plugins.precision.utils import _convert_fp_tensor +from lightning.fabric.utilities.rank_zero import rank_zero_warn + +_TRANSFORMER_ENGINE_AVAILABLE = RequirementCache("transformer_engine>=0.11.0") + +if TYPE_CHECKING and _TRANSFORMER_ENGINE_AVAILABLE: + from transformer_engine.common.recipe import DelayedScaling + + +log = logging.getLogger(__name__) + + +class TransformerEnginePrecision(Precision): + """Plugin for training with fp8 precision via nvidia's `Transformer Engine + ` feature. + + Args: + dtype: The base dtype to use. + recipe: Recipe for the DelayedScaling + `configuration None: + if not _TRANSFORMER_ENGINE_AVAILABLE: + raise ModuleNotFoundError(str(_TRANSFORMER_ENGINE_AVAILABLE)) + from transformer_engine.common.recipe import DelayedScaling + + if recipe is None: + recipe = DelayedScaling() + elif isinstance(recipe, Mapping): + recipe = dict(recipe) # copy + if "fp8_format" in recipe: + from transformer_engine.common.recipe import Format + + recipe["fp8_format"] = getattr(Format, recipe["fp8_format"]) + recipe = DelayedScaling(**recipe) + + if dtype is None: + dtype = torch.get_default_dtype() + self.dtype = dtype + self.recipe = recipe + self.replace_layers = replace_layers + + def convert_module(self, module: torch.nn.Module) -> torch.nn.Module: + # avoid converting if any is found. assume the user took care of it + if self.replace_layers and not any("transformer_engine" in m.__module__ for m in module.modules()): + _convert_layers(module) + module = module.to(dtype=self.dtype) + return module + + @contextmanager + def init_context(self) -> Generator[None, None, None]: + import transformer_engine.pytorch as te + + default_dtype = torch.get_default_dtype() + torch.set_default_dtype(self.dtype) + + replace_layers = self.replace_layers + if replace_layers: + original_linear = torch.nn.Linear + original_layer_norm = torch.nn.LayerNorm + torch.nn.Linear = te.Linear # type: ignore[misc] + torch.nn.LayerNorm = te.LayerNorm # type: ignore[misc] + + yield + + if replace_layers: + torch.nn.Linear = original_linear # type: ignore[misc] + torch.nn.LayerNorm = original_layer_norm # type: ignore[misc] + + torch.set_default_dtype(default_dtype) + + @contextmanager + def forward_context(self) -> Generator[None, None, None]: + default_dtype = torch.get_default_dtype() + torch.set_default_dtype(self.dtype) + + import transformer_engine.pytorch as te + + with te.fp8_autocast(enabled=True, fp8_recipe=self.recipe): + yield + + torch.set_default_dtype(default_dtype) + + def convert_input(self, data: Any) -> Any: + return apply_to_collection(data, function=_convert_fp_tensor, dtype=Tensor, dst_type=self.dtype) + + def convert_output(self, data: Any) -> Any: + return apply_to_collection(data, function=_convert_fp_tensor, dtype=Tensor, dst_type=torch.get_default_dtype()) + + +def _convert_layers(module: torch.nn.Module) -> None: + import transformer_engine.pytorch as te + + for name, child in module.named_children(): + if isinstance(child, torch.nn.Linear): + if child.in_features % 8 != 0 or child.out_features % 16 != 0: + # https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/examples/fp8_primer.html#FP8-autocasting + rank_zero_warn( + "Support for FP8 in the linear layers with `precision='transformer-engine'` is currently limited to" + "tensors with shapes where the dimensions are divisible by 8 and 16 respectively." + f"The layer {name!r} does not fit this criteria. You might want to add padding to your inputs." + ) + continue + has_bias = child.bias is not None + replacement = te.Linear(child.in_features, child.out_features, bias=has_bias) + replacement.weight.data = child.weight.data.clone() + if has_bias: + replacement.bias.data = child.bias.data.clone() + log.debug(f"Replacing layer {name!r} with Transformer Engine equivalent") + module.__setattr__(name, replacement) + elif isinstance(child, torch.nn.LayerNorm): + replacement = te.LayerNorm(child.normalized_shape[0], eps=child.eps) + replacement.weight.data = child.weight.data.clone() + replacement.bias.data = child.bias.data.clone() + log.debug(f"Replacing layer {name!r} with Transformer Engine equivalent") + module.__setattr__(name, replacement) + else: + # there are other transformer engine layers that we could convert but require fusion. full list at: + # https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/api/pytorch.html + _convert_layers(child) diff --git a/tests/tests_fabric/test_connector.py b/tests/tests_fabric/test_connector.py index fcf35c8a68094..8fa157342e6bc 100644 --- a/tests/tests_fabric/test_connector.py +++ b/tests/tests_fabric/test_connector.py @@ -13,6 +13,7 @@ # limitations under the License import inspect import os +import sys from typing import Any, Dict from unittest import mock from unittest.mock import Mock @@ -48,6 +49,7 @@ XLAEnvironment, ) from lightning.fabric.plugins.io import TorchCheckpointIO +from lightning.fabric.plugins.precision.transformer_engine import TransformerEnginePrecision from lightning.fabric.strategies import ( DataParallelStrategy, DDPStrategy, @@ -1037,3 +1039,88 @@ def _mock_interactive(): assert isinstance(connector.strategy.cluster_environment, XLAEnvironment) assert connector.strategy.launcher._start_method == "fork" assert connector.strategy.launcher.is_interactive_compatible + + +def test_connector_transformer_engine(monkeypatch): + monkeypatch.setattr( + lightning.fabric.plugins.precision.transformer_engine, "_TRANSFORMER_ENGINE_AVAILABLE", lambda: True + ) + transformer_engine_mock = Mock() + monkeypatch.setitem(sys.modules, "transformer_engine", transformer_engine_mock) + recipe_mock = Mock() + monkeypatch.setitem(sys.modules, "transformer_engine.common.recipe", recipe_mock) + + connector = _Connector(precision="transformer-engine") + assert isinstance(connector.precision, TransformerEnginePrecision) + + recipe_mock.reset_mock() + precision = TransformerEnginePrecision() + connector = _Connector(plugins=precision) + assert connector.precision is precision + assert precision.dtype == torch.float32 + recipe_mock.DelayedScaling.assert_called_once_with() + + recipe_mock.reset_mock() + recipe = {"foo": 0, "fp8_format": "HYBRID"} + precision = TransformerEnginePrecision(dtype=torch.float16, recipe=recipe) + connector = _Connector(plugins=precision) + assert connector.precision is precision + recipe_mock.DelayedScaling.assert_called_once_with(foo=0, fp8_format=recipe_mock.Format.HYBRID) + assert isinstance(recipe["fp8_format"], str) # not modified + + class SubModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.l = torch.nn.Linear(1, 3) + + class MyModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.l1 = torch.nn.Linear(16, 48) + self.l2 = torch.nn.LayerNorm(1) + self.l3 = SubModule() + + monkeypatch.setitem(sys.modules, "transformer_engine.pytorch", Mock()) + model = MyModule() + + precision.replace_layers = False + precision.convert_module(model) + assert isinstance(model.l1, torch.nn.Linear) + assert model.l1.weight.dtype == torch.float16 + assert isinstance(model.l3.l, torch.nn.Linear) + assert isinstance(model.l2, torch.nn.LayerNorm) + + precision.replace_layers = True + setattr_mock = Mock() + model.__setattr__ = setattr_mock + with pytest.warns(match="divisible by 8 and 16"): + precision.convert_module(model) + mock_calls = setattr_mock.mock_calls + assert len(mock_calls) == 2 + assert mock_calls[0][1][0] == "l1" + assert mock_calls[1][1][0] == "l2" + assert mock_calls[0][1][1]._extract_mock_name() == "mock.pytorch.Linear()" + assert mock_calls[1][1][1]._extract_mock_name() == "mock.pytorch.LayerNorm()" + + precision.replace_layers = False + with precision.init_context(): + model = MyModule() + assert isinstance(model.l1, torch.nn.Linear) + assert isinstance(model.l2, torch.nn.LayerNorm) + assert isinstance(model.l3.l, torch.nn.Linear) + + class TELinearMock(Mock): + ... + + class TELayerNormMock(Mock): + ... + + transformer_engine_mock.pytorch.Linear = TELinearMock + transformer_engine_mock.pytorch.LayerNorm = TELayerNormMock + precision.replace_layers = True + with precision.init_context(): + assert torch.get_default_dtype() == torch.float16 + model = MyModule() + assert isinstance(model.l1, TELinearMock) + assert isinstance(model.l2, TELayerNormMock) + assert isinstance(model.l3.l, TELinearMock)