Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Akoumparouli/nemo ux precision plugin refactor #10129

Merged
merged 16 commits into from
Aug 20, 2024
Merged
2 changes: 0 additions & 2 deletions nemo/lightning/fabric/plugins.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,4 @@ def forward_context(self) -> Generator[None, None, None]:
def _convert_megatron_mixed_precision(plugin: MegatronMixedPrecision) -> FabricMegatronMixedPrecision:
return FabricMegatronMixedPrecision(
precision=plugin.precision,
device=plugin.device,
scaler=plugin.scaler,
)
159 changes: 107 additions & 52 deletions nemo/lightning/pytorch/plugins/mixed_precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,16 @@
# limitations under the License.

from contextlib import contextmanager
from dataclasses import dataclass, fields
from typing import Any, Callable, Generator, List, Literal, Tuple, TypeVar, Union

import pytorch_lightning as pl
import torch
from pytorch_lightning.plugins.precision import MixedPrecision
from pytorch_lightning.plugins.precision import Precision
from torch.nn import Module
from torch.optim import Optimizer

from nemo.lightning._strategy_lib import GradScaler
from nemo.utils import logging

AnyT = TypeVar("AnyT")

Expand All @@ -33,18 +34,93 @@ def get_optim_config(optimizer: Optimizer):
raise ValueError("Failed to extract optimizer config from module.")


class MegatronMixedPrecision(MixedPrecision):
@dataclass
class DtypeConfig:
fp32: bool = False
fp16: bool = False
bf16: bool = False
maanug-nv marked this conversation as resolved.
Show resolved Hide resolved
params_dtype: torch.dtype = None
pipeline_dtype: torch.dtype = None
autocast_dtype: torch.dtype = None
autocast_enabled: bool = False
grad_reduce_in_fp32: bool = True
# fp8 related
fp8: str = None
fp8_margin: int = 0
fp8_interval: int = 1
fp8_amax_history_len: int = 1
fp8_amax_compute_algo: str = "most_recent"
fp8_wgrad: bool = True
fp8_dot_product_attention: bool = False
fp8_multi_head_attention: bool = False
# FP16 Loss scaling
loss_scale: float = (None,)
initial_loss_scale: float = (None,)
min_loss_scale: float = (None,)
loss_scale_window: float = (None,)
hysteresis: float = (None,)


class MegatronMixedPrecision(Precision):
def __init__(
self,
precision: Literal["16-mixed", "bf16-mixed"],
device="cuda",
precision: Literal["16-mixed", "bf16-mixed", "32"],
params_dtype: torch.dtype = None,
pipeline_dtype: torch.dtype = None,
autocast_dtype: torch.dtype = None,
autocast_enabled: bool = False,
grad_reduce_in_fp32: bool = True,
# fp8 related,
fp8: str = None,
fp8_margin: int = 0,
fp8_interval: int = 1,
fp8_amax_history_len: int = 1,
fp8_amax_compute_algo: str = "most_recent",
fp8_wgrad: bool = True,
fp8_dot_product_attention: bool = False,
fp8_multi_head_attention: bool = False,
fp16_loss_scale: float = None,
fp16_initial_loss_scale: float = 4294967296,
fp16_min_loss_scale: float = 1.0,
fp16_loss_scale_window: int = 1000,
fp16_hysteresis: int = 2,
) -> None:
if precision == "bf16-mixed":
scaler = None
else:
scaler = GradScaler(init_scale=2**32, growth_interval=1000, hysteresis=2)

super().__init__(precision, device, scaler)
if isinstance(precision, int):
precision = str(precision)

dtype = torch.bfloat16 if precision in ['bf16', 'bf16-mixed'] else torch.float32
self.dtype_config = DtypeConfig(
fp32=precision in ['fp32', '32'],
fp16=precision in ['fp16', 'fp16-mixed', '16', '16-mixed'],
bf16=precision in ['bf16', 'bf16-mixed'],
params_dtype=params_dtype or torch.float32,
pipeline_dtype=pipeline_dtype or dtype,
autocast_dtype=autocast_dtype or dtype,
autocast_enabled=autocast_enabled,
grad_reduce_in_fp32=grad_reduce_in_fp32,
fp8=fp8,
fp8_margin=fp8_margin,
fp8_interval=fp8_interval,
fp8_amax_history_len=fp8_amax_history_len,
fp8_amax_compute_algo=fp8_amax_compute_algo,
fp8_wgrad=fp8_wgrad,
fp8_dot_product_attention=fp8_dot_product_attention,
fp8_multi_head_attention=fp8_multi_head_attention,
# fp16 loss scale
loss_scale=fp16_loss_scale,
initial_loss_scale=fp16_initial_loss_scale,
min_loss_scale=fp16_min_loss_scale,
loss_scale_window=fp16_loss_scale_window,
hysteresis=fp16_hysteresis,
)
super().__init__()
if self.dtype_config.fp16:
self.precision = "16-mixed"
elif self.dtype_config.bf16:
self.precision = "bf16-mixed"
else:
self.precision = "32-true"

def convert_module(self, module: Module) -> Module:
"""Convert the module parameters to the precision type this plugin handles.
Expand All @@ -55,11 +131,11 @@ def convert_module(self, module: Module) -> Module:
from megatron.core.transformer.module import Float16Module
from megatron.core.utils import get_model_config

if self.precision in ["16-mixed", "bf16-mixed"]:
if self.dtype_config.fp16 or self.dtype_config.bf16:
# Patch config options
config = get_model_config(module.module)
config.fp16 = self.precision == "16-mixed"
config.bf16 = self.precision == "bf16-mixed"
config.autocast = False
config.fp16 = self.dtype_config.fp16
config.bf16 = self.dtype_config.bf16
if hasattr(module, 'module'):
module.module = Float16Module(config, module.module)
else:
Expand All @@ -74,8 +150,8 @@ def convert_optimizer(self, optimizer: Optimizer) -> Optimizer:

"""
optim_config = get_optim_config(optimizer)
assert optim_config.bf16 == (self.precision == "bf16-mixed"), "BF16 enabled on model but not on optimizer"
assert optim_config.fp16 == (self.precision == "fp16-mixed"), "BF16 enabled on model but not on optimizer"
assert optim_config.bf16 == self.dtype_config.bf16, "BF16 enabled on model but not on optimizer"
assert optim_config.fp16 == self.dtype_config.fp16, "BF16 enabled on model but not on optimizer"
return optimizer

