-
Notifications
You must be signed in to change notification settings - Fork 259
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add meta_init, enable it as default init process (#84)
This PR enables meta_init functionality to avoid OOM'ing on cpu for larger models. The core functionality is in meta_init.py, and a few changes in parallelization and train.py. Key items: 1 - this is largely the same as the earlier PR I had for meta_init, but I did a new one b/c faster than reworking it with all the interim changes. 2 - to address feedback in previous PR: a - why do we need meta_init.py, can't we just do: ~~~ with torch.device("meta"): model = Model.from_args(...) ~~~ Unfortunately this does not work b/c the rope embeddings are treated differently (buffer) and thus the simple lambda call from param_init_fn in FSDP (lambda module: module.to_device('cuda') ) will not invoke or move the rope embeddings and the model will fail on first forward. This issue relates to the nn.embeddings not being moved, and that the device is referenced in the forward pass for the current rope class. Have opened #110 to track this and investigate while not holding up meta init that is working from landing. b - per earlier feedback - meta init is now 'not optional' but simply the default. This should ensure all models leverage it and ensure we aren't missing things for future meta_init aspects. 3 - misc change - I switched the model_params to just do the normal all params count instead of 'unique params' b/c it does not mesh with what people perceive model size as. Testing: tested both debugmodel and 26B model with and without meta init to confirm same loss curves. Note for future reference - if you get a bad init (meta init failure) you will simply not train (loss is same every iter). If you fail to call reset params after FSDP, then you will train (b/c we default to torch.randn_like) but your starting loss will be 5x+ higher (telling you that you have not properly init'ed the model).
- Loading branch information
Showing
6 changed files
with
78 additions
and
13 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. | ||
|
||
from contextlib import contextmanager | ||
|
||
import torch | ||
from torch import nn | ||
from torch.distributed.fsdp._common_utils import _is_fsdp_flattened | ||
|
||
|
||
@contextmanager | ||
def meta_model_init(): | ||
"""init model on meta device""" | ||
saved_register_parameter = nn.Module.register_parameter | ||
saved_register_buffer = nn.Module.register_buffer | ||
|
||
def register_meta_param(module, name, param): | ||
saved_register_parameter(module, name, param) | ||
if param is not None: | ||
param_cls = type(module._parameters[name]) | ||
kwargs = module._parameters[name].__dict__ | ||
module._parameters[name] = param_cls( | ||
module._parameters[name].to(torch.device("meta")), **kwargs | ||
) | ||
|
||
def register_meta_buffer(module, name, buffer): | ||
saved_register_buffer(module, name, buffer) | ||
if buffer is not None: | ||
module._buffers[name] = module._buffers[name].to(torch.device("meta")) | ||
|
||
try: | ||
nn.Module.register_parameter = register_meta_param | ||
nn.Module.register_buffer = register_meta_buffer | ||
yield | ||
finally: | ||
nn.Module.register_parameter = saved_register_parameter | ||
nn.Module.register_buffer = saved_register_buffer | ||
|
||
|
||
@torch.no_grad() | ||
def meta_to_real_init_fn(module: nn.Module): | ||
for submodule in module.modules(): | ||
for param_name, param in submodule.named_parameters(recurse=False): | ||
if not _is_fsdp_flattened(param) and param.is_meta: | ||
materialized_param = nn.Parameter( | ||
torch.randn_like(param, device=torch.device("cuda")) | ||
) | ||
setattr(submodule, param_name, materialized_param) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters