Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Aug 28, 2024
1 parent 8866920 commit ccb6b63
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 21 deletions.
45 changes: 27 additions & 18 deletions tests/pytorch/test_fp8_model_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,25 +29,31 @@ def test_default(self) -> None:
model = te.Linear(768, 768)

assert isinstance(model.weight, Float8Tensor), "Weight should be Float8Tensor"
assert not hasattr(model.weight, "._high_precision_init_val"), \
"_high_precision_init_val should not exist"
assert not hasattr(model.weight, "get_high_precision_init_val"), \
"get_high_precision_init_val() should not exist"
assert not hasattr(model.weight, "clear_high_precision_init_val"), \
"clear_high_precision_init_val() should not exist"
assert not hasattr(
model.weight, "._high_precision_init_val"
), "_high_precision_init_val should not exist"
assert not hasattr(
model.weight, "get_high_precision_init_val"
), "get_high_precision_init_val() should not exist"
assert not hasattr(
model.weight, "clear_high_precision_init_val"
), "clear_high_precision_init_val() should not exist"

def test_preserve_high_precision_init_val(self) -> None:
"""Test fp8_model_init with preserve_high_precision_init_val=True"""
with fp8_model_init(preserve_high_precision_init_val=True):
model = te.Linear(768, 768)

assert isinstance(model.weight, Float8Tensor), "Weight should be Float8Tensor"
assert hasattr(model.weight, "_high_precision_init_val"), \
"_high_precision_init_val not found"
assert hasattr(model.weight, "get_high_precision_init_val"), \
"get_high_precision_init_val() not found"
assert hasattr(model.weight, "clear_high_precision_init_val"), \
"clear_high_precision_init_val() not found"
assert hasattr(
model.weight, "_high_precision_init_val"
), "_high_precision_init_val not found"
assert hasattr(
model.weight, "get_high_precision_init_val"
), "get_high_precision_init_val() not found"
assert hasattr(
model.weight, "clear_high_precision_init_val"
), "clear_high_precision_init_val() not found"

high_precision = model.weight.get_high_precision_init_val()
assert high_precision.device.type == "cpu", "high_precision_init_val is not on the CPU"
Expand All @@ -58,11 +64,14 @@ def test_preserve_high_precision_init_val(self) -> None:
fp8_meta_index=model.weight._fp8_meta_index,
amax=torch.empty(1, device="cuda"), # Dummy amax to avoid overwriting history.
)
assert torch.all(new_fp8._data == model.weight._data), \
"high_precision_init_val and model.weight are not equal"
assert torch.all(
new_fp8._data == model.weight._data
), "high_precision_init_val and model.weight are not equal"

model.weight.clear_high_precision_init_val()
assert model.weight.get_high_precision_init_val() is None, \
"clear_high_precision_init_val() not work"
assert not hasattr(model.weight, "._high_precision_init_val"), \
"clear_high_precision_init_val() not work"
assert (
model.weight.get_high_precision_init_val() is None
), "clear_high_precision_init_val() not work"
assert not hasattr(
model.weight, "._high_precision_init_val"
), "clear_high_precision_init_val() not work"
6 changes: 3 additions & 3 deletions transformer_engine/pytorch/fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,9 +485,9 @@ def restore_fp8_meta_tensors(fp8_meta: Dict[str, Any]) -> None:

@contextmanager
def fp8_model_init(
enabled: bool = True,
preserve_high_precision_init_val: bool = False,
) -> None:
enabled: bool = True,
preserve_high_precision_init_val: bool = False,
) -> None:
"""
Context manager for FP8 initialization of parameters.
Expand Down
1 change: 1 addition & 0 deletions transformer_engine/pytorch/module/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -883,6 +883,7 @@ def reset_parameters(self, defer_init: Optional[bool] = False) -> None:
# a parameter so we always re-apply it just for extra safety.
param = torch.nn.Parameter(param)
if high_precision_init_val is not None:

def get(self):
if hasattr(self, "_high_precision_init_val"):
return self._high_precision_init_val
Expand Down

0 comments on commit ccb6b63

Please sign in to comment.