Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[lazy] support from_pretrained #4801

Merged
merged 3 commits into from
Sep 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .github/workflows/build_on_pr.yml
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ jobs:
runs-on: [self-hosted, gpu]
container:
image: hpcaitech/pytorch-cuda:1.12.0-11.3.0
options: --gpus all --rm -v /data/scratch/cifar-10:/data/scratch/cifar-10
options: --gpus all --rm -v /data/scratch/cifar-10:/data/scratch/cifar-10 -v /data/scratch/llama-tiny:/data/scratch/llama-tiny
timeout-minutes: 60
defaults:
run:
Expand Down Expand Up @@ -214,6 +214,7 @@ jobs:
NCCL_SHM_DISABLE: 1
LD_LIBRARY_PATH: /github/home/.tensornvme/lib:/usr/local/nvidia/lib:/usr/local/nvidia/lib64
TESTMON_CORE_PKGS: /__w/ColossalAI/ColossalAI/requirements/requirements.txt,/__w/ColossalAI/ColossalAI/requirements/requirements-test.txt
LLAMA_PATH: /data/scratch/llama-tiny

- name: Store Testmon Cache
run: |
Expand Down
3 changes: 2 additions & 1 deletion .github/workflows/build_on_schedule.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ jobs:
runs-on: [self-hosted, 8-gpu]
container:
image: hpcaitech/pytorch-cuda:1.12.0-11.3.0
options: --gpus all --rm -v /data/scratch/cifar-10:/data/scratch/cifar-10
options: --gpus all --rm -v /data/scratch/cifar-10:/data/scratch/cifar-10 -v /data/scratch/llama-tiny:/data/scratch/llama-tiny
timeout-minutes: 40
steps:
- name: Check GPU Availability # ensure all GPUs have enough memory
Expand Down Expand Up @@ -64,6 +64,7 @@ jobs:
env:
DATA: /data/scratch/cifar-10
LD_LIBRARY_PATH: /github/home/.tensornvme/lib:/usr/local/nvidia/lib:/usr/local/nvidia/lib64
LLAMA_PATH: /data/scratch/llama-tiny

- name: Notify Lark
id: message-preparation
Expand Down
3 changes: 2 additions & 1 deletion .github/workflows/compatiblity_test_on_dispatch.yml
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ jobs:
matrix: ${{fromJson(needs.matrix_preparation.outputs.matrix)}}
container:
image: ${{ matrix.container }}
options: --gpus all --rm -v /data/scratch/cifar-10:/data/scratch/cifar-10
options: --gpus all --rm -v /data/scratch/cifar-10:/data/scratch/cifar-10 -v /data/scratch/llama-tiny:/data/scratch/llama-tiny
timeout-minutes: 120
steps:
- name: Install dependencies
Expand Down Expand Up @@ -92,3 +92,4 @@ jobs:
DATA: /data/scratch/cifar-10
NCCL_SHM_DISABLE: 1
LD_LIBRARY_PATH: /github/home/.tensornvme/lib:/usr/local/nvidia/lib:/usr/local/nvidia/lib64
LLAMA_PATH: /data/scratch/llama-tiny
3 changes: 2 additions & 1 deletion .github/workflows/compatiblity_test_on_pr.yml
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ jobs:
matrix: ${{fromJson(needs.matrix_preparation.outputs.matrix)}}
container:
image: ${{ matrix.container }}
options: --gpus all --rm -v /data/scratch/cifar-10:/data/scratch/cifar-10
options: --gpus all --rm -v /data/scratch/cifar-10:/data/scratch/cifar-10 -v /data/scratch/llama-tiny:/data/scratch/llama-tiny
timeout-minutes: 120
concurrency:
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}-run-test-${{ matrix.container }}
Expand Down Expand Up @@ -87,3 +87,4 @@ jobs:
DATA: /data/scratch/cifar-10
NCCL_SHM_DISABLE: 1
LD_LIBRARY_PATH: /github/home/.tensornvme/lib:/usr/local/nvidia/lib:/usr/local/nvidia/lib64
LLAMA_PATH: /data/scratch/llama-tiny
3 changes: 2 additions & 1 deletion .github/workflows/compatiblity_test_on_schedule.yml
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ jobs:
matrix: ${{fromJson(needs.matrix_preparation.outputs.matrix)}}
container:
image: ${{ matrix.container }}
options: --gpus all --rm -v /data/scratch/cifar-10:/data/scratch/cifar-10
options: --gpus all --rm -v /data/scratch/cifar-10:/data/scratch/cifar-10 -v /data/scratch/llama-tiny:/data/scratch/llama-tiny
timeout-minutes: 120
steps:
- name: Install dependencies
Expand Down Expand Up @@ -85,6 +85,7 @@ jobs:
DATA: /data/scratch/cifar-10
NCCL_SHM_DISABLE: 1
LD_LIBRARY_PATH: /github/home/.tensornvme/lib:/usr/local/nvidia/lib:/usr/local/nvidia/lib64
LLAMA_PATH: /data/scratch/llama-tiny

- name: Notify Lark
id: message-preparation
Expand Down
8 changes: 8 additions & 0 deletions colossalai/booster/booster.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
from torch.utils.data import DataLoader

import colossalai.interface.pretrained as pretrained_utils
from colossalai.checkpoint_io import GeneralCheckpointIO
from colossalai.interface import ModelWrapper, OptimizerWrapper

Expand Down Expand Up @@ -131,6 +132,7 @@ def boost(
"""
# TODO(FrankLeeeee): consider multi-model and multi-optimizer case
# TODO(FrankLeeeee): consider multi-dataloader case
pretrained_path = pretrained_utils.get_pretrained_path(model)
# transform model for mixed precision
if self.plugin:
model, optimizer, criterion, dataloader, lr_scheduler = self.plugin.configure(
Expand All @@ -146,6 +148,12 @@ def boost(
# when mixed_precision is specified and the plugin is not given or does not control the precision
model, optimizer, criterion = self.mixed_precision.configure(model, optimizer, criterion)

if pretrained_path:
self.load_model(model, pretrained_path)
# clear pretrained path attr
orig_model = model.unwrap() if isinstance(model, ModelWrapper) else model
pretrained_utils.set_pretrained_path(orig_model, None)

return model, optimizer, criterion, dataloader, lr_scheduler

def backward(self, loss: torch.Tensor, optimizer: Optimizer) -> None:
Expand Down
16 changes: 16 additions & 0 deletions colossalai/interface/pretrained.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from typing import Optional

from torch.nn import Module

__all__ = [
"get_pretrained_path",
"set_pretrained_path",
]


def get_pretrained_path(model: Module) -> Optional[str]:
return getattr(model, "_pretrained", None)


def set_pretrained_path(model: Module, path: str) -> None:
setattr(model, "_pretrained", path)
3 changes: 3 additions & 0 deletions colossalai/lazy/lazy_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from colossalai.logging import get_dist_logger

from .construction import ConstructorManager
from .pretrained import PretrainedManager

import colossalai._analyzer._subclasses._meta_registration # noqa

Expand Down Expand Up @@ -595,11 +596,13 @@ def wrapper(*args, **kwargs):
)

ConstructorManager.apply(overrides)
PretrainedManager.inject()

def __exit__(self, exc_type, exc_val, exc_tb):
self.tensor_cls.default_device = self.old_default_device
LazyInitContext._replaced = False
ConstructorManager.clear()
PretrainedManager.recover()

@staticmethod
def materialize(module: nn.Module, verbose: bool = False) -> nn.Module:
Expand Down
Loading
Loading