From d5cf9f97ea68128d2854f6f966639e4d1ac7682e Mon Sep 17 00:00:00 2001 From: Alexandros Koumparoulis Date: Tue, 20 Aug 2024 09:43:24 -0700 Subject: [PATCH] fix test Signed-off-by: Alexandros Koumparoulis --- nemo/lightning/pytorch/plugins/mixed_precision.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/nemo/lightning/pytorch/plugins/mixed_precision.py b/nemo/lightning/pytorch/plugins/mixed_precision.py index fd7ce81dc93a..79394cc4bbb1 100644 --- a/nemo/lightning/pytorch/plugins/mixed_precision.py +++ b/nemo/lightning/pytorch/plugins/mixed_precision.py @@ -92,7 +92,7 @@ def __init__( dtype = torch.bfloat16 if precision in ['bf16', 'bf16-mixed'] else torch.float32 self.dtype_config = DtypeConfig( fp32=precision in ['fp32', '32'], - fp16=precision in ['fp16', 'fp16-mixed', '16'], + fp16=precision in ['fp16', 'fp16-mixed', '16', '16-mixed'], bf16=precision in ['bf16', 'bf16-mixed'], params_dtype=params_dtype or torch.float32, pipeline_dtype=pipeline_dtype or dtype, @@ -115,6 +115,12 @@ def __init__( hysteresis=fp16_hysteresis, ) super().__init__() + if self.dtype_config.fp16: + self.precision = "16-mixed" + elif self.dtype_config.bf16: + self.precision = "bf16-mixed" + else: + self.precision = "32-true" def convert_module(self, module: Module) -> Module: """Convert the module parameters to the precision type this plugin handles.