Skip to content

Commit

Permalink
[ckpt] Add async ckpt api (#6136)
Browse files Browse the repository at this point in the history
* fix

* fix

* fix

* fix

* fix

* fix

* fix

* fix

* fix

* fix

* fix

* fix
  • Loading branch information
wangbluo authored Nov 15, 2024
1 parent 79224bd commit cb00f1f
Show file tree
Hide file tree
Showing 12 changed files with 172 additions and 84 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/build_on_pr.yml
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ jobs:
cd TensorNVMe
conda install cmake
pip install -r requirements.txt
DISABLE_URING=1 pip install -v .
DISABLE_URING=1 pip install -v --no-cache-dir .
- name: Store TensorNVMe Cache
run: |
Expand Down
1 change: 1 addition & 0 deletions colossalai/booster/booster.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,7 @@ def save_model(
names to compose the keys in state_dict. Defaults to None.
size_per_shard (int, optional): Maximum size of checkpoint shard file in MB. This is useful only when ``shard=True``. Defaults to 1024.
use_safetensors (bool, optional): whether to use safe tensors. Default: False. If set to True, the checkpoint will be saved.
use_async (bool, optional): whether to save the state_dict of model asynchronously. Default: False.
"""
self.checkpoint_io.save_model(
model,
Expand Down
57 changes: 37 additions & 20 deletions colossalai/booster/plugin/gemini_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,14 @@ def __init__(self) -> None:
self.coordinator = DistCoordinator()
self.logger = get_dist_logger()

def save_unsharded_model(self, model: GeminiDDP, checkpoint: str, gather_dtensor: bool, use_safetensors: bool):
def save_unsharded_model(
self,
model: GeminiDDP,
checkpoint: str,
gather_dtensor: bool,
use_safetensors: bool,
use_async: bool = False,
):
"""
Save sharded model to checkpoint but only on master process.
The model should be unwrapped in self.load_model via ModelWrapper.unwrap.
Expand All @@ -74,7 +81,10 @@ def save_unsharded_model(self, model: GeminiDDP, checkpoint: str, gather_dtensor
assert isinstance(model, GeminiDDP), "Please boost the model before saving!"
state_dict = model.state_dict(only_rank_0=True)
if self.coordinator.is_master():
save_state_dict(state_dict, checkpoint, use_safetensors)
if use_async:
super().save_unsharded_model(model, checkpoint, gather_dtensor, use_safetensors, use_async)
else:
save_state_dict(state_dict, checkpoint, use_safetensors)

def load_unsharded_model(self, model: GeminiDDP, checkpoint: str, strict: bool = True):
"""
Expand Down Expand Up @@ -112,6 +122,7 @@ def save_sharded_model(
prefix: Optional[str] = None,
max_shard_size: int = 1024,
use_safetensors: bool = False,
use_async: bool = False,
):
"""
Save sharded model.
Expand All @@ -130,27 +141,33 @@ def save_sharded_model(

# Save shards of optimizer states.
is_master = self.coordinator.is_master()
total_size = save_state_dict_shards(
sharded_state_dict=state_dict_shard,
checkpoint=checkpoint_path,
index_file=index_file,
base_filename=weights_name,
is_master=is_master,
use_safetensors=use_safetensors,
)
if use_async:
super().save_sharded_model(
model, checkpoint_path, gather_dtensor, prefix, max_shard_size, use_safetensors, use_async
)

# only save the index file on the master rank
if self.coordinator.is_master():
index_file.append_meta_data("total_size", total_size)
index_file.write_index_file(save_index_file)
save_config_file(model.unwrap(), checkpoint_path)
self.logger.info(
f"The model is split into checkpoint shards. "
f"You can find where each parameters has been saved in the "
f"index located at {save_index_file}.",
ranks=[0],
else:
total_size = save_state_dict_shards(
sharded_state_dict=state_dict_shard,
checkpoint=checkpoint_path,
index_file=index_file,
base_filename=weights_name,
is_master=is_master,
use_safetensors=use_safetensors,
)

# only save the index file on the master rank
if self.coordinator.is_master():
index_file.append_meta_data("total_size", total_size)
index_file.write_index_file(save_index_file)
save_config_file(model.unwrap(), checkpoint_path)
self.logger.info(
f"The model is split into checkpoint shards. "
f"You can find where each parameters has been saved in the "
f"index located at {save_index_file}.",
ranks=[0],
)

def load_sharded_model(
self, model: GeminiDDP, checkpoint_index_file: Path, strict: bool = False, use_safetensors: bool = False
):
Expand Down
5 changes: 4 additions & 1 deletion colossalai/booster/plugin/torch_fsdp_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,9 @@ def load_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: Path
sharded_osd = FSDP.scatter_full_optim_state_dict(checkpoint, fsdp_model)
optimizer.load_state_dict(sharded_osd)

def save_unsharded_model(self, model: ModelWrapper, checkpoint: str, gather_dtensor: bool, use_safetensors: bool):
def save_unsharded_model(
self, model: ModelWrapper, checkpoint: str, gather_dtensor: bool, use_safetensors: bool, use_async: bool = False
):
"""
Save model to checkpoint but only on master process.
"""
Expand Down Expand Up @@ -82,6 +84,7 @@ def save_sharded_model(
prefix: Optional[str] = None,
size_per_shard: int = 1024,
use_safetensors: bool = False,
use_async: bool = False,
):
"""
Save model to checkpoint but only on master process.
Expand Down
4 changes: 2 additions & 2 deletions colossalai/checkpoint_io/checkpoint_io_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,10 +176,10 @@ def save_model(

if shard:
self.save_sharded_model(
model, checkpoint, gather_dtensor, prefix, size_per_shard, use_safetensors, use_async=use_async
model, checkpoint, gather_dtensor, prefix, size_per_shard, use_safetensors, use_async
)
else:
self.save_unsharded_model(model, checkpoint, gather_dtensor, use_safetensors, use_async=use_async)
self.save_unsharded_model(model, checkpoint, gather_dtensor, use_safetensors, use_async)

def load_optimizer(self, optimizer: Optimizer, checkpoint: str, prefix: str = None, size_per_shard: int = 1024):
"""
Expand Down
2 changes: 1 addition & 1 deletion colossalai/checkpoint_io/general_checkpoint_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,8 @@ def save_unsharded_model(
self.pinned_state_dicts[id(model)] = create_pinned_state_dict(state_dict)
self.async_writers.append(writer)
move_and_save(writer, state_dict, self.pinned_state_dicts[id(model)])
else:

else:
# save the checkpoint
save_state_dict(state_dict, checkpoint, use_safetensors)

Expand Down
77 changes: 55 additions & 22 deletions colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
from .index_file import CheckpointIndexFile
from .utils import (
StateDictSharder,
async_save_state_dict_shards,
create_pinned_state_dict,
gather_distributed_param,
get_model_base_filenames,
get_optimizer_base_filenames,
Expand Down Expand Up @@ -177,6 +179,7 @@ def save_sharded_model(
prefix: Optional[str] = None,
size_per_shard: int = 1024,
use_safetensors: bool = False,
use_async: bool = False,
) -> None:
"""
Save sharded model checkpoint under the given checkpointing path.
Expand All @@ -194,6 +197,7 @@ def save_sharded_model(
prefix (str, optional): Perfix of file to save. Defaults to None.
size_per_shard (int, optional): Size per shard in MB. Defaults to 1024.
use_safetensors (bool, optional): Whether to use safe tensors. Defaults to False.
use_async (bool, optional): Whether to save the state_dicts of model asynchronously. Defaults to False.
"""

assert isinstance(model, ModelWrapper), "Please boost the model before saving!"
Expand All @@ -219,24 +223,27 @@ def save_sharded_model(

if self.pp_size == 1:
# When pipeline is not used, save the model shards as in general checkpointIO
total_size = save_state_dict_shards(
sharded_state_dict=state_dict_shard,
checkpoint=checkpoint,
index_file=index_file,
base_filename=weights_name,
is_master=control_saving,
use_safetensors=use_safetensors,
)
if control_saving:
index_file.append_meta_data("total_size", total_size)
index_file.write_index_file(save_index_file)
save_config_file(model, checkpoint)
if self.verbose and self.coordinator.is_master():
logging.info(
f"The model is split into checkpoint shards. "
f"You can find where each parameters has been saved in the "
f"index located at {save_index_file}."
)
if use_async:
super().save_unsharded_model(model, checkpoint, gather_dtensor, use_safetensors, use_async=use_async)
else:
total_size = save_state_dict_shards(
sharded_state_dict=state_dict_shard,
checkpoint=checkpoint,
index_file=index_file,
base_filename=weights_name,
is_master=control_saving,
use_safetensors=use_safetensors,
)
if control_saving:
index_file.append_meta_data("total_size", total_size)
index_file.write_index_file(save_index_file)
save_config_file(model, checkpoint)
if self.verbose and self.coordinator.is_master():
logging.info(
f"The model is split into checkpoint shards. "
f"You can find where each parameters has been saved in the "
f"index located at {save_index_file}."
)

else:
# When pipeline is used, each stage produces its own shard files and index files.
Expand All @@ -251,7 +258,16 @@ def save_sharded_model(
weights_name = weights_name.replace(".safetensors", f"-stage-{self.pp_rank+1:05d}-shard.safetensors")
save_index_file = save_index_file.replace(".json", f"-stage-{self.pp_rank+1:05d}.json")
save_index_file = os.path.join("tmp_index_files", save_index_file)

if use_async:
total_size, returned_state_dict, writers = async_save_state_dict_shards(
sharded_state_dict=state_dict_shard,
checkpoint=checkpoint,
index_file=index_file,
base_filename=weights_name,
is_master=control_saving,
use_pp_format=True,
n_write_entries=191,
)
total_size = save_state_dict_shards(
sharded_state_dict=state_dict_shard,
checkpoint=checkpoint,
Expand Down Expand Up @@ -626,7 +642,9 @@ def _get_param_id_from_optimizer_param(
if self.verbose and self.coordinator.is_master():
logging.info(f"The optimizer has been successfully loaded from sharded checkpoint: {ckpt_root_path}.")

def save_unsharded_model(self, model: ModelWrapper, checkpoint: str, gather_dtensor: bool, use_safetensors: bool):
def save_unsharded_model(
self, model: ModelWrapper, checkpoint: str, gather_dtensor: bool, use_safetensors: bool, use_async: bool = False
):
"""
Save model state dict to a single file with given checkpointing path.
Expand All @@ -635,6 +653,7 @@ def save_unsharded_model(self, model: ModelWrapper, checkpoint: str, gather_dten
checkpoint (str): Checkpointing path which should be a file path. Can be absolute or relative path.
gather_dtensor (bool, optional): Whether to gather dtensor, currently not used. Defaults to True.
use_safetensors (bool, optional): Whether to use safe tensors. Defaults to False.
use_async (bool, optional): Whether to save the state_dicts of model asynchronously. Defaults to False.
"""
if self.coordinator.is_master():
logging.warning("Please avoid using unsharded checkpointing methods when dealing with large models!")
Expand All @@ -651,7 +670,10 @@ def save_unsharded_model(self, model: ModelWrapper, checkpoint: str, gather_dten
if self.pp_size == 1:
# When pipeline is not used, let master rank directly save the collected state_dict.
if self.tp_rank == 0:
save_state_dict(state_dict, checkpoint, use_safetensors)
if use_async:
super().save_unsharded_model(model, checkpoint, gather_dtensor, use_safetensors, use_async)
else:
save_state_dict(state_dict, checkpoint, use_safetensors)
else:
# When pipeline is used, first collect state_dict from every pipeline stage, then save the complete state_dict.
state_dict_list = [None for _ in range(self.pp_size)]
Expand All @@ -662,7 +684,18 @@ def save_unsharded_model(self, model: ModelWrapper, checkpoint: str, gather_dten
complete_state_dict = dict()
for _state_dict in state_dict_list:
complete_state_dict.update(_state_dict)
save_state_dict(complete_state_dict, checkpoint, use_safetensors)
if use_async:
from tensornvme.async_file_io import AsyncFileWriter

from colossalai.utils.safetensors import move_and_save

writer = AsyncFileWriter(open(checkpoint, "wb"), self.N_WRITE_ENTRIES, backend="pthread")
if id(model) not in self.pinned_state_dicts:
self.pinned_state_dicts[id(model)] = create_pinned_state_dict(state_dict)
self.async_writers.append(writer)
move_and_save(writer, state_dict, self.pinned_state_dicts[id(model)])
else:
save_state_dict(complete_state_dict, checkpoint, use_safetensors)

def load_unsharded_model(self, model: ModelWrapper, checkpoint: str, strict: bool = False):
"""
Expand Down
52 changes: 33 additions & 19 deletions colossalai/checkpoint_io/moe_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ def save_sharded_model(
prefix: Optional[str] = None,
size_per_shard: int = 1024,
use_safetensors: bool = False,
use_async: bool = False,
) -> None:
"""
Save sharded model checkpoint under the given checkpointing path.
Expand Down Expand Up @@ -161,24 +162,27 @@ def save_sharded_model(

if self.pp_size == 1 and self.ep_size == 1:
# When pipeline is not used, save the model shards as in general checkpointIO
total_size = save_state_dict_shards(
sharded_state_dict=state_dict_shard,
checkpoint=checkpoint,
index_file=index_file,
base_filename=weights_name,
is_master=control_saving,
use_safetensors=use_safetensors,
)
if control_saving:
index_file.append_meta_data("total_size", total_size)
index_file.write_index_file(save_index_file)
save_config_file(model, checkpoint)
if self.verbose and self.coordinator.is_master():
logging.info(
f"The model is split into checkpoint shards. "
f"You can find where each parameters has been saved in the "
f"index located at {save_index_file}."
)
if use_async:
super().save_unsharded_model(model, checkpoint, gather_dtensor, use_safetensors, use_async=use_async)
else:
total_size = save_state_dict_shards(
sharded_state_dict=state_dict_shard,
checkpoint=checkpoint,
index_file=index_file,
base_filename=weights_name,
is_master=control_saving,
use_safetensors=use_safetensors,
)
if control_saving:
index_file.append_meta_data("total_size", total_size)
index_file.write_index_file(save_index_file)
save_config_file(model, checkpoint)
if self.verbose and self.coordinator.is_master():
logging.info(
f"The model is split into checkpoint shards. "
f"You can find where each parameters has been saved in the "
f"index located at {save_index_file}."
)

dist.barrier()
else:
Expand Down Expand Up @@ -708,10 +712,20 @@ def save_unsharded_model(
checkpoint: str,
gather_dtensor: bool,
use_safetensors: bool,
use_async: bool = False,
):
state_dict = self.pre_save_model(model)
if dist.get_rank() == 0:
torch.save(state_dict, checkpoint)
if use_async:
super().save_unsharded_model(
model=model,
checkpoint=checkpoint,
gather_dtensor=gather_dtensor,
use_safetensors=use_safetensors,
use_async=use_async,
)
else:
torch.save(state_dict, checkpoint)
dist.barrier()

# Copied from colossalai.moe
Expand Down
13 changes: 5 additions & 8 deletions colossalai/checkpoint_io/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,7 +371,11 @@ def shard_optimizer_checkpoint(state_dict: dict, max_shard_size: int = 1024) ->
# ======================================


def save_state_dict(state_dict: dict, checkpoint_file_path: str, use_safetensors: bool) -> None:
def save_state_dict(
state_dict: dict,
checkpoint_file_path: str,
use_safetensors: bool,
) -> None:
"""
Save state dict to checkpoint.
Expand Down Expand Up @@ -581,14 +585,7 @@ def load_shard_state_dict(checkpoint_file: Path, use_safetensors: bool = False):
raise Exception("load the model using `safetensors`, but no file endwith .safetensors")
if use_safetensors:
from safetensors.torch import load_file as safe_load_file
from safetensors.torch import safe_open

with safe_open(checkpoint_file, framework="pt") as f:
metadata = f.metadata()
if metadata["format"] != "pt":
raise NotImplementedError(
f"Conversion from a {metadata['format']} safetensors archive to PyTorch is not implemented yet."
)
return safe_load_file(checkpoint_file)
else:
return torch.load(checkpoint_file, map_location=torch.device("cpu"))
Expand Down
Loading

0 comments on commit cb00f1f

Please sign in to comment.