Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add meta_init, enable it as default init process #84

Merged
merged 12 commits into from
Mar 5, 2024
48 changes: 48 additions & 0 deletions torchtrain/meta_init.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.

from contextlib import contextmanager

import torch
from torch import nn
from torch.distributed.fsdp._common_utils import _is_fsdp_flattened


@contextmanager
def meta_model_init():
"""init model on meta device"""
saved_register_parameter = nn.Module.register_parameter
saved_register_buffer = nn.Module.register_buffer

def register_meta_param(module, name, param):
saved_register_parameter(module, name, param)
if param is not None:
param_cls = type(module._parameters[name])
kwargs = module._parameters[name].__dict__
module._parameters[name] = param_cls(
module._parameters[name].to(torch.device("meta")), **kwargs
)

def register_meta_buffer(module, name, buffer):
saved_register_buffer(module, name, buffer)
if buffer is not None:
module._buffers[name] = module._buffers[name].to(torch.device("meta"))

try:
nn.Module.register_parameter = register_meta_param
nn.Module.register_buffer = register_meta_buffer
yield
finally:
nn.Module.register_parameter = saved_register_parameter
nn.Module.register_buffer = saved_register_buffer


@torch.no_grad()
def meta_to_real_init_fn(module: nn.Module):
for submodule in module.modules():
for param_name, param in submodule.named_parameters(recurse=False):
if not _is_fsdp_flattened(param) and param.is_meta:
materialized_param = nn.Parameter(
torch.randn_like(param, device=torch.device("cuda"))
)
setattr(submodule, param_name, materialized_param)
4 changes: 2 additions & 2 deletions torchtrain/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,8 +192,8 @@ def get_num_params(model: nn.Module, only_trainable: bool = False) -> int:
param_list = list(model.parameters())
if only_trainable:
param_list = [p for p in param_list if p.requires_grad]
unique_params = {p.data_ptr(): p for p in param_list}.values()
return sum(p.numel() for p in unique_params)
# unique_params = {p.data_ptr(): p for p in param_list}.values()
return sum(p.numel() for p in param_list)
lessw2020 marked this conversation as resolved.
Show resolved Hide resolved


class MetricLogger:
Expand Down
2 changes: 1 addition & 1 deletion torchtrain/models/llama/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
"1B": ModelArgs(dim=1024, n_layers=16, n_heads=8),
"7B": ModelArgs(dim=4096, n_layers=32, n_heads=32),
"13B": ModelArgs(dim=5120, n_layers=40, n_heads=40),
"40B": ModelArgs(dim=5120, n_layers=80, n_heads=40),
"26B": ModelArgs(dim=5120, n_layers=80, n_heads=40),
lessw2020 marked this conversation as resolved.
Show resolved Hide resolved
"70B": ModelArgs(
dim=8192,
n_layers=80,
Expand Down
10 changes: 5 additions & 5 deletions torchtrain/parallelisms/parallelize_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from torchtrain.config_manager import JobConfig

from torchtrain.logging_utils import rank0_log
from torchtrain.meta_init import meta_to_real_init_fn

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -153,6 +154,7 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig):
if parallel_dims.dp_enabled:
dp_mesh = world_mesh["dp"] if world_mesh.ndim > 1 else world_mesh
assert dp_mesh.mesh_dim_names == ("dp",), dp_mesh.mesh_dim_names

fsdp_config = {
"mixed_precision": MixedPrecision(
param_dtype=torch.bfloat16,
Expand All @@ -164,23 +166,21 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig):
# When torch.compile is active, it requires us to set use_orig_params=True
"use_orig_params": True,
"device_mesh": dp_mesh,
"param_init_fn": meta_to_real_init_fn,
}

with enable_wrap(wrapper_cls=FSDP, **fsdp_config):
for layer_id, transformer_block in enumerate(model.layers):

# apply AC to each layer
# before wrapping with FSDP, we need to make sure the layer is on GPU
transformer_block = transformer_block.cuda()
transformer_block = checkpoint_wrapper(transformer_block, job_config)

# Wraps each layer with FSDP
model.layers[layer_id] = wrap(transformer_block)

# wrap the rest layers with FSDP
model = wrap(model.cuda())
model = wrap(model)

rank0_log("Applied FSDP to the model...")

# redundant if FSDP is enabled, but ensure the model is on device regardless of which parallelisms were used
model.cuda()
lessw2020 marked this conversation as resolved.
Show resolved Hide resolved
return model
11 changes: 9 additions & 2 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from torchtrain.datasets import create_tokenizer, dataloader_fn
from torchtrain.logging_utils import init_logger, rank0_log
from torchtrain.lr_scheduling import get_lr_scheduler
from torchtrain.meta_init import meta_model_init
from torchtrain.metrics import build_metric_logger, get_num_params, GPUMemoryMonitor

from torchtrain.models import model_name_to_cls, model_name_to_tokenizer, models_config
Expand Down Expand Up @@ -108,18 +109,20 @@ def main(job_config: JobConfig):
)

# build model
# TODO: add meta initialization
model_cls = model_name_to_cls[model_name]
model_config = models_config[model_name][job_config.model.flavor]
model_config.vocab_size = tokenizer.n_words

model = model_cls.from_model_args(model_config)
# build model using meta init
with meta_model_init():
model = model_cls.from_model_args(model_config)

# log model size
model_param_count = get_num_params(model)
rank0_log(
f"Model {model_name} {job_config.model.flavor} size: {model_param_count:,} total parameters"
)

gpu_metrics = GPUMemoryMonitor("cuda")
rank0_log(f"GPU memory usage: {gpu_metrics}")

Expand All @@ -128,6 +131,10 @@ def main(job_config: JobConfig):
model, world_mesh, parallel_dims, job_config
)

# we have now moved from meta to device,
# reset parameters for proper initialization
model.reset_parameters()
lessw2020 marked this conversation as resolved.
Show resolved Hide resolved

# to use FSDP-customized gradient scaler and gradient clipping solutions
assert isinstance(model, FSDP)

Expand Down
39 changes: 39 additions & 0 deletions train_configs/llama7B.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# TorchTrain Config.toml
[job]
dump_folder = "./outputs"

[profiling]
run_profiler = false
save_traces_folder = "profiling/traces"
# profiling frequency - example: 10 means every 10th iter will be profiled
profile_every_x_iter = 10

[metrics]
enable_tensorboard = false
save_tb_folder = "tb"
log_freq = 10

[model]
name = "llama"
flavor = "7B"
tokenizer_path = "./torchtrain/datasets/tokenizer/tokenizer.model"

[optimizer]
name = "AdamW"
lr = 8e-4


[training]
batch_size = 8
seq_len = 2048
warmup_pct = 0.20 # lr scheduler warm up
max_norm = 1.0 # grad norm clipping
steps = 10
data_parallel_degree = -1
sequence_parallel_degree = 1
pipeline_parallel_degree = 1
compile = false
checkpoint_interval = 3600
checkpoint_interval_type = "steps"
checkpoint_folder = ""
dataset = "alpaca"
Loading