Skip to content

Commit

Permalink
Akoumparouli/nemo ux precision plugin refactor (NVIDIA#10129)
Browse files Browse the repository at this point in the history
* rename mixed_precision.py to precision.py

Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com>

* replace print with logging.warning

Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com>

* Apply isort and black reformatting

Signed-off-by: akoumpa <akoumpa@users.noreply.github.com>

* also patch ddp_config

Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com>

* Rename patch_dtype_config to update_config_with_dtype_overrides

Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com>

* Add GradScaler's args to constructor's arg list

Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com>

* Apply isort and black reformatting

Signed-off-by: akoumpa <akoumpa@users.noreply.github.com>

* fix import

Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com>

* Leverage mcore's fp16 grad scaler

Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com>

* remove unused param

Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com>

* Add precision plugin test

Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com>

* Apply isort and black reformatting

Signed-off-by: akoumpa <akoumpa@users.noreply.github.com>

* Also update __io__ configs

Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com>

* remove unused imports

Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com>

* fix fabric to ptl converter mcore precision plugin

Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com>

* fix test

Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com>

---------

Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com>
Signed-off-by: akoumpa <akoumpa@users.noreply.github.com>
Co-authored-by: akoumpa <akoumpa@users.noreply.github.com>
  • Loading branch information
akoumpa and akoumpa authored Aug 20, 2024
1 parent d4f02b5 commit 60442c2
Show file tree
Hide file tree
Showing 4 changed files with 212 additions and 54 deletions.
2 changes: 0 additions & 2 deletions nemo/lightning/fabric/plugins.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,4 @@ def forward_context(self) -> Generator[None, None, None]:
def _convert_megatron_mixed_precision(plugin: MegatronMixedPrecision) -> FabricMegatronMixedPrecision:
return FabricMegatronMixedPrecision(
precision=plugin.precision,
device=plugin.device,
scaler=plugin.scaler,
)
159 changes: 107 additions & 52 deletions nemo/lightning/pytorch/plugins/mixed_precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,16 @@
# limitations under the License.

from contextlib import contextmanager
from dataclasses import dataclass, fields
from typing import Any, Callable, Generator, List, Literal, Tuple, TypeVar, Union

import pytorch_lightning as pl
import torch
from pytorch_lightning.plugins.precision import MixedPrecision
from pytorch_lightning.plugins.precision import Precision
from torch.nn import Module
from torch.optim import Optimizer

from nemo.lightning._strategy_lib import GradScaler
from nemo.utils import logging

AnyT = TypeVar("AnyT")

Expand All @@ -33,18 +34,93 @@ def get_optim_config(optimizer: Optimizer):
raise ValueError("Failed to extract optimizer config from module.")


class MegatronMixedPrecision(MixedPrecision):
@dataclass
class DtypeConfig:
fp32: bool = False
fp16: bool = False
bf16: bool = False
params_dtype: torch.dtype = None
pipeline_dtype: torch.dtype = None
autocast_dtype: torch.dtype = None
autocast_enabled: bool = False
grad_reduce_in_fp32: bool = True
# fp8 related
fp8: str = None
fp8_margin: int = 0
fp8_interval: int = 1
fp8_amax_history_len: int = 1
fp8_amax_compute_algo: str = "most_recent"
fp8_wgrad: bool = True
fp8_dot_product_attention: bool = False
fp8_multi_head_attention: bool = False
# FP16 Loss scaling
loss_scale: float = (None,)
initial_loss_scale: float = (None,)
min_loss_scale: float = (None,)
loss_scale_window: float = (None,)
hysteresis: float = (None,)


class MegatronMixedPrecision(Precision):
def __init__(
self,
precision: Literal["16-mixed", "bf16-mixed"],
device="cuda",
precision: Literal["16-mixed", "bf16-mixed", "32"],
params_dtype: torch.dtype = None,
pipeline_dtype: torch.dtype = None,
autocast_dtype: torch.dtype = None,
autocast_enabled: bool = False,
grad_reduce_in_fp32: bool = True,
# fp8 related,
fp8: str = None,
fp8_margin: int = 0,
fp8_interval: int = 1,
fp8_amax_history_len: int = 1,
fp8_amax_compute_algo: str = "most_recent",
fp8_wgrad: bool = True,
fp8_dot_product_attention: bool = False,
fp8_multi_head_attention: bool = False,
fp16_loss_scale: float = None,
fp16_initial_loss_scale: float = 4294967296,
fp16_min_loss_scale: float = 1.0,
fp16_loss_scale_window: int = 1000,
fp16_hysteresis: int = 2,
) -> None:
if precision == "bf16-mixed":
scaler = None
else:
scaler = GradScaler(init_scale=2**32, growth_interval=1000, hysteresis=2)

super().__init__(precision, device, scaler)
if isinstance(precision, int):
precision = str(precision)

dtype = torch.bfloat16 if precision in ['bf16', 'bf16-mixed'] else torch.float32
self.dtype_config = DtypeConfig(
fp32=precision in ['fp32', '32'],
fp16=precision in ['fp16', 'fp16-mixed', '16', '16-mixed'],
bf16=precision in ['bf16', 'bf16-mixed'],
params_dtype=params_dtype or torch.float32,
pipeline_dtype=pipeline_dtype or dtype,
autocast_dtype=autocast_dtype or dtype,
autocast_enabled=autocast_enabled,
grad_reduce_in_fp32=grad_reduce_in_fp32,
fp8=fp8,
fp8_margin=fp8_margin,
fp8_interval=fp8_interval,
fp8_amax_history_len=fp8_amax_history_len,
fp8_amax_compute_algo=fp8_amax_compute_algo,
fp8_wgrad=fp8_wgrad,
fp8_dot_product_attention=fp8_dot_product_attention,
fp8_multi_head_attention=fp8_multi_head_attention,
# fp16 loss scale
loss_scale=fp16_loss_scale,
initial_loss_scale=fp16_initial_loss_scale,
min_loss_scale=fp16_min_loss_scale,
loss_scale_window=fp16_loss_scale_window,
hysteresis=fp16_hysteresis,
)
super().__init__()
if self.dtype_config.fp16:
self.precision = "16-mixed"
elif self.dtype_config.bf16:
self.precision = "bf16-mixed"
else:
self.precision = "32-true"

def convert_module(self, module: Module) -> Module:
"""Convert the module parameters to the precision type this plugin handles.
Expand All @@ -55,11 +131,11 @@ def convert_module(self, module: Module) -> Module:
from megatron.core.transformer.module import Float16Module
from megatron.core.utils import get_model_config

if self.precision in ["16-mixed", "bf16-mixed"]:
if self.dtype_config.fp16 or self.dtype_config.bf16:
# Patch config options
config = get_model_config(module.module)
config.fp16 = self.precision == "16-mixed"
config.bf16 = self.precision == "bf16-mixed"
config.autocast = False
config.fp16 = self.dtype_config.fp16
config.bf16 = self.dtype_config.bf16
if hasattr(module, 'module'):
module.module = Float16Module(config, module.module)
else:
Expand All @@ -74,8 +150,8 @@ def convert_optimizer(self, optimizer: Optimizer) -> Optimizer:
"""
optim_config = get_optim_config(optimizer)
assert optim_config.bf16 == (self.precision == "bf16-mixed"), "BF16 enabled on model but not on optimizer"
assert optim_config.fp16 == (self.precision == "fp16-mixed"), "BF16 enabled on model but not on optimizer"
assert optim_config.bf16 == self.dtype_config.bf16, "BF16 enabled on model but not on optimizer"
assert optim_config.fp16 == self.dtype_config.fp16, "BF16 enabled on model but not on optimizer"
return optimizer

def convert_input(self, data: AnyT) -> AnyT:
Expand All @@ -96,42 +172,6 @@ def convert_output(self, data: AnyT) -> AnyT:
"""
return data

def optimizer_step(
self,
optimizer: torch.optim.Optimizer,
model: Union["pl.LightningModule", torch.nn.Module],
closure: Callable[[], Any],
**kwargs: Any,
) -> None:
from nemo.core.optim import MainParamsOptimizerWrapper

if not isinstance(optimizer, MainParamsOptimizerWrapper):
return super().optimizer_step(optimizer, model, closure, **kwargs)

if self.scaler is None:
assert optimizer.fp32_grad_accumulation, "BF16 uses FP32 grad accumulation"
_ = closure()
self._after_closure(model, optimizer)
return optimizer.step(**kwargs)

assert not optimizer.fp32_grad_accumulation, "FP16 uses FP16 grad accumulation"
closure_result = closure()

# TODO: Add an option for merged all-reduce

# cast fp16 grads to fp32 and copy to main grads, which are used for unscale and param update
optimizer.copy_model_grads_to_main_grads()
# `unscale` after the closure is executed but before the `on_before_optimizer_step` hook.
# unscale main (fp32) gradients
self.scaler.unscale_(optimizer)
self._after_closure(model, optimizer)
skipped_backward = closure_result is None
# in manual optimization, the closure does not return a value
if not isinstance(model, pl.LightningModule) or not model.automatic_optimization or not skipped_backward:
# note: the scaler will skip the `optimizer.step` if nonfinite gradients are found
self.scaler.step(optimizer, **kwargs)
self.scaler.update()

@contextmanager
def forward_context(self) -> Generator[None, None, None]:
"""No explicit precision casting. Inputs are supposed to be manually casted."""
Expand All @@ -141,4 +181,19 @@ def forward_context(self) -> Generator[None, None, None]:
pass


def update_config_with_dtype_overrides(dtype_config, config):
if hasattr(config, "__io__"):
config.__io__ = update_config_with_dtype_overrides(dtype_config, config.__io__)
for field in fields(dtype_config):
if not hasattr(config, field.name):
continue
# If we overwrote a value, throw a warning.
old_val = getattr(config, field.name)
new_val = getattr(dtype_config, field.name)
if old_val != new_val:
setattr(config, field.name, new_val)
logging.warning(f"Overwrote {type(config).__name__}.{field.name} {old_val} -> {new_val}")
return config


__all__ = ["MegatronMixedPrecision"]
10 changes: 10 additions & 0 deletions nemo/lightning/pytorch/strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,12 @@ def connect(self, model: pl.LightningModule) -> None:
if _maybe_mcore_config:
self._mcore_config = _maybe_mcore_config

dtype_config = getattr(self._precision_plugin, 'dtype_config', None)
if dtype_config:
from nemo.lightning.pytorch.plugins.mixed_precision import update_config_with_dtype_overrides

model.config = update_config_with_dtype_overrides(dtype_config, model.config)

has_optim = getattr(model, "optim", None)
if has_optim:
opt_config = getattr(model.optim, "config", None)
Expand All @@ -228,6 +234,10 @@ def connect(self, model: pl.LightningModule) -> None:
raise ValueError("PyTorch DDP is not enabled for mcore optimizer")
ddp_config = cast(DistributedDataParallelConfig, self.ddp_config)

if dtype_config:
model.optim.config = update_config_with_dtype_overrides(dtype_config, model.optim.config)
self.ddp_config = update_config_with_dtype_overrides(dtype_config, self.ddp_config)

if mcore_opt_config.use_distributed_optimizer != ddp_config.use_distributed_optimizer:
from nemo.utils import logging

Expand Down
95 changes: 95 additions & 0 deletions tests/lightning/test_precision_plugin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
import pytest
import pytorch_lightning as pl
import torch
from megatron.core.optimizer import OptimizerConfig

from nemo import lightning as nl
from nemo.collections import llm


class DummyTokenizer:
def __init__(self):
self.vocab_size = 30000


class TestMegatronMixedPrecision:
"""Unit tests for the MegatronMixedPrecision class."""

@pytest.mark.run_only_on('GPU')
def test_precision_plugin_fp8_passed(self):
"""Test __init__ with default parameters."""

class TrainerHook(nl.Trainer):
def connect(self, model: pl.LightningModule) -> None:
assert model.config.bf16 == False
assert model.config.fp8 is None
super().connect(model)
assert model.config.fp8 == 'e4m3'
assert model.config.bf16 == True

trainer = TrainerHook(
devices=2,
accelerator="gpu",
max_steps=2,
strategy=nl.MegatronStrategy(
tensor_model_parallel_size=2,
sequence_parallel=True,
ckpt_include_optimizer=False,
),
plugins=nl.MegatronMixedPrecision(precision="bf16-mixed", fp8='e4m3'),
limit_val_batches=0.0,
num_sanity_val_steps=0,
)

optim = nl.MegatronOptimizerModule(
config=OptimizerConfig(
optimizer="adam",
lr=1e-5,
use_distributed_optimizer=False,
fp16=True,
params_dtype=torch.float32,
),
)
config = llm.Llama2Config7B()
config.num_layers = 2
model = llm.LlamaModel(config, tokenizer=DummyTokenizer(), optim=optim)
trainer.strategy.connect(model)

@pytest.mark.run_only_on('GPU')
def test_precision_plugin_precision_params_override(self):
"""Test __init__ with default parameters."""
trainer = nl.Trainer(
devices=2,
accelerator="gpu",
max_steps=2,
strategy=nl.MegatronStrategy(
tensor_model_parallel_size=2,
sequence_parallel=True,
ckpt_include_optimizer=False,
),
plugins=nl.MegatronMixedPrecision(precision="bf16-mixed"),
limit_val_batches=0.0,
num_sanity_val_steps=0,
)

optim = nl.MegatronOptimizerModule(
config=OptimizerConfig(
optimizer="adam",
lr=1e-5,
use_distributed_optimizer=False,
fp16=True,
params_dtype=torch.float32,
),
)
config = llm.Llama2Config7B()
config.num_layers = 2
config.fp16 = True
config.bf16 = False
model = llm.LlamaModel(config, tokenizer=DummyTokenizer(), optim=optim)
trainer.strategy.connect(model)
assert optim.config.bf16 is not None
assert optim.config.fp16 is not None
assert optim.config.bf16 == True
assert optim.config.fp16 == False
assert model.config.fp16 == False
assert model.config.bf16 == True

0 comments on commit 60442c2

Please sign in to comment.