diff --git a/tests/pytorch/test_fp8_model_init.py b/tests/pytorch/test_fp8_model_init.py index 4cce7ef816..1dfde214e3 100644 --- a/tests/pytorch/test_fp8_model_init.py +++ b/tests/pytorch/test_fp8_model_init.py @@ -29,12 +29,15 @@ def test_default(self) -> None: model = te.Linear(768, 768) assert isinstance(model.weight, Float8Tensor), "Weight should be Float8Tensor" - assert not hasattr(model.weight, "._high_precision_init_val"), \ - "_high_precision_init_val should not exist" - assert not hasattr(model.weight, "get_high_precision_init_val"), \ - "get_high_precision_init_val() should not exist" - assert not hasattr(model.weight, "clear_high_precision_init_val"), \ - "clear_high_precision_init_val() should not exist" + assert not hasattr( + model.weight, "._high_precision_init_val" + ), "_high_precision_init_val should not exist" + assert not hasattr( + model.weight, "get_high_precision_init_val" + ), "get_high_precision_init_val() should not exist" + assert not hasattr( + model.weight, "clear_high_precision_init_val" + ), "clear_high_precision_init_val() should not exist" def test_preserve_high_precision_init_val(self) -> None: """Test fp8_model_init with preserve_high_precision_init_val=True""" @@ -42,12 +45,15 @@ def test_preserve_high_precision_init_val(self) -> None: model = te.Linear(768, 768) assert isinstance(model.weight, Float8Tensor), "Weight should be Float8Tensor" - assert hasattr(model.weight, "_high_precision_init_val"), \ - "_high_precision_init_val not found" - assert hasattr(model.weight, "get_high_precision_init_val"), \ - "get_high_precision_init_val() not found" - assert hasattr(model.weight, "clear_high_precision_init_val"), \ - "clear_high_precision_init_val() not found" + assert hasattr( + model.weight, "_high_precision_init_val" + ), "_high_precision_init_val not found" + assert hasattr( + model.weight, "get_high_precision_init_val" + ), "get_high_precision_init_val() not found" + assert hasattr( + model.weight, "clear_high_precision_init_val" + ), "clear_high_precision_init_val() not found" high_precision = model.weight.get_high_precision_init_val() assert high_precision.device.type == "cpu", "high_precision_init_val is not on the CPU" @@ -58,11 +64,14 @@ def test_preserve_high_precision_init_val(self) -> None: fp8_meta_index=model.weight._fp8_meta_index, amax=torch.empty(1, device="cuda"), # Dummy amax to avoid overwriting history. ) - assert torch.all(new_fp8._data == model.weight._data), \ - "high_precision_init_val and model.weight are not equal" + assert torch.all( + new_fp8._data == model.weight._data + ), "high_precision_init_val and model.weight are not equal" model.weight.clear_high_precision_init_val() - assert model.weight.get_high_precision_init_val() is None, \ - "clear_high_precision_init_val() not work" - assert not hasattr(model.weight, "._high_precision_init_val"), \ - "clear_high_precision_init_val() not work" + assert ( + model.weight.get_high_precision_init_val() is None + ), "clear_high_precision_init_val() not work" + assert not hasattr( + model.weight, "._high_precision_init_val" + ), "clear_high_precision_init_val() not work" diff --git a/transformer_engine/pytorch/fp8.py b/transformer_engine/pytorch/fp8.py index f54cda6429..bb799ef8e5 100644 --- a/transformer_engine/pytorch/fp8.py +++ b/transformer_engine/pytorch/fp8.py @@ -485,9 +485,9 @@ def restore_fp8_meta_tensors(fp8_meta: Dict[str, Any]) -> None: @contextmanager def fp8_model_init( - enabled: bool = True, - preserve_high_precision_init_val: bool = False, - ) -> None: + enabled: bool = True, + preserve_high_precision_init_val: bool = False, +) -> None: """ Context manager for FP8 initialization of parameters. diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index ceb4a286e4..3aebc1729b 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -883,6 +883,7 @@ def reset_parameters(self, defer_init: Optional[bool] = False) -> None: # a parameter so we always re-apply it just for extra safety. param = torch.nn.Parameter(param) if high_precision_init_val is not None: + def get(self): if hasattr(self, "_high_precision_init_val"): return self._high_precision_init_val