Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PyTorch] Debug checkpointing with operation-based API #1063

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
184 changes: 184 additions & 0 deletions tests/pytorch/test_torch_save_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@

import pytest
import torch
import transformer_engine.common
import transformer_engine.pytorch as te
import transformer_engine.pytorch.ops as te_ops
import transformer_engine_torch as tex
from transformer_engine.pytorch.cpp_extensions import fp8_gemm, cast_to_fp8
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
Expand Down Expand Up @@ -287,3 +289,185 @@ def test_fp8_model_checkpoint(
torch.testing.assert_close(
model.weight._scale_inv.item(), fp8_meta_fwd_ref["scale_inv"][meta_index].item()
)


@pytest.mark.parametrize("fp8", (False, True))
@pytest.mark.parametrize("save_fp8_model", (False, True))
@pytest.mark.parametrize("load_fp8_model", (False, True))
def test_sequential_model(
*,
in_shape: Iterable[int] = (16, 16),
dtype: torch.dtype = torch.float32,
device: torch.device = "cuda",
save_steps: int = 2,
load_steps: int = 2,
fp8: bool,
save_fp8_model: bool,
load_fp8_model: bool,
) -> None:

# Skip invalid configurations
if fp8 or save_fp8_model or load_fp8_model:
if not fp8_available:
pytest.skip(reason_for_no_fp8)
if torch.device(device).type != "cuda":
pytest.skip("FP8 is only supported on CUDA devices")

# FP8 recipe
margin = 2
fp8_format = transformer_engine.common.recipe.Format.E4M3
recipe = transformer_engine.common.recipe.DelayedScaling(
margin=margin,
fp8_format=fp8_format,
amax_history_len=8,
amax_compute_algo="max",
)

# Construct model to save to checkpoint
with te.fp8_model_init(enabled=save_fp8_model):
model = te_ops.Sequential(
te_ops.BasicLinear(in_shape[-1], in_shape[-1], device=device, dtype=dtype),
)
with torch.no_grad():
torch.rand(model[0].weight.size(), out=model[0].weight)

# Synthetic data
xs_ref = [
torch.rand(in_shape, dtype=dtype, device=device) for _ in range(save_steps + load_steps)
]
dys_ref = [
torch.rand(in_shape, dtype=dtype, device=device) for _ in range(save_steps + load_steps)
]

def train_step(
model: te_ops.Sequential,
x: torch.Tensor,
dy: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Helper function to perform training step"""
x = x.detach().clone().requires_grad_()
dy = dy.detach().clone()
with te.fp8_autocast(enabled=fp8, fp8_recipe=recipe):
y = model(x)
y.backward(dy)
with torch.no_grad():
for param in model.parameters():
param += 0.125
return (
y.detach().clone(),
x.grad.detach().clone(),
model[0].weight.detach().float().clone(),
)

# Initial training steps with saved model
ys_ref = []
dxs_ref = []
ws_ref = []
for step in range(save_steps):
y, dx, w = train_step(model, xs_ref[step], dys_ref[step])
ys_ref.append(y)
dxs_ref.append(dx)
ws_ref.append(w)

# Keep track of FP8 metadata if needed
fp8_meta_ref = dict(input={}, param={}, grad_output={})
if fp8:
for fp8_meta_type, fp8_meta_key in (
("input", "scaling_fwd"),
("param", "scaling_fwd"),
("grad_output", "scaling_bwd"),
):
m_model = model[0].get_fp8_meta(fp8_meta_type)[fp8_meta_key]
m_ref = fp8_meta_ref[fp8_meta_type]
m_ref["amax"] = m_model.amax_history.detach().clone()
m_ref["scale"] = m_model.scale.detach().clone()
m_ref["scale_inv"] = m_model.scale_inv.detach().clone()
del m_model, m_ref

# Save checkpoint
byte_stream = io.BytesIO()
torch.save(model.state_dict(), byte_stream)
model_bytes = byte_stream.getvalue()
del byte_stream

# More training steps with saved model
for step in range(save_steps, save_steps + load_steps):
y, dx, w = train_step(model, xs_ref[step], dys_ref[step])
ys_ref.append(y)
dxs_ref.append(dx)
ws_ref.append(w)

# Disturb and destroy model
with torch.no_grad():
for param in model.parameters():
param.zero_()
model[0]._fp8_metas = None
del model

# Construct new model to load from checkpoint
with te.fp8_model_init(enabled=load_fp8_model):
model = te_ops.Sequential(
te_ops.BasicLinear(in_shape[-1], in_shape[-1], device=device, dtype=dtype),
)

# Tolerances for numerical checks
tols = {}
if fp8 or save_fp8_model or load_fp8_model:
tols = dict(rtol=0.125, atol=0.0675) # fp8e4me3 epsilon = 0.0625
exact_tols = dict(rtol=0, atol=0)

# Training steps with dummy data
for step in range(save_steps):
y, dx, w = train_step(
model,
torch.zeros_like(xs_ref[step]),
torch.zeros_like(dys_ref[step]),
)

# Make sure results don't match saved model
with pytest.raises(AssertionError):
torch.testing.assert_close(y, ys_ref[step], **tols)
with pytest.raises(AssertionError):
torch.testing.assert_close(dx, dxs_ref[step], **tols)
with pytest.raises(AssertionError):
torch.testing.assert_close(w, ws_ref[step], **tols)

# Make sure new model's FP8 metadata doesn't match saved model
if fp8:
for fp8_meta_type, fp8_meta_key in (
("input", "scaling_fwd"),
("param", "scaling_fwd"),
("grad_output", "scaling_bwd"),
):
m_model = model[0].get_fp8_meta(fp8_meta_type)[fp8_meta_key]
m_ref = fp8_meta_ref[fp8_meta_type]
with pytest.raises(AssertionError):
torch.testing.assert_close(m_model.amax_history, m_ref["amax"], **exact_tols)
with pytest.raises(AssertionError):
torch.testing.assert_close(m_model.scale, m_ref["scale"], **exact_tols)
with pytest.raises(AssertionError):
torch.testing.assert_close(m_model.scale_inv, m_ref["scale_inv"], **exact_tols)

# Load checkpoint
model.load_state_dict(torch.load(io.BytesIO(model_bytes)))
del model_bytes

# Check that new model's FP8 metadata matches saved model
if fp8:
for fp8_meta_type, fp8_meta_key in (
("input", "scaling_fwd"),
("param", "scaling_fwd"),
("grad_output", "scaling_bwd"),
):
m_model = model[0].get_fp8_meta(fp8_meta_type)[fp8_meta_key]
m_ref = fp8_meta_ref[fp8_meta_type]
torch.testing.assert_close(m_model.amax_history, m_ref["amax"], **exact_tols)
torch.testing.assert_close(m_model.scale, m_ref["scale"], **exact_tols)
torch.testing.assert_close(m_model.scale_inv, m_ref["scale_inv"], **exact_tols)

# More training steps with loaded model
for step in range(save_steps, save_steps + load_steps):
y, dx, w = train_step(model, xs_ref[step], dys_ref[step])
torch.testing.assert_close(y, ys_ref[step], **tols)
torch.testing.assert_close(dx, dxs_ref[step], **tols)
torch.testing.assert_close(w, ws_ref[step], **tols)
156 changes: 156 additions & 0 deletions transformer_engine/pytorch/ops/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import abc
from collections.abc import Iterable
import dataclasses
import pickle
from typing import Any, Optional

import torch
Expand Down Expand Up @@ -375,6 +376,161 @@ def forward(

return OperationFuser([self], fuse_ops=False)(input, [kwargs])

def get_extra_state(self) -> Optional[torch.Tensor]:
"""Serialize extra state

Contains metadata for FP8 casting.

"""

# This implementation is working around a few issues:
#
# (1) PyTorch's "extra state" infrastructure might be able to
# support any picklable type, but they make no guarantees.
# It seems that ONNX export experiences issues with
# non-tensor extra state.
# (2) PyTorch's checkpointing infrastructure does not remap
# devices for "extra state" like it does for "state dict".
# Thus, we want to avoid putting extra state on the GPU
# since it may be loaded on the wrong device.
# (3) The extra state consists of many small tensors. If we
# want to copy them all to CPU, then we need to avoid the
# overhead of many GPU-CPU memory transfers.
#
# See: https://github.com/NVIDIA/TransformerEngine/pull/351
# See: https://github.com/NVIDIA/TransformerEngine/pull/363

# Return immediately if op has no FP8 state
has_fp8_state = any(
self.num_fp8_scales(mode) > 0 for mode in ("input", "param", "grad_output")
)
if not has_fp8_state:
return None

def to_cpu(src: torch.Tensor) -> torch.Tensor:
"""Helper function to make CPU copy of tensor

Memory transfer is asynchronous w.r.t. host, so GPU should
be synchronized before using result.

"""
dst = torch.empty_like(src, device="cpu")
dst.copy_(src, non_blocking=True)
return dst

# Store FP8 state
state = {}
for mode in ("input", "param", "grad_output"):

# Get state for a given FP8 tensor
if self.num_fp8_scales(mode) == 0:
state[mode] = None
continue
fp8_meta = self.get_fp8_meta(mode)
if fp8_meta is None:
continue
state[mode] = {}

# Store tensors
if "scaling_fwd" in fp8_meta:
state[mode]["scale_fwd"] = to_cpu(fp8_meta["scaling_fwd"].scale)
state[mode]["scale_inv_fwd"] = to_cpu(fp8_meta["scaling_fwd"].scale_inv)
state[mode]["amax_history_fwd"] = to_cpu(fp8_meta["scaling_fwd"].amax_history)
if "scaling_bwd" in fp8_meta:
state[mode]["scale_bwd"] = to_cpu(fp8_meta["scaling_bwd"].scale)
state[mode]["scale_inv_bwd"] = to_cpu(fp8_meta["scaling_bwd"].scale_inv)
state[mode]["amax_history_bwd"] = to_cpu(fp8_meta["scaling_bwd"].amax_history)

# Store other picklable items
extra = {}
for key, val in fp8_meta.items():
if key == "buffer_index_and_autocast_key":
continue
if not isinstance(val, (bool, int, float, str, tuple, list)):
continue
extra[key] = val
state[mode]["extra_fp8_variables"] = extra

# Serialize state into byte tensor
torch.cuda.synchronize()
state_serialized = bytearray(pickle.dumps(state))
state_serialized = torch.frombuffer(state_serialized, dtype=torch.uint8)
return state_serialized

def set_extra_state(self, state: Optional[torch.Tensor]) -> None:
"""Load extra state"""
if state is None:
return

# Deserialize state from byte tensor
state = pickle.loads(state.detach().numpy(force=True).tobytes())
if state is None:
return

def copy_tensor(src: torch.Tensor, dst: torch.Tensor) -> None:
"""Helper function to copy tensor from CPU

Memory transfer is asynchronous w.r.t. host, so GPU should
be synchronized before using result.

"""
if src.size() != dst.size():
dst.data = torch.empty(src.size(), dtype=dst.dtype, device=dst.device)
dst.copy_(src, non_blocking=True)

# Load FP8 state
for mode in ("input", "param", "grad_output"):

# Get state for a given FP8 tensor
if mode not in state:
continue
if self.num_fp8_scales(mode) == 0:
continue
fp8_meta = self.get_fp8_meta(mode)
if fp8_meta is None:
continue

# Load extra state
fp8_meta.update(state[mode]["extra_fp8_variables"])
if "amax_history_fwd" in state[mode]:
fp8_meta["recipe"].amax_history_len = state[mode]["amax_history_fwd"].size(0)
elif "amax_history_bwd" in state[mode]:
fp8_meta["recipe"].amax_history_len = state[mode]["amax_history_bwd"].size(0)
if "global_fp8_buffer_pos_fwd_recompute" in fp8_meta:
del fp8_meta["global_fp8_buffer_pos_fwd_recompute"]

# Load tensors
fp8_meta = self.get_fp8_meta(mode)
if "scaling_fwd" in fp8_meta:
fp8_meta_fwd = fp8_meta["scaling_fwd"]
copy_tensor(state[mode]["scale_fwd"], fp8_meta_fwd.scale)
copy_tensor(state[mode]["scale_inv_fwd"], fp8_meta_fwd.scale_inv)
copy_tensor(state[mode]["amax_history_fwd"], fp8_meta_fwd.amax_history)
if "scaling_bwd" in fp8_meta:
fp8_meta_bwd = fp8_meta["scaling_bwd"]
copy_tensor(state[mode]["scale_bwd"], fp8_meta_bwd.scale)
copy_tensor(state[mode]["scale_inv_bwd"], fp8_meta_bwd.scale_inv)
copy_tensor(state[mode]["amax_history_bwd"], fp8_meta_bwd.amax_history)

# Finish CPU-GPU memory transfers
torch.cuda.synchronize()

def _load_from_state_dict(self, *args, **kwargs) -> None:
"""Load state"""

# In the base PyTorch module class, the extra state is loaded
# _after_ the parameters. However, copying values into FP8
# parameters requires an FP8 cast, which uses a scaling factor
# from the operation's FP8 metadata. The FP8 metadata is
# included in the operation's extra state, so we need to
# manually load the extra state before loading parameters.

state_dict, prefix = args[0], args[1]
extra_state_key = prefix + torch.nn.modules.module._EXTRA_STATE_KEY_SUFFIX
if extra_state_key in state_dict:
self.set_extra_state(state_dict[extra_state_key])
super()._load_from_state_dict(*args, **kwargs)


class FusedOperation(FusibleOperation):
"""Compound tensor operation supported by the operation fuser
Expand Down
Loading