From b112c99ea8e09eb06ada0d60a3687983cb8c4bd0 Mon Sep 17 00:00:00 2001 From: Masahiro Tanaka <81312776+tohtana@users.noreply.github.com> Date: Wed, 13 Mar 2024 15:33:44 -0700 Subject: [PATCH] Fix loading a universal checkpoint (#5263) This PR fixes the following two points regarding checkpoint loading. - Load optimizer states With [this PR](https://github.com/microsoft/DeepSpeed/pull/5104), we removed optimizer's `step()` on initialization. This made the DS's parameter update match with PyTorch's normal behavior. However, we don't have keys in optimizer states any more when we load a checkpoint. For legacy/elastic checkpoints, the PR changed the checkpoint loaders to create keys and buffers on loading. However, the loader for universal checkpoints still relies on keys in optimizer states. As the result, loading a universal checkpoint fails. This PR fixes the loader to find optimizer state keys from a given checkpoint. - Resume step count https://github.com/microsoft/DeepSpeed/pull/5263/commits/2943e6ab7e156946a018ab2a08c7f3ba45b55e01 The checkpoint loader for a universal checkpoint resumes step count for optimizer only when the param group already has `step`. But some optimizers creates the key `step` in a param group at the first call of `step()` (e.g. Apex [Fused Adam](https://github.com/NVIDIA/apex/blob/810ffae374a2b9cb4b5c5e28eaeca7d7998fca0c/apex/optimizers/fused_adam.py#L154). In this case, the step count is not restored. This PR changes this behavior to always set step count in a param group. This PR also stop incrementing the step count when loading. I didn't see why we need to increment the step count for my small example, but we may need a discussion to consider various cases. --- deepspeed/checkpoint/universal_checkpoint.py | 33 +++++++++++++------- deepspeed/runtime/bf16_optimizer.py | 10 ++++-- deepspeed/runtime/engine.py | 9 +++--- deepspeed/runtime/zero/stage_1_and_2.py | 10 ++++-- deepspeed/utils/__init__.py | 2 +- deepspeed/utils/tensor_fragment.py | 15 +++++++++ 6 files changed, 58 insertions(+), 21 deletions(-) diff --git a/deepspeed/checkpoint/universal_checkpoint.py b/deepspeed/checkpoint/universal_checkpoint.py index 542d1125c566..a1314e004969 100644 --- a/deepspeed/checkpoint/universal_checkpoint.py +++ b/deepspeed/checkpoint/universal_checkpoint.py @@ -4,6 +4,7 @@ # DeepSpeed Team import os +import re import torch import types from .constants import (FP32_WEIGHT_KEY, PARAM, VOCAB_TENSOR, CAT_DIM, PARAM_N_SUB_PARAMS) @@ -11,15 +12,18 @@ def load_hp_checkpoint_state(self, folder, tp_rank, tp_world_size): hp_mapping = self._hp_mapping - optim_state_keys = hp_mapping.get_optim_state_keys() - hp_keys = [FP32_WEIGHT_KEY] + optim_state_keys - #print(f'{hp_keys=}') - checkpoint_files = {key: os.path.join(folder, f"{key}.pt") for key in hp_keys} - for file in checkpoint_files.values(): - assert os.path.isfile(file), f'{file} is not a valid file' + hp_mapping.optim_fragment = {} + + hp_keys = [] + for file in os.listdir(folder): + # We expect files named something like "exp_avg.pt", "exp_avg_sq.pt", "fp32.pt" + pattern = r'(.+).pt' + match = re.search(pattern, file) + if match: + hp_keys.append(match.group(1)) for key in hp_keys: - ckpt_file = checkpoint_files[key] + ckpt_file = os.path.join(folder, f"{key}.pt") ckpt_dict = torch.load(ckpt_file) full_hp_param = ckpt_dict[PARAM] @@ -62,7 +66,6 @@ def load_hp_checkpoint_state(self, folder, tp_rank, tp_world_size): assert full_param_numel == tp_world_size * tp_slice_numel, \ f'Loading {ckpt_file} full param numel {full_param_numel} != tensor slice numel {tp_slice_numel} * tp_world_size {tp_world_size}' - dst_tensor = hp_mapping.hp_fragment if key == FP32_WEIGHT_KEY else hp_mapping.get_optim_state_fragment(key) # print(f"{full_hp_param.shape=} {full_param_numel=} {folder=}") # print(f"{dst_tensor.shape=} {dst_tensor.numel()=}{folder=}") @@ -84,13 +87,21 @@ def load_hp_checkpoint_state(self, folder, tp_rank, tp_world_size): lp_frag_address = hp_mapping.lp_fragment_address tp_hp_fragment = tp_hp_slice.narrow(0, lp_frag_address.start, lp_frag_address.numel) - assert dst_tensor.numel() == lp_frag_address.numel, \ - f'Load checkpoint {key} dst_tensor numel {dst_tensor.numel()} != src numel {lp_frag_address.numel}' # print(f"{key} SHAPE: {tp_hp_slice.shape=}") # print(f"{key} SHAPE: {dst_tensor.shape=}") # print(f"{key} SHAPE: {tp_hp_fragment.shape=}") - dst_tensor.data.copy_(tp_hp_fragment.data) + + if key == FP32_WEIGHT_KEY: + dst_tensor = hp_mapping.get_hp_fragment() + assert dst_tensor.numel() == lp_frag_address.numel, \ + f'Load checkpoint {key} dst numel {dst_tensor.numel()} != src numel {lp_frag_address.numel}' + dst_tensor.data.copy_(tp_hp_fragment.data) + else: + assert tp_hp_fragment.numel() == lp_frag_address.numel, \ + f'Load checkpoint {key} dst numel {tp_hp_fragment.numel()} != src numel {lp_frag_address.numel}' + + hp_mapping.optim_fragment[key] = tp_hp_fragment.clone().detach() def enable_universal_checkpoint(param_list): diff --git a/deepspeed/runtime/bf16_optimizer.py b/deepspeed/runtime/bf16_optimizer.py index aaa836bf1c31..82c8dda423a6 100644 --- a/deepspeed/runtime/bf16_optimizer.py +++ b/deepspeed/runtime/bf16_optimizer.py @@ -18,7 +18,7 @@ align_dense_tensors, all_gather_dp_groups, bwc_tensor_model_parallel_rank, is_model_parallel_parameter, see_memory_usage, graph_process) -from deepspeed.utils import link_hp_params, lazy_init_hp_params_optimizer_state, fragment_address +from deepspeed.utils import link_hp_params, lazy_init_hp_params_optimizer_state, fragment_address, map_to_flat_opt_states from deepspeed.checkpoint import enable_universal_checkpoint from deepspeed.checkpoint.constants import (DS_VERSION, PARTITION_COUNT, BASE_OPTIMIZER_STATE, SINGLE_PARTITION_OF_FP32_GROUPS, CLIP_GRAD, GROUP_PADDINGS, @@ -457,12 +457,18 @@ def _load_hp_checkpoint_state(self, checkpoint_dir): tp_rank = bwc_tensor_model_parallel_rank(mpu=self.mpu) tp_world_size = self.mpu.get_slice_parallel_world_size() - for i, _ in enumerate(self.optimizer.param_groups): + for i, param_group in enumerate(self.optimizer.param_groups): + # We have an assumption that all params in the same param_group have the same keys + opt_keys = set() + for lp in self.bf16_groups[i]: if lp._hp_mapping is not None: #print(f"Loading {self.param_names[lp]} {tp_rank=} {tp_world_size=}") lp.load_hp_checkpoint_state(os.path.join(checkpoint_dir, self.param_names[lp]), tp_rank, tp_world_size) + for key in lp._hp_mapping.get_optim_state_keys(): + opt_keys.add(key) + map_to_flat_opt_states(param_group['params'][0], self.bf16_groups[i], self.optimizer.state, opt_keys) def accumulate_hp_grads_and_remove_lp(self, lp_param, group_idx, param_idx): assert self.immediate_grad_update diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index 5c1202ba06ae..174e699c5202 100644 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -2785,7 +2785,7 @@ def load_checkpoint(self, if self.load_universal_checkpoint(): self.optimizer.update_lp_params() if load_zero_checkpoint: - self.update_optimizer_step(step=client_states['iteration'] + 1) + self.update_optimizer_step(step=client_states['iteration']) return load_path, client_states @@ -2966,7 +2966,7 @@ def _load_zero_checkpoint(self, load_dir, tag, load_optimizer_states=True): def update_optimizer_step(self, step): def set_step(d): - if isinstance(d['step'], torch.Tensor): + if 'step' in d and isinstance(d['step'], torch.Tensor): d['step'] = torch.tensor(step, dtype=d['step'].dtype, device=d['step'].device) else: d['step'] = step @@ -2975,10 +2975,9 @@ def set_step(d): base_optimizer = optimizer.optimizer state = base_optimizer.state for group in optimizer.param_groups: - if 'step' in group: - set_step(group) + set_step(group) for p in group['params']: - if p in state and len(state[p]) > 0 and 'step' in state[p]: + if p in state and len(state[p]) > 0: set_step(state[p]) def _get_mp_rank_zero_checkpoint_names(self, load_dir, tag, mp_rank, dp_world_size, bf16_mode): diff --git a/deepspeed/runtime/zero/stage_1_and_2.py b/deepspeed/runtime/zero/stage_1_and_2.py index e8823f153fb8..6cfcc418e71a 100755 --- a/deepspeed/runtime/zero/stage_1_and_2.py +++ b/deepspeed/runtime/zero/stage_1_and_2.py @@ -28,7 +28,7 @@ from deepspeed.checkpoint.constants import (DS_VERSION, GROUP_PADDINGS, PARTITION_COUNT, LOSS_SCALER, SINGLE_PARTITION_OF_FP32_GROUPS, BASE_OPTIMIZER_STATE, BASE_OPTIMIZER_STATE_STEP, CLIP_GRAD, ZERO_STAGE, PARAM_SLICE_MAPPINGS) -from deepspeed.utils import link_hp_params, lazy_init_hp_params_optimizer_state +from deepspeed.utils import link_hp_params, lazy_init_hp_params_optimizer_state, map_to_flat_opt_states from deepspeed.checkpoint import enable_universal_checkpoint from deepspeed.utils import groups @@ -2310,12 +2310,18 @@ def _load_hp_checkpoint_state(self, checkpoint_dir): tp_world_size = self.mpu.get_slice_parallel_world_size() if hasattr(self.mpu, "get_slice_parallel_world_size") \ else self.mpu.get_tensor_model_parallel_world_size() - for i, _ in enumerate(self.optimizer.param_groups): + for i, param_group in enumerate(self.optimizer.param_groups): + # We have an assumption that all params in the same param_group have the same keys + opt_keys = set() + for lp in self.bit16_groups[i]: if lp._hp_mapping is not None: #print(f"Loading {self.param_names[lp]} {tp_rank=} {tp_world_size=}") lp.load_hp_checkpoint_state(os.path.join(checkpoint_dir, self.param_names[lp]), tp_rank, tp_world_size) + for key in lp._hp_mapping.get_optim_state_keys(): + opt_keys.add(key) + map_to_flat_opt_states(param_group['params'][0], self.bit16_groups[i], self.optimizer.state, opt_keys) def _load_global_state(self, sd): self.loss_scaler = sd.get(LOSS_SCALER, self.loss_scaler) diff --git a/deepspeed/utils/__init__.py b/deepspeed/utils/__init__.py index 33ea8ba60818..75fb6aa9d30a 100644 --- a/deepspeed/utils/__init__.py +++ b/deepspeed/utils/__init__.py @@ -10,7 +10,7 @@ from .groups import * from .nvtx import instrument_w_nvtx # TODO: Move tensor fragment and mixed precision to zero utils -from .tensor_fragment import tensor_fragment, get_full_hp_param, get_hp_fragment_mapping, fragment_address, get_full_hp_grad +from .tensor_fragment import tensor_fragment, get_full_hp_param, get_hp_fragment_mapping, fragment_address, get_full_hp_grad, map_to_flat_opt_states from .tensor_fragment import safe_get_full_fp32_param, safe_get_full_grad, safe_get_full_optimizer_state from .tensor_fragment import set_full_hp_param from .tensor_fragment import safe_set_full_fp32_param, safe_set_full_optimizer_state diff --git a/deepspeed/utils/tensor_fragment.py b/deepspeed/utils/tensor_fragment.py index 49eefafcfbcc..b34722580ddd 100644 --- a/deepspeed/utils/tensor_fragment.py +++ b/deepspeed/utils/tensor_fragment.py @@ -58,6 +58,21 @@ def get_hp_fragment(self, optim_state_key=None): return self.get_optim_state_fragment(optim_state_key) +def map_to_flat_opt_states(flat_hp_tensor, lp_tensors, optim_state, opt_keys): + for key in opt_keys: + hp_param = flat_hp_tensor + buffer = torch.zeros_like(hp_param) + + for lp in lp_tensors: + if lp._hp_mapping is not None: + hp_fragment_address = lp._hp_mapping.get_hp_fragment_address() + hp_fragment = buffer.narrow(0, hp_fragment_address.start, hp_fragment_address.numel) + hp_fragment.data.copy_(lp._hp_mapping.get_hp_fragment(optim_state_key=key).data) + lp._hp_mapping.hp_fragment = hp_fragment + + optim_state[hp_param][key] = buffer + + def get_full_hp_param(self, optim_state_key=None): reduce_buffer = torch.zeros_like(self, dtype=torch.float32).flatten() if self._hp_mapping is not None: