From 7e40c76dbb9f6a193177cb1edf099366d37a46bc Mon Sep 17 00:00:00 2001 From: Baizhou Zhang Date: Thu, 21 Sep 2023 13:52:52 +0800 Subject: [PATCH 1/5] support unsharded saving/loading for model --- .../hybrid_parallel_checkpoint_io.py | 92 ++++++++++++++++--- 1 file changed, 77 insertions(+), 15 deletions(-) diff --git a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py index 41e53b3b388f..5aeabd29cb7e 100644 --- a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py +++ b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py @@ -9,7 +9,6 @@ import torch.distributed as dist import torch.nn as nn from torch.distributed import ProcessGroup -from torch.optim import Optimizer from torch.optim.lr_scheduler import _LRScheduler as LRScheduler from colossalai.cluster import DistCoordinator @@ -24,10 +23,12 @@ get_optimizer_base_filenames, is_safetensors_available, load_shard_state_dict, + load_state_dict, load_state_dict_into_model, load_states_into_optimizer, save_config_file, save_param_groups, + save_state_dict, save_state_dict_shards, search_tp_partition_dim, sharded_optimizer_loading_epilogue, @@ -217,7 +218,7 @@ def save_sharded_model( index_file.append_meta_data("total_size", total_size) index_file.write_index_file(save_index_file) save_config_file(model, checkpoint) - if self.verbose: + if self.verbose and self.coordinator.is_master(): logging.info( f"The model is split into checkpoint shards. " f"You can find where each parameters has been saved in the " @@ -273,7 +274,7 @@ def save_sharded_model( final_index_file.write_index_file(final_index_file_path) save_config_file(model, checkpoint) rmtree(tmp_index_file_folder) - if self.verbose: + if self.verbose and self.coordinator.is_master(): logging.info( f"The model is split into checkpoint shards. " f"You can find where each parameters has been saved in the " @@ -353,7 +354,7 @@ def _load(name: str): # Update master params if mixed-precision training is enabled. model_before_wrapping.update_master_params() - if self.verbose: + if self.verbose and self.coordinator.is_master(): logging.info(f"The model has been successfully loaded from sharded checkpoint: {ckpt_root_path}.") def save_sharded_optimizer( @@ -424,7 +425,7 @@ def save_sharded_optimizer( # Store index file. index_file.append_meta_data("total_size", total_size) index_file.write_index_file(save_index_file) - if self.verbose: + if self.verbose and self.coordinator.is_master(): logging.info( f"The optimizer is going to be split to checkpoint shards. " f"You can find where each parameters has been saved in the " @@ -484,7 +485,7 @@ def save_sharded_optimizer( final_index_file.write_index_file(final_index_file_path) rmtree(tmp_index_file_folder) - if self.verbose: + if self.verbose and self.coordinator.is_master(): logging.info( f"The model is split into checkpoint shards. " f"You can find where each parameters has been saved in the " @@ -579,23 +580,84 @@ def _get_param_id_from_optimizer_param( optimizer.optim.state[param] = sharded_state sharded_optimizer_loading_epilogue(optimizer.optim) - if self.verbose: + if self.verbose and self.coordinator.is_master(): logging.info(f"The optimizer has been successfully loaded from sharded checkpoint: {ckpt_root_path}.") - def load_unsharded_model(self, model: nn.Module, checkpoint: str, strict: bool = True): - # TODO(Baizhou): support this feature after implementing complete state_dict collection - raise NotImplementedError + def save_unsharded_model(self, model: ModelWrapper, checkpoint: str, gather_dtensor: bool, use_safetensors: bool): + """ + Save model state dict to a single file with given checkpointing path. - def save_unsharded_model(self, model: nn.Module, checkpoint: str, gather_dtensor: bool, use_safetensors: bool): - # TODO(Baizhou): support this feature after implementing complete state_dict collection - raise NotImplementedError + Args: + model (nn.Module): Model on local device to be saved. + checkpoint (str): Checkpointing path which should be a file path. Can be absolute or relative path. + gather_dtensor (bool, optional): Whether to gather dtensor, currently not used. Defaults to True. + use_safetensors (bool, optional): Whether to use safe tensors. Defaults to False. + """ + if self.coordinator.is_master(): + logging.warning("Please avoid using unsharded checkpointing methods when dealing with large models!") + + assert isinstance(model, ModelWrapper), "Please boost the model before saving!" + model = model.unwrap() + + if self.dp_rank != 0: + return + + # The logic of collecting parameter shards along tp degree + # has been implemented by _save_to_state_dict method of ParallelModule in Shardformer. + state_dict = model.state_dict() + + if self.pp_size == 1: + # When pipeline is not used, let master rank directly save the collected state_dict. + if self.tp_rank == 0: + save_state_dict(state_dict, checkpoint, use_safetensors) + else: + # When pipeline is used, first collect state_dict from every pipeline stage, then save the complete state_dict. + state_dict_list = [None for _ in range(self.pp_size)] + dist.barrier(self.pp_group) + dist.all_gather_object(state_dict_list, state_dict, self.pp_group) + + # Only the master rank do the saving. + if self.coordinator.is_master(): + complete_state_dict = dict() + for _state_dict in state_dict_list: + complete_state_dict.update(_state_dict) + save_state_dict(complete_state_dict, checkpoint, use_safetensors) + + def load_unsharded_model(self, model: ModelWrapper, checkpoint: str, strict: bool = False): + """ + Load model from a single file with the given path of checkpoint. + + Args: + model (nn.Module): The model to be loaded. + checkpoint_index_file (str): Path to the checkpoint file. + strict (bool, optional): For name matching during loading state_dict. Defaults to False. + This argument should be manually set to False since not all params in checkpoint are needed for each device when pipeline is enabled. + """ + if self.coordinator.is_master(): + logging.warning("Please avoid using unsharded checkpointing methods when dealing with large models!") + + assert isinstance(model, ModelWrapper), "Please boost the model before loading!" + strict = False + model_before_wrapping = model + model = model.unwrap() + + # Load from checkpoint. Since the logic of breaking parameter shards along tp degree + # has been implemented by _load_from_state_dict method of ParallelModule in Shardformer, + # model.load_state_dict can be directly called. + state_dict = load_state_dict(checkpoint) + model.load_state_dict(state_dict, strict=strict) + + # Update master params if mixed-precision training is enabled. + model_before_wrapping.update_master_params() - def save_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: str, gather_dtensor: bool): + def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool): # TODO(Baizhou): support this feature after implementing complete state_dict collection + logging.warning("Please avoid using unsharded checkpointing methods when dealing with large models!") raise NotImplementedError - def load_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: str, gather_dtensor: bool): + def load_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool): # TODO(Baizhou): support this feature after implementing complete state_dict collection + logging.warning("Please avoid using unsharded checkpointing methods when dealing with large models!") raise NotImplementedError def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str): From 093f726a99c00cbcb92d0d7ab74e67f53283e7f9 Mon Sep 17 00:00:00 2001 From: Baizhou Zhang Date: Thu, 21 Sep 2023 17:49:27 +0800 Subject: [PATCH 2/5] support optimizer unsharded saving --- .../hybrid_parallel_checkpoint_io.py | 70 +++++++++++++++++-- ...st_hybrid_parallel_plugin_checkpoint_io.py | 3 +- 2 files changed, 64 insertions(+), 9 deletions(-) diff --git a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py index 5aeabd29cb7e..90790bcfe150 100644 --- a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py +++ b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py @@ -120,13 +120,13 @@ def _optimizer_sharder( use_zero: bool, dp_group: ProcessGroup, tp_group: ProcessGroup, - master_to_working_map: Optional[Dict[int, torch.Tensor]] = None, size_per_shard: int = 1024, ): # An internel method that breaks state_dict of optimizer into shards within limited size. state_dict_sharder = StateDictSharder(size_per_shard) param_info = optimizer.param_info + master_to_working_map = optimizer.get_master_to_working_map() for param, state in optimizer.optim.state.items(): if param is None: @@ -400,7 +400,6 @@ def save_sharded_optimizer( use_zero=self.use_zero, dp_group=self.dp_group, tp_group=self.tp_group, - master_to_working_map=optimizer.get_master_to_working_map(), size_per_shard=size_per_shard, ) states_name, save_index_file, param_group_file = get_optimizer_base_filenames(prefix) @@ -651,13 +650,68 @@ def load_unsharded_model(self, model: ModelWrapper, checkpoint: str, strict: boo model_before_wrapping.update_master_params() def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool): - # TODO(Baizhou): support this feature after implementing complete state_dict collection - logging.warning("Please avoid using unsharded checkpointing methods when dealing with large models!") - raise NotImplementedError + """ + Save optimizer state dict to a checkpoint file with given path. + + Args: + optimizer (OptimizerWrapper): Optimizer to save sharded state_dict. + checkpoint (str): Path to save optimizer state_dict. + gather_dtensor (bool): Whether to gather_dtensor, not used. + """ + if self.coordinator.is_master(): + logging.warning("Please avoid using unsharded checkpointing methods when dealing with large models!") + + assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before saving!" + + local_states = dict() + + for param, state in optimizer.optim.state.items(): + if param is None: + continue + + # working param is needed for obtaining correct param_id + master_to_working_map = optimizer.get_master_to_working_map() + if master_to_working_map is not None: + working_param = master_to_working_map[id(param)] + else: + working_param = param + + # gather complete state from tp shards & dp shards + param_id = optimizer.param_info["param2id"][id(working_param)] + original_shape = optimizer.param_info["param2shape"][id(working_param)] + local_states[param_id] = HybridParallelCheckpointIO.gather_from_sharded_optimizer_state( + state, + working_param, + original_shape=original_shape, + dp_group=self.dp_group, + tp_group=self.tp_group, + use_zero=self.use_zero, + inplace=False, + device=torch.device("cuda"), + ) + + if self.pp_size == 1: + # When pipeline is not used, let master rank directly save the collected state_dict. + state_dict = {"param_groups": optimizer.param_info["param_groups"], "state": local_states} + if self.coordinator.is_master(): + save_state_dict(state_dict, checkpoint, use_safetensors=False) + else: + # When pipeline is used, first collect state_dict from every pipeline stage, then save the complete state_dict. + states_list = [None for _ in range(self.pp_size)] + dist.barrier(self.pp_group) + dist.all_gather_object(states_list, local_states, self.pp_group) + + # Only the master rank do the saving. + if self.coordinator.is_master(): + state_dict = {"param_groups": optimizer.param_info["param_groups"], "state": dict()} + for _states in states_list: + state_dict["state"].update(_states) + save_state_dict(state_dict, checkpoint, use_safetensors=False) def load_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool): # TODO(Baizhou): support this feature after implementing complete state_dict collection - logging.warning("Please avoid using unsharded checkpointing methods when dealing with large models!") + if self.coordinator.is_master(): + logging.warning("Please avoid using unsharded checkpointing methods when dealing with large models!") raise NotImplementedError def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str): @@ -676,6 +730,7 @@ def gather_from_sharded_optimizer_state( tp_group: ProcessGroup, use_zero: bool, inplace: bool, + device: torch.device = torch.device("cpu"), ) -> OrderedDict: """ With given parameter and its optimizer states, gather the complete optimizer state for saving. @@ -688,6 +743,7 @@ def gather_from_sharded_optimizer_state( tp_group (ProcessGroup): The process group of tensor parallel. use_zero (bool): Whether Zero is used. inplace (bool): If set to True, will update the values of argument 'state' in place. Else will make a copy of state. + device (torch.device): The destination device of loaded optimizer states. Defaults to torch.device('cpu'). Returns: OrderedDict: The complete optimizer state of given parameter. @@ -713,7 +769,7 @@ def gather_from_sharded_optimizer_state( dist.all_gather(gather_tensor, v, group=tp_group) v = torch.cat(gather_tensor, dim=partition_dim) - state_[k] = v.detach().clone().cpu() + state_[k] = v.detach().clone().to(device) return state_ diff --git a/tests/test_checkpoint_io/test_hybrid_parallel_plugin_checkpoint_io.py b/tests/test_checkpoint_io/test_hybrid_parallel_plugin_checkpoint_io.py index e8bb8f9e3475..711bd4d214a8 100644 --- a/tests/test_checkpoint_io/test_hybrid_parallel_plugin_checkpoint_io.py +++ b/tests/test_checkpoint_io/test_hybrid_parallel_plugin_checkpoint_io.py @@ -20,9 +20,8 @@ from tests.kit.model_zoo import model_zoo -# TODO (Baizhou): Add test cases for shard=False @clear_cache_before_run() -@parameterize("shard", [True]) +@parameterize("shard", [True, False]) @parameterize("model_name", ["transformers_gpt"]) @parameterize("size_per_shard", [32]) @parameterize( From 1184ab518f4ecb8fb02c68d9bb1a9b16ca97506e Mon Sep 17 00:00:00 2001 From: Baizhou Zhang Date: Thu, 21 Sep 2023 17:52:29 +0800 Subject: [PATCH 3/5] update doc --- docs/source/en/basics/booster_plugins.md | 2 -- docs/source/zh-Hans/basics/booster_plugins.md | 2 -- 2 files changed, 4 deletions(-) diff --git a/docs/source/en/basics/booster_plugins.md b/docs/source/en/basics/booster_plugins.md index 57fa813436da..a3df44fc6780 100644 --- a/docs/source/en/basics/booster_plugins.md +++ b/docs/source/en/basics/booster_plugins.md @@ -74,8 +74,6 @@ This plugin implements the combination of various parallel training strategies a > ⚠ When using this plugin, only the subset of Huggingface transformers supported by Shardformer are compatible with tensor parallel, pipeline parallel and optimization tools. Mainstream transformers such as Llama 1, Llama 2, OPT, Bloom, Bert and GPT2 etc. are all supported by Shardformer. -> ⚠ This plugin only supports sharded checkpointing methods for model/optimizer at present. Unsharded checkpointing methods will be supported in future release. - {{ autodoc:colossalai.booster.plugin.HybridParallelPlugin }} ### Torch DDP Plugin diff --git a/docs/source/zh-Hans/basics/booster_plugins.md b/docs/source/zh-Hans/basics/booster_plugins.md index d4ef7012ff67..8d8a288da949 100644 --- a/docs/source/zh-Hans/basics/booster_plugins.md +++ b/docs/source/zh-Hans/basics/booster_plugins.md @@ -71,8 +71,6 @@ Zero-2 不支持局部梯度累积。如果您坚持使用,虽然可以积累 > ⚠ 在使用该插件的时候, 只有支持Shardformer的部分Huggingface transformers模型才能够使用张量并行、流水线并行以及优化工具。Llama 1、Llama 2、OPT、Bloom、Bert以及GPT2等主流transformers模型均已支持Shardformer。 -> ⚠ 该插件当前只对模型和优化器支持分片的checkpoint方法。不分片的checkpoint方法会在未来的版本中被支持。 - {{ autodoc:colossalai.booster.plugin.HybridParallelPlugin }} ### Torch DDP 插件 From 0a6e09b6ff0e08d48fbd0771b931f6a829691105 Mon Sep 17 00:00:00 2001 From: Baizhou Zhang Date: Thu, 21 Sep 2023 19:49:32 +0800 Subject: [PATCH 4/5] support unsharded loading for optimizer --- .../hybrid_parallel_checkpoint_io.py | 62 +++++++++++++++++-- 1 file changed, 58 insertions(+), 4 deletions(-) diff --git a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py index 90790bcfe150..7155cd992d8a 100644 --- a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py +++ b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py @@ -651,7 +651,7 @@ def load_unsharded_model(self, model: ModelWrapper, checkpoint: str, strict: boo def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool): """ - Save optimizer state dict to a checkpoint file with given path. + Save optimizer state dict to a file with given path. Args: optimizer (OptimizerWrapper): Optimizer to save sharded state_dict. @@ -663,6 +663,7 @@ def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before saving!" + # optimizer states of parameters kept by local device('s pipeline stage) local_states = dict() for param, state in optimizer.optim.state.items(): @@ -708,11 +709,64 @@ def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, state_dict["state"].update(_states) save_state_dict(state_dict, checkpoint, use_safetensors=False) - def load_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool): - # TODO(Baizhou): support this feature after implementing complete state_dict collection + def load_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str): + """ + Load optimizer from a file with given path. + + Args: + optimizer (OptimizerWrapper): The optimizer to be loaded. + checkpoint_index_file (str): Path to the checkpoint file. + """ + + def _get_param_id_from_optimizer_param( + param: torch.Tensor, master_to_working_map: Optional[Dict[int, torch.Tensor]] = None + ): + if master_to_working_map is not None: + working_param = master_to_working_map[id(param)] + else: + working_param = param + return optimizer.param_info["param2id"][id(working_param)] + if self.coordinator.is_master(): logging.warning("Please avoid using unsharded checkpointing methods when dealing with large models!") - raise NotImplementedError + + assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before loading!" + + # Complete optimizer state_dict loaded from checkpoint, need to be processed later. + state_dict = load_state_dict(checkpoint) + + # Load param_groups. + updated_groups = [] + saved_groups = state_dict["param_groups"] + for old_pg, saved_pg in zip(optimizer.optim.param_groups, saved_groups): + new_pg = copy.deepcopy(saved_pg) + new_pg["params"] = old_pg["params"] + updated_groups.append(new_pg) + optimizer.optim.__dict__.update({"param_groups": updated_groups}) + + # Load saved states to optimizer. First discard those states not belonging to current pipeline stage. + master_to_working_map = optimizer.get_master_to_working_map() + id_map = {} + for pg in optimizer.optim.param_groups: + for param in pg["params"]: + param_id = _get_param_id_from_optimizer_param(param, master_to_working_map) + id_map[param_id] = param + load_states_into_optimizer(optimizer.optim, state_dict["state"], id_map, strict=True) + + # Then shard the loaded optimizer states if using tp/zero. + for param, state in optimizer.optim.state.items(): + device = param.device + if master_to_working_map is not None: + working_param = master_to_working_map[id(param)] + else: + working_param = param + original_shape = optimizer.param_info["param2shape"][id(working_param)] + sharded_state = self.shard_from_complete_optimizer_state( + state, current_shape=working_param.shape, original_shape=original_shape, device=device, inplace=True + ) + optimizer.optim.state[param] = sharded_state + + sharded_optimizer_loading_epilogue(optimizer.optim) def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str): """ From 49f53a37ae897723db8cd9527260712db11cde16 Mon Sep 17 00:00:00 2001 From: Baizhou Zhang Date: Mon, 25 Sep 2023 15:30:58 +0800 Subject: [PATCH 5/5] small fix --- colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py index 7155cd992d8a..779ff42d75a1 100644 --- a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py +++ b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py @@ -740,7 +740,7 @@ def _get_param_id_from_optimizer_param( saved_groups = state_dict["param_groups"] for old_pg, saved_pg in zip(optimizer.optim.param_groups, saved_groups): new_pg = copy.deepcopy(saved_pg) - new_pg["params"] = old_pg["params"] + new_pg["params"] = old_pg["params"] # Only keep the parameters kept by current pipeline stage. updated_groups.append(new_pg) optimizer.optim.__dict__.update({"param_groups": updated_groups}) @@ -755,6 +755,8 @@ def _get_param_id_from_optimizer_param( # Then shard the loaded optimizer states if using tp/zero. for param, state in optimizer.optim.state.items(): + if param is None: + continue device = param.device if master_to_working_map is not None: working_param = master_to_working_map[id(param)]