def convert_input(self, data: AnyT) -> AnyT:
Expand All @@ -96,42 +172,6 @@ def convert_output(self, data: AnyT) -> AnyT:
"""
return data

def optimizer_step(
self,
optimizer: torch.optim.Optimizer,
model: Union["pl.LightningModule", torch.nn.Module],
closure: Callable[[], Any],
**kwargs: Any,
) -> None:
from nemo.core.optim import MainParamsOptimizerWrapper

if not isinstance(optimizer, MainParamsOptimizerWrapper):
return super().optimizer_step(optimizer, model, closure, **kwargs)

if self.scaler is None:
assert optimizer.fp32_grad_accumulation, "BF16 uses FP32 grad accumulation"
_ = closure()
self._after_closure(model, optimizer)
return optimizer.step(**kwargs)

assert not optimizer.fp32_grad_accumulation, "FP16 uses FP16 grad accumulation"
closure_result = closure()

# TODO: Add an option for merged all-reduce

# cast fp16 grads to fp32 and copy to main grads, which are used for unscale and param update
optimizer.copy_model_grads_to_main_grads()
# `unscale` after the closure is executed but before the `on_before_optimizer_step` hook.
# unscale main (fp32) gradients
self.scaler.unscale_(optimizer)
self._after_closure(model, optimizer)
skipped_backward = closure_result is None
# in manual optimization, the closure does not return a value
if not isinstance(model, pl.LightningModule) or not model.automatic_optimization or not skipped_backward:
# note: the scaler will skip the `optimizer.step` if nonfinite gradients are found
self.scaler.step(optimizer, **kwargs)
self.scaler.update()

@contextmanager
def forward_context(self) -> Generator[None, None, None]:
"""No explicit precision casting. Inputs are supposed to be manually casted."""
Expand All @@ -141,4 +181,19 @@ def forward_context(self) -> Generator[None, None, None]:
pass


def update_config_with_dtype_overrides(dtype_config, config):
if hasattr(config, "__io__"):
config.__io__ = update_config_with_dtype_overrides(dtype_config, config.__io__)
for field in fields(dtype_config):
if not hasattr(config, field.name):
continue
# If we overwrote a value, throw a warning.
old_val = getattr(config, field.name)
new_val = getattr(dtype_config, field.name)
if old_val != new_val:
setattr(config, field.name, new_val)
logging.warning(f"Overwrote {type(config).__name__}.{field.name} {old_val} -> {new_val}")
return config


__all__ = ["MegatronMixedPrecision"]
10 changes: 10 additions & 0 deletions nemo/lightning/pytorch/strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,12 @@ def connect(self, model: pl.LightningModule) -> None:
if _maybe_mcore_config:
self._mcore_config = _maybe_mcore_config

dtype_config = getattr(self._precision_plugin, 'dtype_config', None)
if dtype_config:
from nemo.lightning.pytorch.plugins.mixed_precision import update_config_with_dtype_overrides

model.config = update_config_with_dtype_overrides(dtype_config, model.config)

has_optim = getattr(model, "optim", None)
if has_optim:
opt_config = getattr(model.optim, "config", None)
Expand All @@ -228,6 +234,10 @@ def connect(self, model: pl.LightningModule) -> None:
raise ValueError("PyTorch DDP is not enabled for mcore optimizer")
ddp_config = cast(DistributedDataParallelConfig, self.ddp_config)

if dtype_config:
model.optim.config = update_config_with_dtype_overrides(dtype_config, model.optim.config)
self.ddp_config = update_config_with_dtype_overrides(dtype_config, self.ddp_config)

if mcore_opt_config.use_distributed_optimizer != ddp_config.use_distributed_optimizer:
from nemo.utils import logging

Expand Down
95 changes: 95 additions & 0 deletions tests/lightning/test_precision_plugin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
import pytest
import pytorch_lightning as pl
import torch
from megatron.core.optimizer import OptimizerConfig

from nemo import lightning as nl
from nemo.collections import llm


class DummyTokenizer:
def __init__(self):
self.vocab_size = 30000


class TestMegatronMixedPrecision:
"""Unit tests for the MegatronMixedPrecision class."""

@pytest.mark.run_only_on('GPU')
def test_precision_plugin_fp8_passed(self):
"""Test __init__ with default parameters."""

class TrainerHook(nl.Trainer):
def connect(self, model: pl.LightningModule) -> None:
assert model.config.bf16 == False
assert model.config.fp8 is None
super().connect(model)
assert model.config.fp8 == 'e4m3'
assert model.config.bf16 == True

trainer = TrainerHook(
devices=2,
accelerator="gpu",
max_steps=2,
strategy=nl.MegatronStrategy(
tensor_model_parallel_size=2,
sequence_parallel=True,
ckpt_include_optimizer=False,
),
plugins=nl.MegatronMixedPrecision(precision="bf16-mixed", fp8='e4m3'),
limit_val_batches=0.0,
num_sanity_val_steps=0,
)

optim = nl.MegatronOptimizerModule(
config=OptimizerConfig(
optimizer="adam",
lr=1e-5,
use_distributed_optimizer=False,
fp16=True,
params_dtype=torch.float32,
),
)
config = llm.Llama2Config7B()
config.num_layers = 2
model = llm.LlamaModel(config, tokenizer=DummyTokenizer(), optim=optim)
trainer.strategy.connect(model)

@pytest.mark.run_only_on('GPU')
def test_precision_plugin_precision_params_override(self):
"""Test __init__ with default parameters."""
trainer = nl.Trainer(
devices=2,
accelerator="gpu",
max_steps=2,
strategy=nl.MegatronStrategy(
tensor_model_parallel_size=2,
sequence_parallel=True,
ckpt_include_optimizer=False,
),
plugins=nl.MegatronMixedPrecision(precision="bf16-mixed"),
limit_val_batches=0.0,
num_sanity_val_steps=0,
)

optim = nl.MegatronOptimizerModule(
config=OptimizerConfig(
optimizer="adam",
lr=1e-5,
use_distributed_optimizer=False,
fp16=True,
params_dtype=torch.float32,
),
)
config = llm.Llama2Config7B()
config.num_layers = 2
config.fp16 = True
config.bf16 = False
model = llm.LlamaModel(config, tokenizer=DummyTokenizer(), optim=optim)
trainer.strategy.connect(model)
assert optim.config.bf16 is not None
assert optim.config.fp16 is not None
assert optim.config.bf16 == True
assert optim.config.fp16 == False
assert model.config.fp16 == False
assert model.config.bf16 == True
Loading