Skip to content

Commit

Permalink
Update base for Update on "some compile-related improvements"
Browse files Browse the repository at this point in the history
1. Adds a CI test for 1D compile + selective op AC, which used to fail silently.
2. The flag `torch._dynamo.config.inline_inbuilt_nn_modules` is enabled to accelerate compilation (for llama3 8b on 8 H100, compile time drops from 9+ seconds to 6+ seconds), per anijain2305's suggestion.
3. It seems per TransformerBlock compile now works without `dynamic=False` and `fullgraph=True`. It is good to reflect the progress and catch regressions, per bdhirsh's suggestion.



[ghstack-poisoned]
  • Loading branch information
tianyu-l committed Aug 2, 2024
2 parents 7838b6f + 72a1614 commit 857d28d
Show file tree
Hide file tree
Showing 38 changed files with 859 additions and 475 deletions.
1 change: 1 addition & 0 deletions .github/workflows/integration_test_4gpu.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -39,5 +39,6 @@ jobs:
python -m pip install --force-reinstall --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121
python -m pip install --pre torchdata --index-url https://download.pytorch.org/whl/nightly/
USE_CPP=0 python -m pip install git+https://github.com/pytorch/ao.git
mkdir artifacts-to-be-uploaded
python ./test_runner.py artifacts-to-be-uploaded --ngpu 4
10 changes: 9 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,14 @@ Our guiding principles when building `torchtitan`:

