Skip to content

Commit

Permalink
fix test
Browse files Browse the repository at this point in the history
Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com>
  • Loading branch information
akoumpa committed Aug 20, 2024
1 parent 82997d2 commit d5cf9f9
Showing 1 changed file with 7 additions and 1 deletion.
8 changes: 7 additions & 1 deletion nemo/lightning/pytorch/plugins/mixed_precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def __init__(
dtype = torch.bfloat16 if precision in ['bf16', 'bf16-mixed'] else torch.float32
self.dtype_config = DtypeConfig(
fp32=precision in ['fp32', '32'],
fp16=precision in ['fp16', 'fp16-mixed', '16'],
fp16=precision in ['fp16', 'fp16-mixed', '16', '16-mixed'],
bf16=precision in ['bf16', 'bf16-mixed'],
params_dtype=params_dtype or torch.float32,
pipeline_dtype=pipeline_dtype or dtype,
Expand All @@ -115,6 +115,12 @@ def __init__(
hysteresis=fp16_hysteresis,
)
super().__init__()
if self.dtype_config.fp16:
self.precision = "16-mixed"
elif self.dtype_config.bf16:
self.precision = "bf16-mixed"
else:
self.precision = "32-true"

def convert_module(self, module: Module) -> Module:
"""Convert the module parameters to the precision type this plugin handles.
Expand Down

0 comments on commit d5cf9f9

Please sign in to comment.