Skip to content

Commit

Permalink
feat(ckpt): optimize model checkpointing in Volc and Ali (#65)
Browse files Browse the repository at this point in the history
  • Loading branch information
season0528 authored Mar 4, 2024
1 parent 4c72165 commit e465142
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 22 deletions.
5 changes: 4 additions & 1 deletion internlm/utils/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -927,8 +927,11 @@ def __init__(

self.async_upload = get_config_value(ckpt_config, "async_upload", False)

use_processpool = self.save_ckpt_folder is not None and (
self.save_ckpt_folder.startswith("volc:") or self.save_ckpt_folder.startswith("oss2:")
)
# initialization storage manager
init_storage_manager(self.enable_save_ckpt, self.async_upload_tmp_folder, self.async_upload)
init_storage_manager(self.enable_save_ckpt, self.async_upload_tmp_folder, self.async_upload, use_processpool)

self.feishu_address = feishu_address
self.storage_manager = get_storage_manager()
Expand Down
68 changes: 47 additions & 21 deletions internlm/utils/storage_manager.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,41 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-

import asyncio
import concurrent.futures
import hashlib
import io
import os
import pickle
import re
import socket
import stat
from asyncio import InvalidStateError
from asyncio.tasks import ALL_COMPLETED
from datetime import datetime
from typing import Any, Awaitable, Callable, Dict, List, Union

import torch
import torch.distributed as dist
import multiprocessing

import dill

dill.Pickler.dumps, dill.Pickler.loads = dill.dumps, dill.loads
multiprocessing.reduction.ForkingPickler = dill.Pickler
multiprocessing.reduction.dump = dill.dump

import asyncio # noqa: E402 #pylint: disable=wrong-import-position
import concurrent.futures # noqa: E402 #pylint: disable=wrong-import-position
import hashlib # noqa: E402 #pylint: disable=wrong-import-position
import io # noqa: E402 #pylint: disable=wrong-import-position
import os # noqa: E402 #pylint: disable=wrong-import-position
import pickle # noqa: E402 #pylint: disable=wrong-import-position
import re # noqa: E402 #pylint: disable=wrong-import-position
import socket # noqa: E402 #pylint: disable=wrong-import-position
import stat # noqa: E402 #pylint: disable=wrong-import-position
from asyncio import ( # noqa: E402 #pylint: disable=wrong-import-position
InvalidStateError,
)
from asyncio.tasks import ( # noqa: E402 #pylint: disable=wrong-import-position
ALL_COMPLETED,
)
from datetime import datetime # noqa: E402 #pylint: disable=wrong-import-position
from typing import ( # noqa: E402 #pylint: disable=wrong-import-position
Any,
Awaitable,
Callable,
Dict,
List,
Union,
)

import torch # noqa: E402 #pylint: disable=wrong-import-position
import torch.distributed as dist # noqa: E402 #pylint: disable=wrong-import-position

try:
import boto3
Expand Down Expand Up @@ -976,7 +995,9 @@ class StorageManager(metaclass=SingletonMeta):
}
CLI_DICT = {}

def __init__(self, enable_save, tmp_local_folder="/dev/shm/test/", async_mode=True, n_async_workers=8) -> None:
def __init__(
self, enable_save, tmp_local_folder="/dev/shm/test/", async_mode=True, use_processpool=False, n_async_workers=8
) -> None:
self._exception_list = []
self._to_be_del_files = []
self._async_stack = []
Expand All @@ -985,14 +1006,18 @@ def __init__(self, enable_save, tmp_local_folder="/dev/shm/test/", async_mode=Tr
self.async_mode = async_mode
self.has_warning = False
self._async_loop = None
self._thread_pool = None
self._executor_pool = None
self.latest_save_folder = None
self.latest_save_step = 0
self.async_task_peeding = False

if enable_save and self.async_mode:
self._async_loop = asyncio.new_event_loop()
self._thread_pool = concurrent.futures.ThreadPoolExecutor(max_workers=n_async_workers)

if use_processpool:
self._executor_pool = concurrent.futures.ProcessPoolExecutor(max_workers=n_async_workers)
else:
self._executor_pool = concurrent.futures.ThreadPoolExecutor(max_workers=n_async_workers)

check_tmp_folder_accessibility(os.path.dirname(self.tmp_local_folder))

Expand Down Expand Up @@ -1196,7 +1221,7 @@ def async_executor(self, fn: Callable, *args, **kwargs) -> None:
"""
if not self._async_loop:
raise RuntimeError("Event loop was not initialized, please call this function in async or parallel mode")
t = self._async_loop.run_in_executor(self._thread_pool, fn, *args, **kwargs)
t = self._async_loop.run_in_executor(self._executor_pool, fn, *args, **kwargs)
self._async_stack.append(t)

def wait(self) -> bool:
Expand Down Expand Up @@ -1242,12 +1267,13 @@ def wait(self) -> bool:
storage_manager: StorageManager = None


def init_storage_manager(enable_save_ckpt, async_upload_tmp_folder, async_upload):
def init_storage_manager(enable_save_ckpt, async_upload_tmp_folder, async_upload, use_processpool=False):
global storage_manager
storage_manager = StorageManager(
enable_save_ckpt,
tmp_local_folder=async_upload_tmp_folder,
async_mode=async_upload,
use_processpool=use_processpool,
)


Expand Down
8 changes: 8 additions & 0 deletions tests/test_utils/test_model_checkpoint.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
import multiprocessing

backup_ForkingPickler= multiprocessing.reduction.ForkingPickler
backup_dump = multiprocessing.reduction.dump
import os
from functools import partial

Expand Down Expand Up @@ -344,6 +348,10 @@ def query_quit_file(rank, world_size=2):

def test_quit_siganl_handler(): # noqa # pylint: disable=unused-import
import multiprocessing
# we do hack here to workaround the bug of 3rd party library dill, which only occurs in this unittest:
# https://github.com/uqfoundation/dill/issues/380
multiprocessing.reduction.ForkingPickler = backup_ForkingPickler
multiprocessing.reduction.dump = backup_dump
from multiprocessing.pool import Pool

world_size = 2
Expand Down

0 comments on commit e465142

Please sign in to comment.