forked from NVIDIA/NeMo
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Akoumparouli/nemo ux precision plugin refactor (NVIDIA#10129)
* rename mixed_precision.py to precision.py Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com> * replace print with logging.warning Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com> * Apply isort and black reformatting Signed-off-by: akoumpa <akoumpa@users.noreply.github.com> * also patch ddp_config Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com> * Rename patch_dtype_config to update_config_with_dtype_overrides Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com> * Add GradScaler's args to constructor's arg list Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com> * Apply isort and black reformatting Signed-off-by: akoumpa <akoumpa@users.noreply.github.com> * fix import Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com> * Leverage mcore's fp16 grad scaler Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com> * remove unused param Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com> * Add precision plugin test Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com> * Apply isort and black reformatting Signed-off-by: akoumpa <akoumpa@users.noreply.github.com> * Also update __io__ configs Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com> * remove unused imports Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com> * fix fabric to ptl converter mcore precision plugin Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com> * fix test Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com> --------- Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com> Signed-off-by: akoumpa <akoumpa@users.noreply.github.com> Co-authored-by: akoumpa <akoumpa@users.noreply.github.com>
- Loading branch information
Showing
4 changed files
with
212 additions
and
54 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,95 @@ | ||
import pytest | ||
import pytorch_lightning as pl | ||
import torch | ||
from megatron.core.optimizer import OptimizerConfig | ||
|
||
from nemo import lightning as nl | ||
from nemo.collections import llm | ||
|
||
|
||
class DummyTokenizer: | ||
def __init__(self): | ||
self.vocab_size = 30000 | ||
|
||
|
||
class TestMegatronMixedPrecision: | ||
"""Unit tests for the MegatronMixedPrecision class.""" | ||
|
||
@pytest.mark.run_only_on('GPU') | ||
def test_precision_plugin_fp8_passed(self): | ||
"""Test __init__ with default parameters.""" | ||
|
||
class TrainerHook(nl.Trainer): | ||
def connect(self, model: pl.LightningModule) -> None: | ||
assert model.config.bf16 == False | ||
assert model.config.fp8 is None | ||
super().connect(model) | ||
assert model.config.fp8 == 'e4m3' | ||
assert model.config.bf16 == True | ||
|
||
trainer = TrainerHook( | ||
devices=2, | ||
accelerator="gpu", | ||
max_steps=2, | ||
strategy=nl.MegatronStrategy( | ||
tensor_model_parallel_size=2, | ||
sequence_parallel=True, | ||
ckpt_include_optimizer=False, | ||
), | ||
plugins=nl.MegatronMixedPrecision(precision="bf16-mixed", fp8='e4m3'), | ||
limit_val_batches=0.0, | ||
num_sanity_val_steps=0, | ||
) | ||
|
||
optim = nl.MegatronOptimizerModule( | ||
config=OptimizerConfig( | ||
optimizer="adam", | ||
lr=1e-5, | ||
use_distributed_optimizer=False, | ||
fp16=True, | ||
params_dtype=torch.float32, | ||
), | ||
) | ||
config = llm.Llama2Config7B() | ||
config.num_layers = 2 | ||
model = llm.LlamaModel(config, tokenizer=DummyTokenizer(), optim=optim) | ||
trainer.strategy.connect(model) | ||
|
||
@pytest.mark.run_only_on('GPU') | ||
def test_precision_plugin_precision_params_override(self): | ||
"""Test __init__ with default parameters.""" | ||
trainer = nl.Trainer( | ||
devices=2, | ||
accelerator="gpu", | ||
max_steps=2, | ||
strategy=nl.MegatronStrategy( | ||
tensor_model_parallel_size=2, | ||
sequence_parallel=True, | ||
ckpt_include_optimizer=False, | ||
), | ||
plugins=nl.MegatronMixedPrecision(precision="bf16-mixed"), | ||
limit_val_batches=0.0, | ||
num_sanity_val_steps=0, | ||
) | ||
|
||
optim = nl.MegatronOptimizerModule( | ||
config=OptimizerConfig( | ||
optimizer="adam", | ||
lr=1e-5, | ||
use_distributed_optimizer=False, | ||
fp16=True, | ||
params_dtype=torch.float32, | ||
), | ||
) | ||
config = llm.Llama2Config7B() | ||
config.num_layers = 2 | ||
config.fp16 = True | ||
config.bf16 = False | ||
model = llm.LlamaModel(config, tokenizer=DummyTokenizer(), optim=optim) | ||
trainer.strategy.connect(model) | ||
assert optim.config.bf16 is not None | ||
assert optim.config.fp16 is not None | ||
assert optim.config.bf16 == True | ||
assert optim.config.fp16 == False | ||
assert model.config.fp16 == False | ||
assert model.config.bf16 == True |