Skip to content

Commit

Permalink
Add model config option patching
Browse files Browse the repository at this point in the history
Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com>
  • Loading branch information
akoumpa committed Aug 13, 2024
1 parent a123bc1 commit 9140e4b
Showing 1 changed file with 5 additions and 3 deletions.
8 changes: 5 additions & 3 deletions nemo/lightning/pytorch/plugins/precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,10 +125,12 @@ def convert_module(self, module: Module) -> Module:
from megatron.core.transformer.module import Float16Module
from megatron.core.utils import get_model_config

if self.precision in ["16-mixed", "bf16-mixed"]:
if self.dtype_config.fp16 or self.dtype_config.bf16:
# Patch config options
# @akoumparouli: is this too late?
config = get_model_config(module.module)
config.fp16 = self.precision == "16-mixed"
config.bf16 = self.precision == "bf16-mixed"
config.fp16 = self.dtype_config.fp16
config.bf16 = self.dtype_config.bf16
if isinstance(module.module, Float16Module):
new_float16_module = Float16Module(config, module.module.module)
module.module = new_float16_module
Expand Down

0 comments on commit 9140e4b

Please sign in to comment.