Skip to content

Commit

Permalink
Add meta_init, enable it as default init process (pytorch#84)
Browse files Browse the repository at this point in the history
This PR enables meta_init functionality to avoid OOM'ing on cpu for
larger models.
The core functionality is in meta_init.py, and a few changes in
parallelization and train.py.
Key items:
1 - this is largely the same as the earlier PR I had for meta_init, but
I did a new one b/c faster than reworking it with all the interim
changes.
2 - to address feedback in previous PR:
a - why do we need meta_init.py, can't we just do:
~~~
with torch.device("meta"):
    model = Model.from_args(...)
~~~
Unfortunately this does not work b/c the rope embeddings are treated
differently (buffer) and thus the simple lambda call from param_init_fn
in FSDP (lambda module: module.to_device('cuda') ) will not invoke or
move the rope embeddings and the model will fail on first forward.
This issue relates to the nn.embeddings not being moved, and that the
device is referenced in the forward pass for the current rope class.
Have opened pytorch#110 to track
this and investigate while not holding up meta init that is working from
landing.

b - per earlier feedback - meta init is now 'not optional' but simply
the default. This should ensure all models leverage it and ensure we
aren't missing things for future meta_init aspects.

3 - misc change - I switched the model_params to just do the normal all
params count instead of 'unique params' b/c it does not mesh with what
people perceive model size as.

Testing:
tested both debugmodel and 26B model with and without meta init to
confirm same loss curves.
Note for future reference - if you get a bad init (meta init failure)
you will simply not train (loss is same every iter).
If you fail to call reset params after FSDP, then you will train (b/c we
default to torch.randn_like) but your starting loss will be 5x+ higher
(telling you that you have not properly init'ed the model).
  • Loading branch information
lessw2020 authored Mar 5, 2024
1 parent 2682144 commit afbf62a
Show file tree
Hide file tree
Showing 6 changed files with 78 additions and 13 deletions.
48 changes: 48 additions & 0 deletions torchtrain/meta_init.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.

from contextlib import contextmanager

import torch
from torch import nn
from torch.distributed.fsdp._common_utils import _is_fsdp_flattened


@contextmanager
def meta_model_init():
"""init model on meta device"""
saved_register_parameter = nn.Module.register_parameter
saved_register_buffer = nn.Module.register_buffer

def register_meta_param(module, name, param):
saved_register_parameter(module, name, param)
if param is not None:
param_cls = type(module._parameters[name])
kwargs = module._parameters[name].__dict__
module._parameters[name] = param_cls(
module._parameters[name].to(torch.device("meta")), **kwargs
)

def register_meta_buffer(module, name, buffer):
saved_register_buffer(module, name, buffer)
if buffer is not None:
module._buffers[name] = module._buffers[name].to(torch.device("meta"))

try:
nn.Module.register_parameter = register_meta_param
nn.Module.register_buffer = register_meta_buffer
yield
finally:
nn.Module.register_parameter = saved_register_parameter
nn.Module.register_buffer = saved_register_buffer


@torch.no_grad()
def meta_to_real_init_fn(module: nn.Module):
for submodule in module.modules():
for param_name, param in submodule.named_parameters(recurse=False):
if not _is_fsdp_flattened(param) and param.is_meta:
materialized_param = nn.Parameter(
torch.randn_like(param, device=torch.device("cuda"))
)
setattr(submodule, param_name, materialized_param)
4 changes: 2 additions & 2 deletions torchtrain/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,8 +193,8 @@ def get_num_params(model: nn.Module, only_trainable: bool = False) -> int:
param_list = list(model.parameters())
if only_trainable:
param_list = [p for p in param_list if p.requires_grad]
unique_params = {p.data_ptr(): p for p in param_list}.values()
return sum(p.numel() for p in unique_params)
# unique_params = {p.data_ptr(): p for p in param_list}.values()
return sum(p.numel() for p in param_list)


class MetricLogger:
Expand Down
5 changes: 3 additions & 2 deletions torchtrain/models/llama/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,11 @@

llama_configs = {
"debugmodel": ModelArgs(dim=256, n_layers=2, n_heads=16),
"1B": ModelArgs(dim=1024, n_layers=16, n_heads=8),
"271M": ModelArgs(dim=1024, n_layers=16, n_heads=8),
"1B": ModelArgs(dim=2048, n_layers=18, n_heads=16),
"7B": ModelArgs(dim=4096, n_layers=32, n_heads=32),
"13B": ModelArgs(dim=5120, n_layers=40, n_heads=40),
"40B": ModelArgs(dim=5120, n_layers=80, n_heads=40),
"26B": ModelArgs(dim=5120, n_layers=80, n_heads=40),
"70B": ModelArgs(
dim=8192,
n_layers=80,
Expand Down
13 changes: 11 additions & 2 deletions torchtrain/models/llama/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,9 @@ def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.empty(dim))
self.reset_parameters()

# re-enable if not using meta-init
# self.reset_parameters()

def _norm(self, x: torch.Tensor):
"""
Expand Down Expand Up @@ -466,7 +468,14 @@ def __init__(self, model_args: ModelArgs):
self.output = nn.Linear(model_args.dim, model_args.vocab_size, bias=False)

# init model weights
self.reset_parameters()

# we are doing meta_init, which will call reset_parameters() after
# the model is moved to actual device.
# If you modify and are not using meta_init, you will need to call
# reset_parameters() manually as below:

# self.reset_parameters()

rank0_log(f"Model built with: {self.model_args}")

def reset_parameters(
Expand Down
14 changes: 9 additions & 5 deletions torchtrain/parallelisms/parallelize_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
)
from torchtrain.config_manager import JobConfig
from torchtrain.logging_utils import rank0_log
from torchtrain.meta_init import meta_to_real_init_fn

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -193,6 +194,7 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig):
if parallel_dims.dp_enabled:
dp_mesh = world_mesh["dp"] if world_mesh.ndim > 1 else world_mesh
assert dp_mesh.mesh_dim_names == ("dp",), dp_mesh.mesh_dim_names

fsdp_config = {
"mixed_precision": MixedPrecision(
param_dtype=torch.bfloat16,
Expand All @@ -204,12 +206,11 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig):
# When torch.compile is active, it requires us to set use_orig_params=True
"use_orig_params": True,
"device_mesh": dp_mesh,
"param_init_fn": meta_to_real_init_fn,
}

with enable_wrap(wrapper_cls=FSDP, **fsdp_config):
for layer_id, transformer_block in enumerate(model.layers):
# before wrapping with FSDP, we need to make sure the layer is on GPU
transformer_block = transformer_block.cuda()

# apply selective AC
transformer_block = checkpoint_wrapper(
Expand All @@ -220,10 +221,13 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig):
model.layers[layer_id] = wrap(transformer_block)

# wrap the rest layers with FSDP
model = wrap(model.cuda())
model = wrap(model)

rank0_log("Applied FSDP to the model...")
else:
model.cuda()

# redundant if FSDP is enabled, but ensure the model is on device regardless of which parallelisms were used
model.cuda()
# we have now moved from meta to device,
# reset parameters for proper initialization
model.reset_parameters()
return model
7 changes: 5 additions & 2 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from torchtrain.datasets import create_tokenizer, dataloader_fn
from torchtrain.logging_utils import init_logger, rank0_log
from torchtrain.lr_scheduling import get_lr_scheduler
from torchtrain.meta_init import meta_model_init
from torchtrain.metrics import build_metric_logger, get_num_params, GPUMemoryMonitor

from torchtrain.models import model_name_to_cls, model_name_to_tokenizer, models_config
Expand Down Expand Up @@ -115,15 +116,17 @@ def main(job_config: JobConfig):
)

# build model
# TODO: add meta initialization
model_cls = model_name_to_cls[model_name]
model_config = models_config[model_name][job_config.model.flavor]
model_config.vocab_size = tokenizer.n_words

model = model_cls.from_model_args(model_config)
# build model using meta init
with meta_model_init():
model = model_cls.from_model_args(model_config)

# log model size
model_param_count = get_num_params(model)

if _is_local_logging:
rank0_log(
f"{Color.blue}Model {model_name} {job_config.model.flavor} {Color.red}size: {model_param_count:,}"
Expand Down

0 comments on commit afbf62a

Please sign in to comment.