From 6529af1290fcc662444d57a759fe79f286f8ea39 Mon Sep 17 00:00:00 2001 From: Less Wright Date: Mon, 4 Mar 2024 21:50:58 -0800 Subject: [PATCH] Add meta_init, enable it as default init process (#84) This PR enables meta_init functionality to avoid OOM'ing on cpu for larger models. The core functionality is in meta_init.py, and a few changes in parallelization and train.py. Key items: 1 - this is largely the same as the earlier PR I had for meta_init, but I did a new one b/c faster than reworking it with all the interim changes. 2 - to address feedback in previous PR: a - why do we need meta_init.py, can't we just do: ~~~ with torch.device("meta"): model = Model.from_args(...) ~~~ Unfortunately this does not work b/c the rope embeddings are treated differently (buffer) and thus the simple lambda call from param_init_fn in FSDP (lambda module: module.to_device('cuda') ) will not invoke or move the rope embeddings and the model will fail on first forward. This issue relates to the nn.embeddings not being moved, and that the device is referenced in the forward pass for the current rope class. Have opened https://github.com/pytorch/torchtrain/issues/110 to track this and investigate while not holding up meta init that is working from landing. b - per earlier feedback - meta init is now 'not optional' but simply the default. This should ensure all models leverage it and ensure we aren't missing things for future meta_init aspects. 3 - misc change - I switched the model_params to just do the normal all params count instead of 'unique params' b/c it does not mesh with what people perceive model size as. Testing: tested both debugmodel and 26B model with and without meta init to confirm same loss curves. Note for future reference - if you get a bad init (meta init failure) you will simply not train (loss is same every iter). If you fail to call reset params after FSDP, then you will train (b/c we default to torch.randn_like) but your starting loss will be 5x+ higher (telling you that you have not properly init'ed the model). --- torchtrain/meta_init.py | 48 ++++++++++++++++++++ torchtrain/metrics.py | 4 +- torchtrain/models/llama/__init__.py | 5 +- torchtrain/models/llama/model.py | 13 +++++- torchtrain/parallelisms/parallelize_llama.py | 14 ++++-- train.py | 7 ++- 6 files changed, 78 insertions(+), 13 deletions(-) create mode 100644 torchtrain/meta_init.py diff --git a/torchtrain/meta_init.py b/torchtrain/meta_init.py new file mode 100644 index 00000000..d67e6ef7 --- /dev/null +++ b/torchtrain/meta_init.py @@ -0,0 +1,48 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. + +from contextlib import contextmanager + +import torch +from torch import nn +from torch.distributed.fsdp._common_utils import _is_fsdp_flattened + + +@contextmanager +def meta_model_init(): + """init model on meta device""" + saved_register_parameter = nn.Module.register_parameter + saved_register_buffer = nn.Module.register_buffer + + def register_meta_param(module, name, param): + saved_register_parameter(module, name, param) + if param is not None: + param_cls = type(module._parameters[name]) + kwargs = module._parameters[name].__dict__ + module._parameters[name] = param_cls( + module._parameters[name].to(torch.device("meta")), **kwargs + ) + + def register_meta_buffer(module, name, buffer): + saved_register_buffer(module, name, buffer) + if buffer is not None: + module._buffers[name] = module._buffers[name].to(torch.device("meta")) + + try: + nn.Module.register_parameter = register_meta_param + nn.Module.register_buffer = register_meta_buffer + yield + finally: + nn.Module.register_parameter = saved_register_parameter + nn.Module.register_buffer = saved_register_buffer + + +@torch.no_grad() +def meta_to_real_init_fn(module: nn.Module): + for submodule in module.modules(): + for param_name, param in submodule.named_parameters(recurse=False): + if not _is_fsdp_flattened(param) and param.is_meta: + materialized_param = nn.Parameter( + torch.randn_like(param, device=torch.device("cuda")) + ) + setattr(submodule, param_name, materialized_param) diff --git a/torchtrain/metrics.py b/torchtrain/metrics.py index d56d80a3..91a1e184 100644 --- a/torchtrain/metrics.py +++ b/torchtrain/metrics.py @@ -193,8 +193,8 @@ def get_num_params(model: nn.Module, only_trainable: bool = False) -> int: param_list = list(model.parameters()) if only_trainable: param_list = [p for p in param_list if p.requires_grad] - unique_params = {p.data_ptr(): p for p in param_list}.values() - return sum(p.numel() for p in unique_params) + # unique_params = {p.data_ptr(): p for p in param_list}.values() + return sum(p.numel() for p in param_list) class MetricLogger: diff --git a/torchtrain/models/llama/__init__.py b/torchtrain/models/llama/__init__.py index c1f87f89..e6175ca9 100644 --- a/torchtrain/models/llama/__init__.py +++ b/torchtrain/models/llama/__init__.py @@ -7,10 +7,11 @@ llama_configs = { "debugmodel": ModelArgs(dim=256, n_layers=2, n_heads=16), - "1B": ModelArgs(dim=1024, n_layers=16, n_heads=8), + "271M": ModelArgs(dim=1024, n_layers=16, n_heads=8), + "1B": ModelArgs(dim=2048, n_layers=18, n_heads=16), "7B": ModelArgs(dim=4096, n_layers=32, n_heads=32), "13B": ModelArgs(dim=5120, n_layers=40, n_heads=40), - "40B": ModelArgs(dim=5120, n_layers=80, n_heads=40), + "26B": ModelArgs(dim=5120, n_layers=80, n_heads=40), "70B": ModelArgs( dim=8192, n_layers=80, diff --git a/torchtrain/models/llama/model.py b/torchtrain/models/llama/model.py index 2cd81a6c..1ba505cf 100644 --- a/torchtrain/models/llama/model.py +++ b/torchtrain/models/llama/model.py @@ -47,7 +47,9 @@ def __init__(self, dim: int, eps: float = 1e-6): super().__init__() self.eps = eps self.weight = nn.Parameter(torch.empty(dim)) - self.reset_parameters() + + # re-enable if not using meta-init + # self.reset_parameters() def _norm(self, x: torch.Tensor): """ @@ -466,7 +468,14 @@ def __init__(self, model_args: ModelArgs): self.output = nn.Linear(model_args.dim, model_args.vocab_size, bias=False) # init model weights - self.reset_parameters() + + # we are doing meta_init, which will call reset_parameters() after + # the model is moved to actual device. + # If you modify and are not using meta_init, you will need to call + # reset_parameters() manually as below: + + # self.reset_parameters() + rank0_log(f"Model built with: {self.model_args}") def reset_parameters( diff --git a/torchtrain/parallelisms/parallelize_llama.py b/torchtrain/parallelisms/parallelize_llama.py index 698079a6..d11fac9f 100644 --- a/torchtrain/parallelisms/parallelize_llama.py +++ b/torchtrain/parallelisms/parallelize_llama.py @@ -35,6 +35,7 @@ ) from torchtrain.config_manager import JobConfig from torchtrain.logging_utils import rank0_log +from torchtrain.meta_init import meta_to_real_init_fn logger = logging.getLogger(__name__) @@ -193,6 +194,7 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig): if parallel_dims.dp_enabled: dp_mesh = world_mesh["dp"] if world_mesh.ndim > 1 else world_mesh assert dp_mesh.mesh_dim_names == ("dp",), dp_mesh.mesh_dim_names + fsdp_config = { "mixed_precision": MixedPrecision( param_dtype=torch.bfloat16, @@ -204,12 +206,11 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig): # When torch.compile is active, it requires us to set use_orig_params=True "use_orig_params": True, "device_mesh": dp_mesh, + "param_init_fn": meta_to_real_init_fn, } with enable_wrap(wrapper_cls=FSDP, **fsdp_config): for layer_id, transformer_block in enumerate(model.layers): - # before wrapping with FSDP, we need to make sure the layer is on GPU - transformer_block = transformer_block.cuda() # apply selective AC transformer_block = checkpoint_wrapper( @@ -220,10 +221,13 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig): model.layers[layer_id] = wrap(transformer_block) # wrap the rest layers with FSDP - model = wrap(model.cuda()) + model = wrap(model) rank0_log("Applied FSDP to the model...") + else: + model.cuda() - # redundant if FSDP is enabled, but ensure the model is on device regardless of which parallelisms were used - model.cuda() + # we have now moved from meta to device, + # reset parameters for proper initialization + model.reset_parameters() return model diff --git a/train.py b/train.py index 3d4c3ae2..9c8e2f7b 100644 --- a/train.py +++ b/train.py @@ -22,6 +22,7 @@ from torchtrain.datasets import create_tokenizer, dataloader_fn from torchtrain.logging_utils import init_logger, rank0_log from torchtrain.lr_scheduling import get_lr_scheduler +from torchtrain.meta_init import meta_model_init from torchtrain.metrics import build_metric_logger, get_num_params, GPUMemoryMonitor from torchtrain.models import model_name_to_cls, model_name_to_tokenizer, models_config @@ -115,15 +116,17 @@ def main(job_config: JobConfig): ) # build model - # TODO: add meta initialization model_cls = model_name_to_cls[model_name] model_config = models_config[model_name][job_config.model.flavor] model_config.vocab_size = tokenizer.n_words - model = model_cls.from_model_args(model_config) + # build model using meta init + with meta_model_init(): + model = model_cls.from_model_args(model_config) # log model size model_param_count = get_num_params(model) + if _is_local_logging: rank0_log( f"{Color.blue}Model {model_name} {job_config.model.flavor} {Color.red}size: {model_param_count:,}"