Skip to content

Commit

Permalink
Fix loading a universal checkpoint (#5263)
Browse files Browse the repository at this point in the history
This PR fixes the following two points regarding checkpoint loading.

- Load optimizer states
With [this PR](#5104), we
removed optimizer's `step()` on initialization. This made the DS's
parameter update match with PyTorch's normal behavior. However, we don't
have keys in optimizer states any more when we load a checkpoint.
For legacy/elastic checkpoints, the PR changed the checkpoint loaders to
create keys and buffers on loading. However, the loader for universal
checkpoints still relies on keys in optimizer states. As the result,
loading a universal checkpoint fails.
This PR fixes the loader to find optimizer state keys from a given
checkpoint.

- Resume step count
2943e6a
The checkpoint loader for a universal checkpoint resumes step count for
optimizer only when the param group already has `step`. But some
optimizers creates the key `step` in a param group at the first call of
`step()` (e.g. Apex [Fused
Adam](https://github.com/NVIDIA/apex/blob/810ffae374a2b9cb4b5c5e28eaeca7d7998fca0c/apex/optimizers/fused_adam.py#L154).
In this case, the step count is not restored. This PR changes this
behavior to always set step count in a param group.
This PR also stop incrementing the step count when loading. I didn't see
why we need to increment the step count for my small example, but we may
need a discussion to consider various cases.
  • Loading branch information
tohtana authored Mar 13, 2024
1 parent 2df8e23 commit b112c99
Show file tree
Hide file tree
Showing 6 changed files with 58 additions and 21 deletions.
33 changes: 22 additions & 11 deletions deepspeed/checkpoint/universal_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,22 +4,26 @@
# DeepSpeed Team

import os
import re
import torch
import types
from .constants import (FP32_WEIGHT_KEY, PARAM, VOCAB_TENSOR, CAT_DIM, PARAM_N_SUB_PARAMS)


def load_hp_checkpoint_state(self, folder, tp_rank, tp_world_size):
hp_mapping = self._hp_mapping
optim_state_keys = hp_mapping.get_optim_state_keys()
hp_keys = [FP32_WEIGHT_KEY] + optim_state_keys
#print(f'{hp_keys=}')
checkpoint_files = {key: os.path.join(folder, f"{key}.pt") for key in hp_keys}
for file in checkpoint_files.values():
assert os.path.isfile(file), f'{file} is not a valid file'
hp_mapping.optim_fragment = {}

hp_keys = []
for file in os.listdir(folder):
# We expect files named something like "exp_avg.pt", "exp_avg_sq.pt", "fp32.pt"
pattern = r'(.+).pt'
match = re.search(pattern, file)
if match:
hp_keys.append(match.group(1))

for key in hp_keys:
ckpt_file = checkpoint_files[key]
ckpt_file = os.path.join(folder, f"{key}.pt")
ckpt_dict = torch.load(ckpt_file)
full_hp_param = ckpt_dict[PARAM]

Expand Down Expand Up @@ -62,7 +66,6 @@ def load_hp_checkpoint_state(self, folder, tp_rank, tp_world_size):

assert full_param_numel == tp_world_size * tp_slice_numel, \
f'Loading {ckpt_file} full param numel {full_param_numel} != tensor slice numel {tp_slice_numel} * tp_world_size {tp_world_size}'
dst_tensor = hp_mapping.hp_fragment if key == FP32_WEIGHT_KEY else hp_mapping.get_optim_state_fragment(key)

# print(f"{full_hp_param.shape=} {full_param_numel=} {folder=}")
# print(f"{dst_tensor.shape=} {dst_tensor.numel()=}{folder=}")
Expand All @@ -84,13 +87,21 @@ def load_hp_checkpoint_state(self, folder, tp_rank, tp_world_size):

lp_frag_address = hp_mapping.lp_fragment_address
tp_hp_fragment = tp_hp_slice.narrow(0, lp_frag_address.start, lp_frag_address.numel)
assert dst_tensor.numel() == lp_frag_address.numel, \
f'Load checkpoint {key} dst_tensor numel {dst_tensor.numel()} != src numel {lp_frag_address.numel}'

# print(f"{key} SHAPE: {tp_hp_slice.shape=}")
# print(f"{key} SHAPE: {dst_tensor.shape=}")
# print(f"{key} SHAPE: {tp_hp_fragment.shape=}")
dst_tensor.data.copy_(tp_hp_fragment.data)

if key == FP32_WEIGHT_KEY:
dst_tensor = hp_mapping.get_hp_fragment()
assert dst_tensor.numel() == lp_frag_address.numel, \
f'Load checkpoint {key} dst numel {dst_tensor.numel()} != src numel {lp_frag_address.numel}'
dst_tensor.data.copy_(tp_hp_fragment.data)
else:
assert tp_hp_fragment.numel() == lp_frag_address.numel, \
f'Load checkpoint {key} dst numel {tp_hp_fragment.numel()} != src numel {lp_frag_address.numel}'

hp_mapping.optim_fragment[key] = tp_hp_fragment.clone().detach()


def enable_universal_checkpoint(param_list):
Expand Down
10 changes: 8 additions & 2 deletions deepspeed/runtime/bf16_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
align_dense_tensors, all_gather_dp_groups, bwc_tensor_model_parallel_rank,
is_model_parallel_parameter, see_memory_usage, graph_process)

from deepspeed.utils import link_hp_params, lazy_init_hp_params_optimizer_state, fragment_address
from deepspeed.utils import link_hp_params, lazy_init_hp_params_optimizer_state, fragment_address, map_to_flat_opt_states
from deepspeed.checkpoint import enable_universal_checkpoint
from deepspeed.checkpoint.constants import (DS_VERSION, PARTITION_COUNT, BASE_OPTIMIZER_STATE,
SINGLE_PARTITION_OF_FP32_GROUPS, CLIP_GRAD, GROUP_PADDINGS,
Expand Down Expand Up @@ -457,12 +457,18 @@ def _load_hp_checkpoint_state(self, checkpoint_dir):
tp_rank = bwc_tensor_model_parallel_rank(mpu=self.mpu)
tp_world_size = self.mpu.get_slice_parallel_world_size()

for i, _ in enumerate(self.optimizer.param_groups):
for i, param_group in enumerate(self.optimizer.param_groups):
# We have an assumption that all params in the same param_group have the same keys
opt_keys = set()

for lp in self.bf16_groups[i]:
if lp._hp_mapping is not None:
#print(f"Loading {self.param_names[lp]} {tp_rank=} {tp_world_size=}")
lp.load_hp_checkpoint_state(os.path.join(checkpoint_dir, self.param_names[lp]), tp_rank,
tp_world_size)
for key in lp._hp_mapping.get_optim_state_keys():
opt_keys.add(key)
map_to_flat_opt_states(param_group['params'][0], self.bf16_groups[i], self.optimizer.state, opt_keys)

def accumulate_hp_grads_and_remove_lp(self, lp_param, group_idx, param_idx):
assert self.immediate_grad_update
Expand Down
9 changes: 4 additions & 5 deletions deepspeed/runtime/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -2785,7 +2785,7 @@ def load_checkpoint(self,
if self.load_universal_checkpoint():
self.optimizer.update_lp_params()
if load_zero_checkpoint:
self.update_optimizer_step(step=client_states['iteration'] + 1)
self.update_optimizer_step(step=client_states['iteration'])

return load_path, client_states

Expand Down Expand Up @@ -2966,7 +2966,7 @@ def _load_zero_checkpoint(self, load_dir, tag, load_optimizer_states=True):
def update_optimizer_step(self, step):

def set_step(d):
if isinstance(d['step'], torch.Tensor):
if 'step' in d and isinstance(d['step'], torch.Tensor):
d['step'] = torch.tensor(step, dtype=d['step'].dtype, device=d['step'].device)
else:
d['step'] = step
Expand All @@ -2975,10 +2975,9 @@ def set_step(d):
base_optimizer = optimizer.optimizer
state = base_optimizer.state
for group in optimizer.param_groups:
if 'step' in group:
set_step(group)
set_step(group)
for p in group['params']:
if p in state and len(state[p]) > 0 and 'step' in state[p]:
if p in state and len(state[p]) > 0:
set_step(state[p])

def _get_mp_rank_zero_checkpoint_names(self, load_dir, tag, mp_rank, dp_world_size, bf16_mode):
Expand Down
10 changes: 8 additions & 2 deletions deepspeed/runtime/zero/stage_1_and_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from deepspeed.checkpoint.constants import (DS_VERSION, GROUP_PADDINGS, PARTITION_COUNT, LOSS_SCALER,
SINGLE_PARTITION_OF_FP32_GROUPS, BASE_OPTIMIZER_STATE,
BASE_OPTIMIZER_STATE_STEP, CLIP_GRAD, ZERO_STAGE, PARAM_SLICE_MAPPINGS)
from deepspeed.utils import link_hp_params, lazy_init_hp_params_optimizer_state
from deepspeed.utils import link_hp_params, lazy_init_hp_params_optimizer_state, map_to_flat_opt_states
from deepspeed.checkpoint import enable_universal_checkpoint

from deepspeed.utils import groups
Expand Down Expand Up @@ -2310,12 +2310,18 @@ def _load_hp_checkpoint_state(self, checkpoint_dir):
tp_world_size = self.mpu.get_slice_parallel_world_size() if hasattr(self.mpu, "get_slice_parallel_world_size") \
else self.mpu.get_tensor_model_parallel_world_size()

for i, _ in enumerate(self.optimizer.param_groups):
for i, param_group in enumerate(self.optimizer.param_groups):
# We have an assumption that all params in the same param_group have the same keys
opt_keys = set()

for lp in self.bit16_groups[i]:
if lp._hp_mapping is not None:
#print(f"Loading {self.param_names[lp]} {tp_rank=} {tp_world_size=}")
lp.load_hp_checkpoint_state(os.path.join(checkpoint_dir, self.param_names[lp]), tp_rank,
tp_world_size)
for key in lp._hp_mapping.get_optim_state_keys():
opt_keys.add(key)
map_to_flat_opt_states(param_group['params'][0], self.bit16_groups[i], self.optimizer.state, opt_keys)

def _load_global_state(self, sd):
self.loss_scaler = sd.get(LOSS_SCALER, self.loss_scaler)
Expand Down
2 changes: 1 addition & 1 deletion deepspeed/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from .groups import *
from .nvtx import instrument_w_nvtx
# TODO: Move tensor fragment and mixed precision to zero utils
from .tensor_fragment import tensor_fragment, get_full_hp_param, get_hp_fragment_mapping, fragment_address, get_full_hp_grad
from .tensor_fragment import tensor_fragment, get_full_hp_param, get_hp_fragment_mapping, fragment_address, get_full_hp_grad, map_to_flat_opt_states
from .tensor_fragment import safe_get_full_fp32_param, safe_get_full_grad, safe_get_full_optimizer_state
from .tensor_fragment import set_full_hp_param
from .tensor_fragment import safe_set_full_fp32_param, safe_set_full_optimizer_state
Expand Down
15 changes: 15 additions & 0 deletions deepspeed/utils/tensor_fragment.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,21 @@ def get_hp_fragment(self, optim_state_key=None):
return self.get_optim_state_fragment(optim_state_key)


def map_to_flat_opt_states(flat_hp_tensor, lp_tensors, optim_state, opt_keys):
for key in opt_keys:
hp_param = flat_hp_tensor
buffer = torch.zeros_like(hp_param)

for lp in lp_tensors:
if lp._hp_mapping is not None:
hp_fragment_address = lp._hp_mapping.get_hp_fragment_address()
hp_fragment = buffer.narrow(0, hp_fragment_address.start, hp_fragment_address.numel)
hp_fragment.data.copy_(lp._hp_mapping.get_hp_fragment(optim_state_key=key).data)
lp._hp_mapping.hp_fragment = hp_fragment

optim_state[hp_param][key] = buffer


def get_full_hp_param(self, optim_state_key=None):
reduce_buffer = torch.zeros_like(self, dtype=torch.float32).flatten()
if self._hp_mapping is not None:
Expand Down

0 comments on commit b112c99

Please sign in to comment.