From 0a0477e345378c59c58cbafb518121a452e3ec32 Mon Sep 17 00:00:00 2001 From: Alexandros Koumparoulis Date: Tue, 13 Aug 2024 13:43:18 -0700 Subject: [PATCH 01/16] rename mixed_precision.py to precision.py Signed-off-by: Alexandros Koumparoulis --- .../pytorch/plugins/mixed_precision.py | 93 ++++++++++++++++--- nemo/lightning/pytorch/strategies.py | 6 ++ 2 files changed, 87 insertions(+), 12 deletions(-) diff --git a/nemo/lightning/pytorch/plugins/mixed_precision.py b/nemo/lightning/pytorch/plugins/mixed_precision.py index 65b7c6292249..62035f8a4d25 100644 --- a/nemo/lightning/pytorch/plugins/mixed_precision.py +++ b/nemo/lightning/pytorch/plugins/mixed_precision.py @@ -1,4 +1,4 @@ -# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +w# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,6 +13,7 @@ # limitations under the License. from contextlib import contextmanager +from dataclasses import dataclass, fields from typing import Any, Callable, Generator, List, Literal, Tuple, TypeVar, Union import pytorch_lightning as pl @@ -32,19 +33,74 @@ def get_optim_config(optimizer: Optimizer): except: raise ValueError("Failed to extract optimizer config from module.") +@dataclass +class DtypeConfig: + fp32: bool = False + fp16: bool = False + bf16: bool = False + params_dtype: torch.dtype = None + pipeline_dtype: torch.dtype = None + autocast_dtype: torch.dtype = None + autocast_enabled: bool = False + grad_reduce_in_fp32: bool = True + # fp8 related + fp8: str = None + fp8_margin: int = 0 + fp8_interval: int = 1 + fp8_amax_history_len: int = 1 + fp8_amax_compute_algo: str = "most_recent" + fp8_wgrad: bool = True + fp8_dot_product_attention: bool = False + fp8_multi_head_attention: bool = False + class MegatronMixedPrecision(MixedPrecision): def __init__( self, - precision: Literal["16-mixed", "bf16-mixed"], - device="cuda", + precision: Literal["16-mixed", "bf16-mixed", "32"], + params_dtype: torch.dtype = None, + pipeline_dtype: torch.dtype = None, + autocast_dtype: torch.dtype = None, + autocast_enabled: bool = False, + grad_reduce_in_fp32: bool = True, + # fp8 related, + fp8: str = None, + fp8_margin: int = 0, + fp8_interval: int = 1, + fp8_amax_history_len: int = 1, + fp8_amax_compute_algo: str = "most_recent", + fp8_wgrad: bool = True, + fp8_dot_product_attention: bool = False, + fp8_multi_head_attention: bool = False, + device: str = "cuda", ) -> None: - if precision == "bf16-mixed": - scaler = None - else: - scaler = GradScaler(init_scale=2**32, growth_interval=1000, hysteresis=2) - super().__init__(precision, device, scaler) + if isinstance(precision, int): + precision = str(precision) + + dtype = torch.bfloat16 if precision in ['bf16', 'bf16-mixed'] else torch.float32 + self.dtype_config = DtypeConfig( + fp32=precision in ['fp32', '32'], + fp16=precision in ['fp16', 'fp16-mixed', '16'], + bf16=precision in ['bf16', 'bf16-mixed'], + params_dtype=params_dtype or torch.float32, + pipeline_dtype=pipeline_dtype or dtype, + autocast_dtype=autocast_dtype or dtype, + autocast_enabled=autocast_enabled, + grad_reduce_in_fp32=grad_reduce_in_fp32, + fp8=fp8, + fp8_margin=fp8_margin, + fp8_interval=fp8_interval, + fp8_amax_history_len=fp8_amax_history_len, + fp8_amax_compute_algo=fp8_amax_compute_algo, + fp8_wgrad=fp8_wgrad, + fp8_dot_product_attention=fp8_dot_product_attention, + fp8_multi_head_attention=fp8_multi_head_attention, + ) + scaler = None + if self.dtype_config.fp16: + scaler = GradScaler(init_scale=2**32, growth_interval=1000, hysteresis=2) + super().__init__(self.dtype_config, device, scaler) def convert_module(self, module: Module) -> Module: """Convert the module parameters to the precision type this plugin handles. @@ -55,11 +111,11 @@ def convert_module(self, module: Module) -> Module: from megatron.core.transformer.module import Float16Module from megatron.core.utils import get_model_config - if self.precision in ["16-mixed", "bf16-mixed"]: + if self.dtype_config.fp16 or self.dtype_config.bf16: + # Patch config options config = get_model_config(module.module) - config.fp16 = self.precision == "16-mixed" - config.bf16 = self.precision == "bf16-mixed" - config.autocast = False + config.fp16 = self.dtype_config.fp16 + config.bf16 = self.dtype_config.bf16 if hasattr(module, 'module'): module.module = Float16Module(config, module.module) else: @@ -141,4 +197,17 @@ def forward_context(self) -> Generator[None, None, None]: pass +def patch_dtype_config(dtype_config, config): + for field in fields(dtype_config): + if not hasattr(config, field.name): + continue + # If we overwrote a value, throw a warning. + old_val = getattr(config, field.name) + new_val = getattr(dtype_config, field.name) + if old_val != new_val: + setattr(config, field.name, new_val) + print(f"Overwrote {type(config).__name__}.{field.name} {old_val} -> {new_val}") + return config + + __all__ = ["MegatronMixedPrecision"] diff --git a/nemo/lightning/pytorch/strategies.py b/nemo/lightning/pytorch/strategies.py index 0250709a4e03..47e4b54b2c8e 100644 --- a/nemo/lightning/pytorch/strategies.py +++ b/nemo/lightning/pytorch/strategies.py @@ -219,6 +219,12 @@ def connect(self, model: pl.LightningModule) -> None: if _maybe_mcore_config: self._mcore_config = _maybe_mcore_config + if hasattr(self._precision_plugin, 'dtype_config'): + from nemo.lightning.pytorch.plugins.precision import patch_dtype_config + + model.config = patch_dtype_config(self._precision_plugin.dtype_config, model.config) + model.optim.config = patch_dtype_config(self._precision_plugin.dtype_config, model.optim.config) + has_optim = getattr(model, "optim", None) if has_optim: opt_config = getattr(model.optim, "config", None) From ffeeb12b2021180e30e1eb51519d51533b979ae0 Mon Sep 17 00:00:00 2001 From: Alexandros Koumparoulis Date: Thu, 15 Aug 2024 12:09:51 -0700 Subject: [PATCH 02/16] replace print with logging.warning Signed-off-by: Alexandros Koumparoulis --- nemo/lightning/pytorch/plugins/mixed_precision.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/nemo/lightning/pytorch/plugins/mixed_precision.py b/nemo/lightning/pytorch/plugins/mixed_precision.py index 62035f8a4d25..a30ba7d8ff40 100644 --- a/nemo/lightning/pytorch/plugins/mixed_precision.py +++ b/nemo/lightning/pytorch/plugins/mixed_precision.py @@ -1,4 +1,4 @@ -w# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -23,6 +23,7 @@ from torch.optim import Optimizer from nemo.lightning._strategy_lib import GradScaler +from nemo.utils import logging AnyT = TypeVar("AnyT") @@ -206,7 +207,7 @@ def patch_dtype_config(dtype_config, config): new_val = getattr(dtype_config, field.name) if old_val != new_val: setattr(config, field.name, new_val) - print(f"Overwrote {type(config).__name__}.{field.name} {old_val} -> {new_val}") + logging.warning(f"Overwrote {type(config).__name__}.{field.name} {old_val} -> {new_val}") return config From dcc77b26cac5ad69ca878f7a2254d49e66a819b8 Mon Sep 17 00:00:00 2001 From: akoumpa Date: Fri, 16 Aug 2024 00:47:35 +0000 Subject: [PATCH 03/16] Apply isort and black reformatting Signed-off-by: akoumpa --- nemo/lightning/pytorch/plugins/mixed_precision.py | 1 + 1 file changed, 1 insertion(+) diff --git a/nemo/lightning/pytorch/plugins/mixed_precision.py b/nemo/lightning/pytorch/plugins/mixed_precision.py index a30ba7d8ff40..cad07baff5d9 100644 --- a/nemo/lightning/pytorch/plugins/mixed_precision.py +++ b/nemo/lightning/pytorch/plugins/mixed_precision.py @@ -34,6 +34,7 @@ def get_optim_config(optimizer: Optimizer): except: raise ValueError("Failed to extract optimizer config from module.") + @dataclass class DtypeConfig: fp32: bool = False From 6dd5bae312e6d010c4bad63a20dd1a41a684a81f Mon Sep 17 00:00:00 2001 From: Alexandros Koumparoulis Date: Thu, 15 Aug 2024 21:53:24 -0700 Subject: [PATCH 04/16] also patch ddp_config Signed-off-by: Alexandros Koumparoulis --- nemo/lightning/pytorch/strategies.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/nemo/lightning/pytorch/strategies.py b/nemo/lightning/pytorch/strategies.py index 47e4b54b2c8e..f0bd64f983b9 100644 --- a/nemo/lightning/pytorch/strategies.py +++ b/nemo/lightning/pytorch/strategies.py @@ -223,7 +223,6 @@ def connect(self, model: pl.LightningModule) -> None: from nemo.lightning.pytorch.plugins.precision import patch_dtype_config model.config = patch_dtype_config(self._precision_plugin.dtype_config, model.config) - model.optim.config = patch_dtype_config(self._precision_plugin.dtype_config, model.optim.config) has_optim = getattr(model, "optim", None) if has_optim: @@ -234,6 +233,10 @@ def connect(self, model: pl.LightningModule) -> None: raise ValueError("PyTorch DDP is not enabled for mcore optimizer") ddp_config = cast(DistributedDataParallelConfig, self.ddp_config) + if hasattr(self._precision_plugin, 'dtype_config'): + model.optim.config = patch_dtype_config(self._precision_plugin.dtype_config, model.optim.config) + self.ddp_config = patch_dtype_config(self._precision_plugin.dtype_config, self.ddp_config) + if mcore_opt_config.use_distributed_optimizer != ddp_config.use_distributed_optimizer: from nemo.utils import logging From 8e6e85025d1d9a12ab7fad988441d3af211c7cca Mon Sep 17 00:00:00 2001 From: Alexandros Koumparoulis Date: Thu, 15 Aug 2024 21:58:39 -0700 Subject: [PATCH 05/16] Rename patch_dtype_config to update_config_with_dtype_overrides Signed-off-by: Alexandros Koumparoulis --- nemo/lightning/pytorch/plugins/mixed_precision.py | 2 +- nemo/lightning/pytorch/strategies.py | 13 +++++++------ 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/nemo/lightning/pytorch/plugins/mixed_precision.py b/nemo/lightning/pytorch/plugins/mixed_precision.py index cad07baff5d9..daa9801f028e 100644 --- a/nemo/lightning/pytorch/plugins/mixed_precision.py +++ b/nemo/lightning/pytorch/plugins/mixed_precision.py @@ -199,7 +199,7 @@ def forward_context(self) -> Generator[None, None, None]: pass -def patch_dtype_config(dtype_config, config): +def update_config_with_dtype_overrides(dtype_config, config): for field in fields(dtype_config): if not hasattr(config, field.name): continue diff --git a/nemo/lightning/pytorch/strategies.py b/nemo/lightning/pytorch/strategies.py index f0bd64f983b9..3ed15d8dc162 100644 --- a/nemo/lightning/pytorch/strategies.py +++ b/nemo/lightning/pytorch/strategies.py @@ -219,10 +219,11 @@ def connect(self, model: pl.LightningModule) -> None: if _maybe_mcore_config: self._mcore_config = _maybe_mcore_config - if hasattr(self._precision_plugin, 'dtype_config'): - from nemo.lightning.pytorch.plugins.precision import patch_dtype_config + dtype_config = getattr(self._precision_plugin, 'dtype_config', None) + if dtype_config: + from nemo.lightning.pytorch.plugins.precision import update_config_with_dtype_overrides - model.config = patch_dtype_config(self._precision_plugin.dtype_config, model.config) + model.config = update_config_with_dtype_overrides(dtype_config, model.config) has_optim = getattr(model, "optim", None) if has_optim: @@ -233,9 +234,9 @@ def connect(self, model: pl.LightningModule) -> None: raise ValueError("PyTorch DDP is not enabled for mcore optimizer") ddp_config = cast(DistributedDataParallelConfig, self.ddp_config) - if hasattr(self._precision_plugin, 'dtype_config'): - model.optim.config = patch_dtype_config(self._precision_plugin.dtype_config, model.optim.config) - self.ddp_config = patch_dtype_config(self._precision_plugin.dtype_config, self.ddp_config) + if dtype_config: + model.optim.config = update_config_with_dtype_overrides(dtype_config, model.optim.config) + self.ddp_config = update_config_with_dtype_overrides(dtype_config, self.ddp_config) if mcore_opt_config.use_distributed_optimizer != ddp_config.use_distributed_optimizer: from nemo.utils import logging From a28b998f303a7ce23be2b4a05b79b9c2f371ee97 Mon Sep 17 00:00:00 2001 From: Alexandros Koumparoulis Date: Thu, 15 Aug 2024 22:01:46 -0700 Subject: [PATCH 06/16] Add GradScaler's args to constructor's arg list Signed-off-by: Alexandros Koumparoulis --- nemo/lightning/pytorch/plugins/mixed_precision.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/nemo/lightning/pytorch/plugins/mixed_precision.py b/nemo/lightning/pytorch/plugins/mixed_precision.py index daa9801f028e..ac16d56cf1f6 100644 --- a/nemo/lightning/pytorch/plugins/mixed_precision.py +++ b/nemo/lightning/pytorch/plugins/mixed_precision.py @@ -74,6 +74,9 @@ def __init__( fp8_wgrad: bool = True, fp8_dot_product_attention: bool = False, fp8_multi_head_attention: bool = False, + native_amp_init_scale: int = 2**32, + native_amp_growth_interval: int = 1000, + native_amp_hysteresis: int = 2, device: str = "cuda", ) -> None: @@ -101,7 +104,7 @@ def __init__( ) scaler = None if self.dtype_config.fp16: - scaler = GradScaler(init_scale=2**32, growth_interval=1000, hysteresis=2) + scaler = GradScaler(init_scale=native_amp_init_scale, growth_interval=native_amp_growth_interval, hysteresis=native_amp_hysteresis) super().__init__(self.dtype_config, device, scaler) def convert_module(self, module: Module) -> Module: From aed66583cc6365aba890a89c732d74ce27437a1e Mon Sep 17 00:00:00 2001 From: akoumpa Date: Fri, 16 Aug 2024 05:02:34 +0000 Subject: [PATCH 07/16] Apply isort and black reformatting Signed-off-by: akoumpa --- nemo/lightning/pytorch/plugins/mixed_precision.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/nemo/lightning/pytorch/plugins/mixed_precision.py b/nemo/lightning/pytorch/plugins/mixed_precision.py index ac16d56cf1f6..69e2c5119dd6 100644 --- a/nemo/lightning/pytorch/plugins/mixed_precision.py +++ b/nemo/lightning/pytorch/plugins/mixed_precision.py @@ -104,7 +104,11 @@ def __init__( ) scaler = None if self.dtype_config.fp16: - scaler = GradScaler(init_scale=native_amp_init_scale, growth_interval=native_amp_growth_interval, hysteresis=native_amp_hysteresis) + scaler = GradScaler( + init_scale=native_amp_init_scale, + growth_interval=native_amp_growth_interval, + hysteresis=native_amp_hysteresis, + ) super().__init__(self.dtype_config, device, scaler) def convert_module(self, module: Module) -> Module: From 4361c75e73a98579576e4dc9e69d707ae8bb1ef6 Mon Sep 17 00:00:00 2001 From: Alexandros Koumparoulis Date: Fri, 16 Aug 2024 13:24:47 -0700 Subject: [PATCH 08/16] fix import Signed-off-by: Alexandros Koumparoulis --- nemo/lightning/pytorch/strategies.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nemo/lightning/pytorch/strategies.py b/nemo/lightning/pytorch/strategies.py index 3ed15d8dc162..668b088a4864 100644 --- a/nemo/lightning/pytorch/strategies.py +++ b/nemo/lightning/pytorch/strategies.py @@ -221,7 +221,7 @@ def connect(self, model: pl.LightningModule) -> None: dtype_config = getattr(self._precision_plugin, 'dtype_config', None) if dtype_config: - from nemo.lightning.pytorch.plugins.precision import update_config_with_dtype_overrides + from nemo.lightning.pytorch.plugins.mixed_precision import update_config_with_dtype_overrides model.config = update_config_with_dtype_overrides(dtype_config, model.config) From 9dd24c4d0a246e40961651c0f4f45892238e9e80 Mon Sep 17 00:00:00 2001 From: Alexandros Koumparoulis Date: Fri, 16 Aug 2024 13:42:59 -0700 Subject: [PATCH 09/16] Leverage mcore's fp16 grad scaler Signed-off-by: Alexandros Koumparoulis --- .../pytorch/plugins/mixed_precision.py | 74 ++++++------------- 1 file changed, 22 insertions(+), 52 deletions(-) diff --git a/nemo/lightning/pytorch/plugins/mixed_precision.py b/nemo/lightning/pytorch/plugins/mixed_precision.py index 69e2c5119dd6..bd70501c8e1a 100644 --- a/nemo/lightning/pytorch/plugins/mixed_precision.py +++ b/nemo/lightning/pytorch/plugins/mixed_precision.py @@ -18,11 +18,10 @@ import pytorch_lightning as pl import torch -from pytorch_lightning.plugins.precision import MixedPrecision +from pytorch_lightning.plugins.precision import Precision from torch.nn import Module from torch.optim import Optimizer -from nemo.lightning._strategy_lib import GradScaler from nemo.utils import logging AnyT = TypeVar("AnyT") @@ -54,9 +53,15 @@ class DtypeConfig: fp8_wgrad: bool = True fp8_dot_product_attention: bool = False fp8_multi_head_attention: bool = False + # FP16 Loss scaling + loss_scale: float = None, + initial_loss_scale: float = None, + min_loss_scale: float = None, + loss_scale_window: float = None, + hysteresis: float = None, -class MegatronMixedPrecision(MixedPrecision): +class MegatronMixedPrecision(Precision): def __init__( self, precision: Literal["16-mixed", "bf16-mixed", "32"], @@ -74,9 +79,11 @@ def __init__( fp8_wgrad: bool = True, fp8_dot_product_attention: bool = False, fp8_multi_head_attention: bool = False, - native_amp_init_scale: int = 2**32, - native_amp_growth_interval: int = 1000, - native_amp_hysteresis: int = 2, + fp16_loss_scale: float = None, + fp16_initial_loss_scale: float = 4294967296, + fp16_min_loss_scale: float = 1.0, + fp16_loss_scale_window: int = 1000, + fp16_hysteresis: int = 2, device: str = "cuda", ) -> None: @@ -101,15 +108,14 @@ def __init__( fp8_wgrad=fp8_wgrad, fp8_dot_product_attention=fp8_dot_product_attention, fp8_multi_head_attention=fp8_multi_head_attention, + # fp16 loss scale + loss_scale=fp16_loss_scale, + initial_loss_scale=fp16_initial_loss_scale, + min_loss_scale=fp16_min_loss_scale, + loss_scale_window=fp16_loss_scale_window, + hysteresis=fp16_hysteresis, ) - scaler = None - if self.dtype_config.fp16: - scaler = GradScaler( - init_scale=native_amp_init_scale, - growth_interval=native_amp_growth_interval, - hysteresis=native_amp_hysteresis, - ) - super().__init__(self.dtype_config, device, scaler) + super().__init__() def convert_module(self, module: Module) -> Module: """Convert the module parameters to the precision type this plugin handles. @@ -139,8 +145,8 @@ def convert_optimizer(self, optimizer: Optimizer) -> Optimizer: """ optim_config = get_optim_config(optimizer) - assert optim_config.bf16 == (self.precision == "bf16-mixed"), "BF16 enabled on model but not on optimizer" - assert optim_config.fp16 == (self.precision == "fp16-mixed"), "BF16 enabled on model but not on optimizer" + assert optim_config.bf16 == self.dtype_config.bf16, "BF16 enabled on model but not on optimizer" + assert optim_config.fp16 == self.dtype_config.fp16, "BF16 enabled on model but not on optimizer" return optimizer def convert_input(self, data: AnyT) -> AnyT: @@ -161,42 +167,6 @@ def convert_output(self, data: AnyT) -> AnyT: """ return data - def optimizer_step( - self, - optimizer: torch.optim.Optimizer, - model: Union["pl.LightningModule", torch.nn.Module], - closure: Callable[[], Any], - **kwargs: Any, - ) -> None: - from nemo.core.optim import MainParamsOptimizerWrapper - - if not isinstance(optimizer, MainParamsOptimizerWrapper): - return super().optimizer_step(optimizer, model, closure, **kwargs) - - if self.scaler is None: - assert optimizer.fp32_grad_accumulation, "BF16 uses FP32 grad accumulation" - _ = closure() - self._after_closure(model, optimizer) - return optimizer.step(**kwargs) - - assert not optimizer.fp32_grad_accumulation, "FP16 uses FP16 grad accumulation" - closure_result = closure() - - # TODO: Add an option for merged all-reduce - - # cast fp16 grads to fp32 and copy to main grads, which are used for unscale and param update - optimizer.copy_model_grads_to_main_grads() - # `unscale` after the closure is executed but before the `on_before_optimizer_step` hook. - # unscale main (fp32) gradients - self.scaler.unscale_(optimizer) - self._after_closure(model, optimizer) - skipped_backward = closure_result is None - # in manual optimization, the closure does not return a value - if not isinstance(model, pl.LightningModule) or not model.automatic_optimization or not skipped_backward: - # note: the scaler will skip the `optimizer.step` if nonfinite gradients are found - self.scaler.step(optimizer, **kwargs) - self.scaler.update() - @contextmanager def forward_context(self) -> Generator[None, None, None]: """No explicit precision casting. Inputs are supposed to be manually casted.""" From 428fd74ffb65fff6b1c19976a880aa15c745450d Mon Sep 17 00:00:00 2001 From: Alexandros Koumparoulis Date: Fri, 16 Aug 2024 13:52:02 -0700 Subject: [PATCH 10/16] remove unused param Signed-off-by: Alexandros Koumparoulis --- nemo/lightning/pytorch/plugins/mixed_precision.py | 1 - 1 file changed, 1 deletion(-) diff --git a/nemo/lightning/pytorch/plugins/mixed_precision.py b/nemo/lightning/pytorch/plugins/mixed_precision.py index bd70501c8e1a..c9f7c6f50361 100644 --- a/nemo/lightning/pytorch/plugins/mixed_precision.py +++ b/nemo/lightning/pytorch/plugins/mixed_precision.py @@ -84,7 +84,6 @@ def __init__( fp16_min_loss_scale: float = 1.0, fp16_loss_scale_window: int = 1000, fp16_hysteresis: int = 2, - device: str = "cuda", ) -> None: if isinstance(precision, int): From 122f9730472ba44c7f1977a7b923fc9906075935 Mon Sep 17 00:00:00 2001 From: Alexandros Koumparoulis Date: Fri, 16 Aug 2024 14:38:55 -0700 Subject: [PATCH 11/16] Add precision plugin test Signed-off-by: Alexandros Koumparoulis --- tests/lightning/test_precision_plugin.py | 99 ++++++++++++++++++++++++ 1 file changed, 99 insertions(+) create mode 100644 tests/lightning/test_precision_plugin.py diff --git a/tests/lightning/test_precision_plugin.py b/tests/lightning/test_precision_plugin.py new file mode 100644 index 000000000000..133c0b9d638c --- /dev/null +++ b/tests/lightning/test_precision_plugin.py @@ -0,0 +1,99 @@ +from collections import defaultdict +from unittest.mock import MagicMock + +import pytest +from megatron.core import parallel_state +from torch import nn + +from nemo import lightning as nl +from nemo.lightning import megatron_parallel as mp +import pytorch_lightning as pl +from megatron.core.optimizer import OptimizerConfig +import torch +from nemo.collections import llm + + +class DummyTokenizer: + def __init__(self): + self.vocab_size = 30000 + +class TestMegatronMixedPrecision: + """Unit tests for the MegatronMixedPrecision class.""" + + @pytest.mark.run_only_on('GPU') + def test_precision_plugin_fp8_passed(self): + """Test __init__ with default parameters.""" + class TrainerHook(nl.Trainer): + def connect(self, model: pl.LightningModule) -> None: + assert model.config.bf16 == False + assert model.config.fp8 is None + super().connect(model) + assert model.config.fp8 == 'e4m3' + assert model.config.bf16 == True + + trainer = TrainerHook( + devices=2, + accelerator="gpu", + max_steps=2, + strategy=nl.MegatronStrategy( + tensor_model_parallel_size=2, + sequence_parallel=True, + ckpt_include_optimizer=False, + ), + plugins=nl.MegatronMixedPrecision(precision="bf16-mixed", fp8='e4m3'), + limit_val_batches=0.0, + num_sanity_val_steps=0, + ) + + optim = nl.MegatronOptimizerModule( + config=OptimizerConfig( + optimizer="adam", + lr=1e-5, + use_distributed_optimizer=False, + fp16=True, + params_dtype=torch.float32, + ), + ) + config = llm.Llama2Config7B() + config.num_layers = 2 + model = llm.LlamaModel(config, tokenizer=DummyTokenizer(), optim=optim) + trainer.strategy.connect(model) + + @pytest.mark.run_only_on('GPU') + def test_precision_plugin_precision_params_override(self): + """Test __init__ with default parameters.""" + trainer = nl.Trainer( + devices=2, + accelerator="gpu", + max_steps=2, + strategy=nl.MegatronStrategy( + tensor_model_parallel_size=2, + sequence_parallel=True, + ckpt_include_optimizer=False, + ), + plugins=nl.MegatronMixedPrecision(precision="bf16-mixed"), + limit_val_batches=0.0, + num_sanity_val_steps=0, + ) + + optim = nl.MegatronOptimizerModule( + config=OptimizerConfig( + optimizer="adam", + lr=1e-5, + use_distributed_optimizer=False, + fp16=True, + params_dtype=torch.float32, + ), + ) + config = llm.Llama2Config7B() + config.num_layers = 2 + config.fp16 = True + config.bf16 = False + model = llm.LlamaModel(config, tokenizer=DummyTokenizer(), optim=optim) + trainer.strategy.connect(model) + assert optim.config.bf16 is not None + assert optim.config.fp16 is not None + assert optim.config.bf16 == True + assert optim.config.fp16 == False + assert model.config.fp16 == False + assert model.config.bf16 == True From fa4bf44747ab0a30d0f5067b577a6aafecfb1b4c Mon Sep 17 00:00:00 2001 From: akoumpa Date: Fri, 16 Aug 2024 21:41:53 +0000 Subject: [PATCH 12/16] Apply isort and black reformatting Signed-off-by: akoumpa --- nemo/lightning/pytorch/plugins/mixed_precision.py | 10 +++++----- tests/lightning/test_precision_plugin.py | 10 ++++++---- 2 files changed, 11 insertions(+), 9 deletions(-) diff --git a/nemo/lightning/pytorch/plugins/mixed_precision.py b/nemo/lightning/pytorch/plugins/mixed_precision.py index c9f7c6f50361..7f81a5beeb19 100644 --- a/nemo/lightning/pytorch/plugins/mixed_precision.py +++ b/nemo/lightning/pytorch/plugins/mixed_precision.py @@ -54,11 +54,11 @@ class DtypeConfig: fp8_dot_product_attention: bool = False fp8_multi_head_attention: bool = False # FP16 Loss scaling - loss_scale: float = None, - initial_loss_scale: float = None, - min_loss_scale: float = None, - loss_scale_window: float = None, - hysteresis: float = None, + loss_scale: float = (None,) + initial_loss_scale: float = (None,) + min_loss_scale: float = (None,) + loss_scale_window: float = (None,) + hysteresis: float = (None,) class MegatronMixedPrecision(Precision): diff --git a/tests/lightning/test_precision_plugin.py b/tests/lightning/test_precision_plugin.py index 133c0b9d638c..773dbc5b12bc 100644 --- a/tests/lightning/test_precision_plugin.py +++ b/tests/lightning/test_precision_plugin.py @@ -2,27 +2,29 @@ from unittest.mock import MagicMock import pytest +import pytorch_lightning as pl +import torch from megatron.core import parallel_state +from megatron.core.optimizer import OptimizerConfig from torch import nn from nemo import lightning as nl -from nemo.lightning import megatron_parallel as mp -import pytorch_lightning as pl -from megatron.core.optimizer import OptimizerConfig -import torch from nemo.collections import llm +from nemo.lightning import megatron_parallel as mp class DummyTokenizer: def __init__(self): self.vocab_size = 30000 + class TestMegatronMixedPrecision: """Unit tests for the MegatronMixedPrecision class.""" @pytest.mark.run_only_on('GPU') def test_precision_plugin_fp8_passed(self): """Test __init__ with default parameters.""" + class TrainerHook(nl.Trainer): def connect(self, model: pl.LightningModule) -> None: assert model.config.bf16 == False From 0b4308f4225ccf7e00bc35fd11cbe2a83546d19b Mon Sep 17 00:00:00 2001 From: Alexandros Koumparoulis Date: Fri, 16 Aug 2024 19:09:12 -0700 Subject: [PATCH 13/16] Also update __io__ configs Signed-off-by: Alexandros Koumparoulis --- nemo/lightning/pytorch/plugins/mixed_precision.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/nemo/lightning/pytorch/plugins/mixed_precision.py b/nemo/lightning/pytorch/plugins/mixed_precision.py index 7f81a5beeb19..fd7ce81dc93a 100644 --- a/nemo/lightning/pytorch/plugins/mixed_precision.py +++ b/nemo/lightning/pytorch/plugins/mixed_precision.py @@ -176,6 +176,8 @@ def forward_context(self) -> Generator[None, None, None]: def update_config_with_dtype_overrides(dtype_config, config): + if hasattr(config, "__io__"): + config.__io__ = update_config_with_dtype_overrides(dtype_config, config.__io__) for field in fields(dtype_config): if not hasattr(config, field.name): continue From a52f6fb2783f7d2dd813f10b533340a9937f316a Mon Sep 17 00:00:00 2001 From: Alexandros Koumparoulis Date: Fri, 16 Aug 2024 19:10:46 -0700 Subject: [PATCH 14/16] remove unused imports Signed-off-by: Alexandros Koumparoulis --- tests/lightning/test_precision_plugin.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/tests/lightning/test_precision_plugin.py b/tests/lightning/test_precision_plugin.py index 773dbc5b12bc..bdd834c3bf7a 100644 --- a/tests/lightning/test_precision_plugin.py +++ b/tests/lightning/test_precision_plugin.py @@ -1,16 +1,10 @@ -from collections import defaultdict -from unittest.mock import MagicMock - import pytest import pytorch_lightning as pl import torch -from megatron.core import parallel_state from megatron.core.optimizer import OptimizerConfig -from torch import nn from nemo import lightning as nl from nemo.collections import llm -from nemo.lightning import megatron_parallel as mp class DummyTokenizer: From 82997d28aa8140eefe34ef1bdb5cb9a221ba948d Mon Sep 17 00:00:00 2001 From: Alexandros Koumparoulis Date: Mon, 19 Aug 2024 23:18:45 -0700 Subject: [PATCH 15/16] fix fabric to ptl converter mcore precision plugin Signed-off-by: Alexandros Koumparoulis --- nemo/lightning/fabric/plugins.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/nemo/lightning/fabric/plugins.py b/nemo/lightning/fabric/plugins.py index 79e1455cb33f..dba103abf2a4 100644 --- a/nemo/lightning/fabric/plugins.py +++ b/nemo/lightning/fabric/plugins.py @@ -124,6 +124,4 @@ def forward_context(self) -> Generator[None, None, None]: def _convert_megatron_mixed_precision(plugin: MegatronMixedPrecision) -> FabricMegatronMixedPrecision: return FabricMegatronMixedPrecision( precision=plugin.precision, - device=plugin.device, - scaler=plugin.scaler, ) From d5cf9f97ea68128d2854f6f966639e4d1ac7682e Mon Sep 17 00:00:00 2001 From: Alexandros Koumparoulis Date: Tue, 20 Aug 2024 09:43:24 -0700 Subject: [PATCH 16/16] fix test Signed-off-by: Alexandros Koumparoulis --- nemo/lightning/pytorch/plugins/mixed_precision.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/nemo/lightning/pytorch/plugins/mixed_precision.py b/nemo/lightning/pytorch/plugins/mixed_precision.py index fd7ce81dc93a..79394cc4bbb1 100644 --- a/nemo/lightning/pytorch/plugins/mixed_precision.py +++ b/nemo/lightning/pytorch/plugins/mixed_precision.py @@ -92,7 +92,7 @@ def __init__( dtype = torch.bfloat16 if precision in ['bf16', 'bf16-mixed'] else torch.float32 self.dtype_config = DtypeConfig( fp32=precision in ['fp32', '32'], - fp16=precision in ['fp16', 'fp16-mixed', '16'], + fp16=precision in ['fp16', 'fp16-mixed', '16', '16-mixed'], bf16=precision in ['bf16', 'bf16-mixed'], params_dtype=params_dtype or torch.float32, pipeline_dtype=pipeline_dtype or dtype, @@ -115,6 +115,12 @@ def __init__( hysteresis=fp16_hysteresis, ) super().__init__() + if self.dtype_config.fp16: + self.precision = "16-mixed" + elif self.dtype_config.bf16: + self.precision = "bf16-mixed" + else: + self.precision = "32-true" def convert_module(self, module: Module) -> Module: """Convert the module parameters to the precision type this plugin handles.