[![Welcome to torchtitan!](assets/images/titan_play_video.png)](https://youtu.be/ee5DOEqD35I?si=_B94PbVv0V5ZnNKE "Welcome to torchtitan!")

### Dive into the code

You may want to see how the model is defined or how parallelism techniques are applied. For a guided tour, see these files first:
* [train.py](https://github.com/pytorch/torchtitan/blob/main/train.py) - the main training loop and high-level setup code
* [torchtitan/parallelisms/parallelize_llama.py](https://github.com/pytorch/torchtitan/blob/main/torchtitan/parallelisms/parallelize_llama.py) - helpers for applying Data / Tensor / Pipeline Parallelisms to the model
* [torchtitan/checkpoint.py](https://github.com/pytorch/torchtitan/blob/main/torchtitan/checkpoint.py) - utils for saving/loading distributed checkpoints
* [torchtitan/models/llama/model.py](https://github.com/pytorch/torchtitan/blob/main/torchtitan/models/llama/model.py) - the Llama model definition (shared for Llama2 and Llama3 variants)

## Pre-Release Updates:
#### (4/25/2024): `torchtitan` is now public but in a pre-release state and under development.
Currently we showcase pre-training **Llama 3 and Llama 2** LLMs of various sizes from scratch. `torchtitan` is tested and verified with the PyTorch nightly version `torch-2.4.0.dev20240412`. (We recommend latest PyTorch nightly).
Expand Down Expand Up @@ -66,7 +74,7 @@ Once you have confirmed access, you can run the following command to download th
```bash
# Get your HF token from https://huggingface.co/settings/tokens

# llama3 tokenizer.model
# llama3 or 3.1 tokenizer.model
python torchtitan/datasets/download_tokenizer.py --repo_id meta-llama/Meta-Llama-3-8B --tokenizer_path "original" --hf_token=...

# llama2 tokenizer.model
Expand Down
2 changes: 0 additions & 2 deletions create_seed_checkpoint.sh
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@

set -ex

export USE_LIBUV=1
TRAINER_DIR=${1:-/home/$USER/local/torchtitan}
NGPU=1
LOG_RANK=0
CONFIG_FILE=${CONFIG_FILE:-"./train_configs/debug_model.toml"}
Expand Down
47 changes: 27 additions & 20 deletions estimation.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,22 +9,19 @@
import os

import torch
import torch.nn.functional as F
from torch._guards import active_fake_mode
from torch._subclasses.fake_tensor import FakeTensorMode
from torch.distributed import destroy_process_group
from torch.distributed._tools.fsdp2_mem_tracker import FSDPMemTracker
from torch.distributed.tensor.parallel import loss_parallel
from torch.testing._internal.distributed.fake_pg import FakeStore

from torchtitan.config_manager import JobConfig
from torchtitan.datasets import create_tokenizer
from torchtitan.float8_linear import build_fp8_linear
from torchtitan.logging_utils import init_logger, logger
from torchtitan.lr_scheduling import get_lr_schedulers
from torchtitan.datasets import build_tokenizer
from torchtitan.float8_linear import Float8Handler
from torchtitan.logging import init_logger, logger
from torchtitan.models import model_name_to_cls, model_name_to_tokenizer, models_config
from torchtitan.optimizer import build_lr_schedulers, build_optimizers
from torchtitan.parallelisms import models_parallelize_fns, ParallelDims
from train import build_optimizers
from train import get_train_context


def estimate_memory(job_config: JobConfig):
Expand Down Expand Up @@ -61,16 +58,18 @@ def estimate_memory(job_config: JobConfig):
logger.info("Compiled RMSNorm is not supported yet. Switching to RMSNorm.")
job_config.model.norm_type = "rmsnorm"

if job_config.training.compile:
if job_config.training.compile or job_config.experimental.enable_compiled_autograd:
logger.info("Compile mode is not supported yet. Switching to eager mode.")
job_config.training.compile = False
job_config.experimental.enable_compiled_autograd = False

parallel_dims = ParallelDims(
dp=job_config.training.data_parallel_degree,
tp=job_config.training.tensor_parallel_degree,
pp=job_config.experimental.pipeline_parallel_degree,
world_size=world_size,
enable_loss_parallel=job_config.training.enable_loss_parallel,
dp_type=job_config.training.data_parallel_type,
)

device = torch.device(f"cuda:{int(os.environ['LOCAL_RANK'])}")
Expand All @@ -93,16 +92,18 @@ def estimate_memory(job_config: JobConfig):

# build tokenizer
tokenizer_type = model_name_to_tokenizer[model_name]
tokenizer = create_tokenizer(tokenizer_type, job_config.model.tokenizer_path)
tokenizer = build_tokenizer(tokenizer_type, job_config.model.tokenizer_path)

# loss_parallel enables dispatching to efficient loss operators
loss_parallel_ctx = (
loss_parallel if parallel_dims.loss_parallel_enabled else contextlib.nullcontext
train_context = get_train_context(
parallel_dims.loss_parallel_enabled,
job_config.experimental.enable_compiled_autograd,
)

# loss fn can be shared by pipeline-parallel or non-pp execution
def loss_fn(pred, labels):
return F.cross_entropy(pred.flatten(0, 1), labels.flatten(0, 1))
return torch.nn.functional.cross_entropy(
pred.flatten(0, 1), labels.flatten(0, 1)
)

# build model (using meta init)
model_cls = model_name_to_cls[model_name]
Expand All @@ -123,9 +124,10 @@ def loss_fn(pred, labels):
with torch.device("meta"):
whole_model = model_cls.from_model_args(model_config)

# apply fp8 linear module swap
if job_config.training.fp8_linear:
build_fp8_linear(whole_model, job_config)
# a no-op hander if fp8 is not enabled
float8_handler = Float8Handler(job_config, parallel_dims)
# swap to Float8Linear base on fp8 config
float8_handler.convert_to_float8_training(whole_model)

# apply PT-D DP/TP parallelisms and activation checkpointing
model_parts = [whole_model]
Expand All @@ -143,7 +145,7 @@ def loss_fn(pred, labels):

# build optimizer after applying parallelisms to the model
optimizers = build_optimizers(model_parts, job_config)
lr_schedulers = get_lr_schedulers(optimizers.optimizers, job_config)
lr_schedulers = build_lr_schedulers(optimizers.optimizers, job_config)

for model in model_parts:
model.train()
Expand All @@ -170,7 +172,7 @@ def loss_fn(pred, labels):
for iter_idx in range(2):
input_ids, labels = batch
# train step
with loss_parallel_ctx():
with train_context():
pred = whole_model(input_ids)
loss = loss_fn(pred, labels)
del pred
Expand All @@ -181,9 +183,14 @@ def loss_fn(pred, labels):
torch.nn.utils.clip_grad_norm_(
model.parameters(), job_config.training.max_norm, foreach=True
)
# sync float8 amaxes and scales
float8_handler.sync_float8_amax_and_scale_history(model)
# optimizer step
optimizers.step()
lr_schedulers.step()
# calculate float8 dynamic amax/scale for all-parameter for FSDP2
# it issues a single all-reduce for all parameters at once for better performance
float8_handler.precompute_fp8_dynamic_scale_for_fsdp(model)
optimizers.zero_grad()
print(f"Peak Memory at iter: {iter_idx}")
fsdp_memtracker.display_snapshot("peak", units="MiB", tabulate=True)
Expand Down Expand Up @@ -217,4 +224,4 @@ def loss_fn(pred, labels):
try:
estimate_memory(config)
finally:
destroy_process_group()
torch.distributed.destroy_process_group()
1 change: 0 additions & 1 deletion multinode_trainer.slurm
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@ export NCCL_SOCKET_IFNAME="eth0,en,eth,em,bond"
export NCCL_BUFFSIZE=2097152
#export TORCH_DIST_INIT_BARRIER=1
export FI_EFA_SET_CUDA_SYNC_MEMOPS=0
#export USE_LIBUV=1
CONFIG_FILE=${CONFIG_FILE:-"./train_configs/llama2_13b.toml"}

dcgmi profile --pause
Expand Down
27 changes: 3 additions & 24 deletions run_llama_train.sh
Original file line number Diff line number Diff line change
Expand Up @@ -7,39 +7,18 @@

set -ex

# libUV is a scalable backend for TCPStore which is used in processGroup
# rendezvous. This is the recommended backend for distributed training.
export USE_LIBUV=1
TRAINER_DIR=${TRAINER_DIR:-/home/$USER/local/torchtitan}

# use envs as local overrides for convenience
# e.g.
# LOG_RANK=0,1 NGPU=4 ./run_llama_train.sh

NGPU=${NGPU:-"8"}
NNODES=${NNODES:-"1"}

# by default log just rank 0 output,
LOG_RANK=${LOG_RANK:-0}


CONFIG_FILE=${CONFIG_FILE:-"./train_configs/debug_model.toml"}

overrides=""
if [ $# -ne 0 ]; then
overrides="$*"
fi

# Check if --estimate.memory=True is in the arguments
if echo "$overrides" | grep -q -- "--memory_estimation.enabled"; then
# Calculate WORLD_SIZE as the product of NGPU and NNODES
# Export WORLD_SIZE and LOCAL_RANK
export WORLD_SIZE=$((NGPU * NNODES))
export LOCAL_RANK=0
python estimation.py --job.config_file ${CONFIG_FILE} $overrides
else
# Call train.py if not in estimation mode
torchrun --nproc_per_node=${NGPU} --rdzv_backend c10d --rdzv_endpoint="localhost:0" \
--local-ranks-filter ${LOG_RANK} --role rank --tee 3 \
train.py --job.config_file ${CONFIG_FILE} $overrides
fi
torchrun --nproc_per_node=${NGPU} --rdzv_backend c10d --rdzv_endpoint="localhost:0" \
--local-ranks-filter ${LOG_RANK} --role rank --tee 3 \
train.py --job.config_file ${CONFIG_FILE} $overrides
26 changes: 26 additions & 0 deletions run_memory_estimation.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
#!/usr/bin/bash
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.

# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

set -ex

# use envs as local overrides for convenience
# e.g.
# NGPU=4 ./run_memory_estimation.sh
NGPU=${NGPU:-"8"}
NNODES=${NNODES:-"1"}
CONFIG_FILE=${CONFIG_FILE:-"./train_configs/debug_model.toml"}

overrides=""
if [ $# -ne 0 ]; then
overrides="$*"
fi

# Calculate WORLD_SIZE as the product of NGPU and NNODES
# Export WORLD_SIZE and LOCAL_RANK
export WORLD_SIZE=$((NGPU * NNODES))
export LOCAL_RANK=0
python estimation.py --job.config_file ${CONFIG_FILE} --memory_estimation.enabled $overrides
4 changes: 2 additions & 2 deletions test/datasets/test_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import torch
from torchtitan.datasets.hf_datasets import build_hf_data_loader
from torchtitan.datasets.tokenizer import create_tokenizer
from torchtitan.datasets.tokenizer import build_tokenizer


class TestCheckpoint:
Expand Down Expand Up @@ -42,7 +42,7 @@ def _build_dataloader(
self, dataset_name, dataset_path, batch_size, seq_len, world_size, rank
):
tokenizer_type = "tiktoken"
tokenizer = create_tokenizer("tiktoken", "./test/assets/test_tiktoken.model")
tokenizer = build_tokenizer("tiktoken", "./test/assets/test_tiktoken.model")
return build_hf_data_loader(
dataset_name=dataset_name,
dataset_path=dataset_path,
Expand Down
27 changes: 27 additions & 0 deletions test_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,21 @@ def build_test_list():
"""
integration_tests_flavors = defaultdict(list)
integration_tests_flavors["debug_model.toml"] = [
OverrideDefinitions(
[
[
"--checkpoint.enable_checkpoint",
"--experimental.pipeline_parallel_degree 4",
"--experimental.pipeline_parallel_split_points layers.1,layers.2,layers.3,layers.4,layers.5,layers.6,layers.7",
"--experimental.pipeline_parallel_schedule flexible_interleaved_1f1b",
"--model.norm_type rmsnorm", # fused_rmsnorm throws cuda context error with pp
],
],
"PP looped flexible 1f1b test",
"pp_looped_flexible_1f1b",
requires_seed_checkpoint=True,
ngpu=4,
),
OverrideDefinitions(
[
[
Expand Down Expand Up @@ -273,6 +288,16 @@ def build_test_list():
"fsdp2_mem_tracker",
ngpu=4,
),
OverrideDefinitions(
[
[
"--training.data_parallel_type ddp",
]
],
"DDP",
"ddp",
ngpu=4,
),
]
return integration_tests_flavors

Expand Down Expand Up @@ -304,6 +329,8 @@ def run_test(test_flavor: OverrideDefinitions, full_path: str, output_dir: str):

for override_arg in test_flavor.override_args:
cmd = f"CONFIG_FILE={full_path} NGPU={test_flavor.ngpu} LOG_RANK={all_ranks} ./run_llama_train.sh"
if test_name == "fsdp2_mem_tracker":
cmd = f"CONFIG_FILE={full_path} NGPU={test_flavor.ngpu} LOG_RANK={all_ranks} ./run_memory_estimation.sh"
cmd += " " + dump_folder_arg
cmd += " " + model_flavor_arg
if override_arg:
Expand Down
45 changes: 42 additions & 3 deletions torchtitan/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
import re
import shutil
import time
from dataclasses import dataclass, field
from io import BytesIO
from multiprocessing import get_context
from typing import Any, Dict, List, Union

Expand All @@ -27,7 +29,7 @@
from torch.distributed.checkpoint.stateful import Stateful
from torch.utils.data import DataLoader
from torchtitan.config_manager import JobConfig, TORCH_DTYPE_MAP
from torchtitan.logging_utils import init_logger, logger
from torchtitan.logging import init_logger, logger


class IntervalType(enum.Enum):
Expand All @@ -41,6 +43,43 @@ class AsyncMode(str, enum.Enum):
ASYNC_WITH_PINNED_MEM = "async_with_pinned_mem"


@dataclass
class TrainState(Stateful):
step: int = 0
global_avg_losses: List[float] = field(default_factory=list)
global_max_losses: List[float] = field(default_factory=list)
log_steps: List[int] = field(default_factory=list)

def state_dict(self) -> Dict[str, Any]:
# Only checkpoint global_avg_losses and global_max_losses per log frequency
# to avoid sync overhead in every iteration.
global_avg_losses_bytes = BytesIO()
torch.save(self.global_avg_losses, global_avg_losses_bytes)
global_max_losses_bytes = BytesIO()
torch.save(self.global_max_losses, global_max_losses_bytes)
log_steps_bytes = BytesIO()
torch.save(self.log_steps, log_steps_bytes)
return {
"step": torch.tensor(self.step, dtype=torch.int32),
"global_avg_losses": global_avg_losses_bytes,
"global_max_losses": global_max_losses_bytes,
"log_steps": log_steps_bytes,
}

def load_state_dict(self, state_dict) -> None:
self.step = state_dict["step"].item()
state_dict["global_avg_losses"].seek(0)
self.global_avg_losses = torch.load(
state_dict["global_avg_losses"], weights_only=False
)
state_dict["global_max_losses"].seek(0)
self.global_max_losses = torch.load(
state_dict["global_max_losses"], weights_only=False
)
state_dict["log_steps"].seek(0)
self.log_steps = torch.load(state_dict["log_steps"], weights_only=False)


class ModelWrapper(Stateful):
def __init__(self, model: Union[nn.Module, List[nn.Module]]) -> None:
self.model = [model] if isinstance(model, nn.Module) else model
Expand Down Expand Up @@ -124,10 +163,10 @@ def checkpoint_mp(recv, send):
class CheckpointManager:
def __init__(
self,
dataloader: DataLoader,
model_parts: List[nn.Module],
optimizers: List[torch.optim.Optimizer],
lr_schedulers: List[torch.optim.lr_scheduler.LRScheduler],
dataloader: DataLoader,
states: Dict[str, Any],
job_config: JobConfig,
) -> None:
Expand Down Expand Up @@ -390,7 +429,7 @@ def save(self, curr_step: int, force: bool = False) -> None:
f"in {time.monotonic() - begin:.2f} seconds."
)

def wait_for_staging(self) -> None:
def maybe_wait_for_staging(self) -> None:
if (
self.enable_checkpoint
and self.async_mode == AsyncMode.ASYNC_WITH_PINNED_MEM
Expand Down
Loading

0 comments on commit 857d28d

Please sign in to comment.