Skip to content

Commit 594d045

Browse files
committed
Update on "some compile-related improvements"
1. Adds a CI test for 1D compile + selective op AC, which used to fail silently. 2. The flag `torch._dynamo.config.inline_inbuilt_nn_modules` is enabled to accelerate compilation (for llama3 8b on 8 H100, compile time drops from 9+ seconds to 6+ seconds), per anijain2305's suggestion. 3. It seems per TransformerBlock compile now works without `dynamic=False` and `fullgraph=True`. It is good to reflect the progress and catch regressions, per bdhirsh's suggestion. [ghstack-poisoned]
2 parents 96d467f + 857d28d commit 594d045

38 files changed

+863
-476
lines changed

.github/workflows/integration_test_4gpu.yaml

+1
Original file line numberDiff line numberDiff line change
@@ -39,5 +39,6 @@ jobs:
3939
4040
python -m pip install --force-reinstall --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121
4141
python -m pip install --pre torchdata --index-url https://download.pytorch.org/whl/nightly/
42+
USE_CPP=0 python -m pip install git+https://github.com/pytorch/ao.git
4243
mkdir artifacts-to-be-uploaded
4344
python ./test_runner.py artifacts-to-be-uploaded --ngpu 4

README.md

+9-1
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,14 @@ Our guiding principles when building `torchtitan`:
1818

1919
[![Welcome to torchtitan!](assets/images/titan_play_video.png)](https://youtu.be/ee5DOEqD35I?si=_B94PbVv0V5ZnNKE "Welcome to torchtitan!")
2020

21+
### Dive into the code
22+
23+
You may want to see how the model is defined or how parallelism techniques are applied. For a guided tour, see these files first:
24+
* [train.py](https://github.com/pytorch/torchtitan/blob/main/train.py) - the main training loop and high-level setup code
25+
* [torchtitan/parallelisms/parallelize_llama.py](https://github.com/pytorch/torchtitan/blob/main/torchtitan/parallelisms/parallelize_llama.py) - helpers for applying Data / Tensor / Pipeline Parallelisms to the model
26+
* [torchtitan/checkpoint.py](https://github.com/pytorch/torchtitan/blob/main/torchtitan/checkpoint.py) - utils for saving/loading distributed checkpoints
27+
* [torchtitan/models/llama/model.py](https://github.com/pytorch/torchtitan/blob/main/torchtitan/models/llama/model.py) - the Llama model definition (shared for Llama2 and Llama3 variants)
28+
2129
## Pre-Release Updates:
2230
#### (4/25/2024): `torchtitan` is now public but in a pre-release state and under development.
2331
Currently we showcase pre-training **Llama 3 and Llama 2** LLMs of various sizes from scratch. `torchtitan` is tested and verified with the PyTorch nightly version `torch-2.4.0.dev20240412`. (We recommend latest PyTorch nightly).
@@ -66,7 +74,7 @@ Once you have confirmed access, you can run the following command to download th
6674
```bash
6775
# Get your HF token from https://huggingface.co/settings/tokens
6876

69-
# llama3 tokenizer.model
77+
# llama3 or 3.1 tokenizer.model
7078
python torchtitan/datasets/download_tokenizer.py --repo_id meta-llama/Meta-Llama-3-8B --tokenizer_path "original" --hf_token=...
7179

7280
# llama2 tokenizer.model

create_seed_checkpoint.sh

-2
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,6 @@
1818

1919
set -ex
2020

21-
export USE_LIBUV=1
22-
TRAINER_DIR=${1:-/home/$USER/local/torchtitan}
2321
NGPU=1
2422
LOG_RANK=0
2523
CONFIG_FILE=${CONFIG_FILE:-"./train_configs/debug_model.toml"}

estimation.py

+27-20
Original file line numberDiff line numberDiff line change
@@ -9,22 +9,19 @@
99
import os
1010

1111
import torch
12-
import torch.nn.functional as F
1312
from torch._guards import active_fake_mode
1413
from torch._subclasses.fake_tensor import FakeTensorMode
15-
from torch.distributed import destroy_process_group
1614
from torch.distributed._tools.fsdp2_mem_tracker import FSDPMemTracker
17-
from torch.distributed.tensor.parallel import loss_parallel
1815
from torch.testing._internal.distributed.fake_pg import FakeStore
1916

2017
from torchtitan.config_manager import JobConfig
21-
from torchtitan.datasets import create_tokenizer
22-
from torchtitan.float8_linear import build_fp8_linear
23-
from torchtitan.logging_utils import init_logger, logger
24-
from torchtitan.lr_scheduling import get_lr_schedulers
18+
from torchtitan.datasets import build_tokenizer
19+
from torchtitan.float8_linear import Float8Handler
20+
from torchtitan.logging import init_logger, logger
2521
from torchtitan.models import model_name_to_cls, model_name_to_tokenizer, models_config
22+
from torchtitan.optimizer import build_lr_schedulers, build_optimizers
2623
from torchtitan.parallelisms import models_parallelize_fns, ParallelDims
27-
from train import build_optimizers
24+
from train import get_train_context
2825

2926

3027
def estimate_memory(job_config: JobConfig):
@@ -61,16 +58,18 @@ def estimate_memory(job_config: JobConfig):
6158
logger.info("Compiled RMSNorm is not supported yet. Switching to RMSNorm.")
6259
job_config.model.norm_type = "rmsnorm"
6360

64-
if job_config.training.compile:
61+
if job_config.training.compile or job_config.experimental.enable_compiled_autograd:
6562
logger.info("Compile mode is not supported yet. Switching to eager mode.")
6663
job_config.training.compile = False
64+
job_config.experimental.enable_compiled_autograd = False
6765

6866
parallel_dims = ParallelDims(
6967
dp=job_config.training.data_parallel_degree,
7068
tp=job_config.training.tensor_parallel_degree,
7169
pp=job_config.experimental.pipeline_parallel_degree,
7270
world_size=world_size,
7371
enable_loss_parallel=job_config.training.enable_loss_parallel,
72+
dp_type=job_config.training.data_parallel_type,
7473
)
7574

7675
device = torch.device(f"cuda:{int(os.environ['LOCAL_RANK'])}")
@@ -93,16 +92,18 @@ def estimate_memory(job_config: JobConfig):
9392

9493
# build tokenizer
9594
tokenizer_type = model_name_to_tokenizer[model_name]
96-
tokenizer = create_tokenizer(tokenizer_type, job_config.model.tokenizer_path)
95+
tokenizer = build_tokenizer(tokenizer_type, job_config.model.tokenizer_path)
9796

98-
# loss_parallel enables dispatching to efficient loss operators
99-
loss_parallel_ctx = (
100-
loss_parallel if parallel_dims.loss_parallel_enabled else contextlib.nullcontext
97+
train_context = get_train_context(
98+
parallel_dims.loss_parallel_enabled,
99+
job_config.experimental.enable_compiled_autograd,
101100
)
102101

103102
# loss fn can be shared by pipeline-parallel or non-pp execution
104103
def loss_fn(pred, labels):
105-
return F.cross_entropy(pred.flatten(0, 1), labels.flatten(0, 1))
104+
return torch.nn.functional.cross_entropy(
105+
pred.flatten(0, 1), labels.flatten(0, 1)
106+
)
106107

107108
# build model (using meta init)
108109
model_cls = model_name_to_cls[model_name]
@@ -123,9 +124,10 @@ def loss_fn(pred, labels):
123124
with torch.device("meta"):
124125
whole_model = model_cls.from_model_args(model_config)
125126

126-
# apply fp8 linear module swap
127-
if job_config.training.fp8_linear:
128-
build_fp8_linear(whole_model, job_config)
127+
# a no-op hander if fp8 is not enabled
128+
float8_handler = Float8Handler(job_config, parallel_dims)
129+
# swap to Float8Linear base on fp8 config
130+
float8_handler.convert_to_float8_training(whole_model)
129131

130132
# apply PT-D DP/TP parallelisms and activation checkpointing
131133
model_parts = [whole_model]
@@ -143,7 +145,7 @@ def loss_fn(pred, labels):
143145

144146
# build optimizer after applying parallelisms to the model
145147
optimizers = build_optimizers(model_parts, job_config)
146-
lr_schedulers = get_lr_schedulers(optimizers.optimizers, job_config)
148+
lr_schedulers = build_lr_schedulers(optimizers.optimizers, job_config)
147149

148150
for model in model_parts:
149151
model.train()
@@ -170,7 +172,7 @@ def loss_fn(pred, labels):
170172
for iter_idx in range(2):
171173
input_ids, labels = batch
172174
# train step
173-
with loss_parallel_ctx():
175+
with train_context():
174176
pred = whole_model(input_ids)
175177
loss = loss_fn(pred, labels)
176178
del pred
@@ -181,9 +183,14 @@ def loss_fn(pred, labels):
181183
torch.nn.utils.clip_grad_norm_(
182184
model.parameters(), job_config.training.max_norm, foreach=True
183185
)
186+
# sync float8 amaxes and scales
187+
float8_handler.sync_float8_amax_and_scale_history(model)
184188
# optimizer step
185189
optimizers.step()
186190
lr_schedulers.step()
191+
# calculate float8 dynamic amax/scale for all-parameter for FSDP2
192+
# it issues a single all-reduce for all parameters at once for better performance
193+
float8_handler.precompute_fp8_dynamic_scale_for_fsdp(model)
187194
optimizers.zero_grad()
188195
print(f"Peak Memory at iter: {iter_idx}")
189196
fsdp_memtracker.display_snapshot("peak", units="MiB", tabulate=True)
@@ -217,4 +224,4 @@ def loss_fn(pred, labels):
217224
try:
218225
estimate_memory(config)
219226
finally:
220-
destroy_process_group()
227+
torch.distributed.destroy_process_group()

multinode_trainer.slurm

-1
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,6 @@ export NCCL_SOCKET_IFNAME="eth0,en,eth,em,bond"
5353
export NCCL_BUFFSIZE=2097152
5454
#export TORCH_DIST_INIT_BARRIER=1
5555
export FI_EFA_SET_CUDA_SYNC_MEMOPS=0
56-
#export USE_LIBUV=1
5756
CONFIG_FILE=${CONFIG_FILE:-"./train_configs/llama2_13b.toml"}
5857

5958
dcgmi profile --pause

run_llama_train.sh

+3-24
Original file line numberDiff line numberDiff line change
@@ -7,39 +7,18 @@
77

88
set -ex
99

10-
# libUV is a scalable backend for TCPStore which is used in processGroup
11-
# rendezvous. This is the recommended backend for distributed training.
12-
export USE_LIBUV=1
13-
TRAINER_DIR=${TRAINER_DIR:-/home/$USER/local/torchtitan}
14-
1510
# use envs as local overrides for convenience
1611
# e.g.
1712
# LOG_RANK=0,1 NGPU=4 ./run_llama_train.sh
18-
1913
NGPU=${NGPU:-"8"}
20-
NNODES=${NNODES:-"1"}
21-
22-
# by default log just rank 0 output,
2314
LOG_RANK=${LOG_RANK:-0}
24-
25-
2615
CONFIG_FILE=${CONFIG_FILE:-"./train_configs/debug_model.toml"}
2716

2817
overrides=""
2918
if [ $# -ne 0 ]; then
3019
overrides="$*"
3120
fi
3221

33-
# Check if --estimate.memory=True is in the arguments
34-
if echo "$overrides" | grep -q -- "--memory_estimation.enabled"; then
35-
# Calculate WORLD_SIZE as the product of NGPU and NNODES
36-
# Export WORLD_SIZE and LOCAL_RANK
37-
export WORLD_SIZE=$((NGPU * NNODES))
38-
export LOCAL_RANK=0
39-
python estimation.py --job.config_file ${CONFIG_FILE} $overrides
40-
else
41-
# Call train.py if not in estimation mode
42-
torchrun --nproc_per_node=${NGPU} --rdzv_backend c10d --rdzv_endpoint="localhost:0" \
43-
--local-ranks-filter ${LOG_RANK} --role rank --tee 3 \
44-
train.py --job.config_file ${CONFIG_FILE} $overrides
45-
fi
22+
torchrun --nproc_per_node=${NGPU} --rdzv_backend c10d --rdzv_endpoint="localhost:0" \
23+
--local-ranks-filter ${LOG_RANK} --role rank --tee 3 \
24+
train.py --job.config_file ${CONFIG_FILE} $overrides

run_memory_estimation.sh

+26
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
#!/usr/bin/bash
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
# All rights reserved.
4+
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
set -ex
9+
10+
# use envs as local overrides for convenience
11+
# e.g.
12+
# NGPU=4 ./run_memory_estimation.sh
13+
NGPU=${NGPU:-"8"}
14+
NNODES=${NNODES:-"1"}
15+
CONFIG_FILE=${CONFIG_FILE:-"./train_configs/debug_model.toml"}
16+
17+
overrides=""
18+
if [ $# -ne 0 ]; then
19+
overrides="$*"
20+
fi
21+
22+
# Calculate WORLD_SIZE as the product of NGPU and NNODES
23+
# Export WORLD_SIZE and LOCAL_RANK
24+
export WORLD_SIZE=$((NGPU * NNODES))
25+
export LOCAL_RANK=0
26+
python estimation.py --job.config_file ${CONFIG_FILE} --memory_estimation.enabled $overrides

test/datasets/test_checkpoint.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
import torch
88
from torchtitan.datasets.hf_datasets import build_hf_data_loader
9-
from torchtitan.datasets.tokenizer import create_tokenizer
9+
from torchtitan.datasets.tokenizer import build_tokenizer
1010

1111

1212
class TestCheckpoint:
@@ -42,7 +42,7 @@ def _build_dataloader(
4242
self, dataset_name, dataset_path, batch_size, seq_len, world_size, rank
4343
):
4444
tokenizer_type = "tiktoken"
45-
tokenizer = create_tokenizer("tiktoken", "./test/assets/test_tiktoken.model")
45+
tokenizer = build_tokenizer("tiktoken", "./test/assets/test_tiktoken.model")
4646
return build_hf_data_loader(
4747
dataset_name=dataset_name,
4848
dataset_path=dataset_path,

test_runner.py

+27
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,21 @@ def build_test_list():
4646
"""
4747
integration_tests_flavors = defaultdict(list)
4848
integration_tests_flavors["debug_model.toml"] = [
49+
OverrideDefinitions(
50+
[
51+
[
52+
"--checkpoint.enable_checkpoint",
53+
"--experimental.pipeline_parallel_degree 4",
54+
"--experimental.pipeline_parallel_split_points layers.1,layers.2,layers.3,layers.4,layers.5,layers.6,layers.7",
55+
"--experimental.pipeline_parallel_schedule flexible_interleaved_1f1b",
56+
"--model.norm_type rmsnorm", # fused_rmsnorm throws cuda context error with pp
57+
],
58+
],
59+
"PP looped flexible 1f1b test",
60+
"pp_looped_flexible_1f1b",
61+
requires_seed_checkpoint=True,
62+
ngpu=4,
63+
),
4964
OverrideDefinitions(
5065
[
5166
[
@@ -284,6 +299,16 @@ def build_test_list():
284299
"fsdp2_mem_tracker",
285300
ngpu=4,
286301
),
302+
OverrideDefinitions(
303+
[
304+
[
305+
"--training.data_parallel_type ddp",
306+
]
307+
],
308+
"DDP",
309+
"ddp",
310+
ngpu=4,
311+
),
287312
]
288313
return integration_tests_flavors
289314

@@ -315,6 +340,8 @@ def run_test(test_flavor: OverrideDefinitions, full_path: str, output_dir: str):
315340

316341
for override_arg in test_flavor.override_args:
317342
cmd = f"CONFIG_FILE={full_path} NGPU={test_flavor.ngpu} LOG_RANK={all_ranks} ./run_llama_train.sh"
343+
if test_name == "fsdp2_mem_tracker":
344+
cmd = f"CONFIG_FILE={full_path} NGPU={test_flavor.ngpu} LOG_RANK={all_ranks} ./run_memory_estimation.sh"
318345
cmd += " " + dump_folder_arg
319346
cmd += " " + model_flavor_arg
320347
if override_arg:

torchtitan/checkpoint.py

+42-3
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
import re
1111
import shutil
1212
import time
13+
from dataclasses import dataclass, field
14+
from io import BytesIO
1315
from multiprocessing import get_context
1416
from typing import Any, Dict, List, Union
1517

@@ -27,7 +29,7 @@
2729
from torch.distributed.checkpoint.stateful import Stateful
2830
from torch.utils.data import DataLoader
2931
from torchtitan.config_manager import JobConfig, TORCH_DTYPE_MAP
30-
from torchtitan.logging_utils import init_logger, logger
32+
from torchtitan.logging import init_logger, logger
3133

3234

3335
class IntervalType(enum.Enum):
@@ -41,6 +43,43 @@ class AsyncMode(str, enum.Enum):
4143
ASYNC_WITH_PINNED_MEM = "async_with_pinned_mem"
4244

4345

46+
@dataclass
47+
class TrainState(Stateful):
48+
step: int = 0
49+
global_avg_losses: List[float] = field(default_factory=list)
50+
global_max_losses: List[float] = field(default_factory=list)
51+
log_steps: List[int] = field(default_factory=list)
52+
53+
def state_dict(self) -> Dict[str, Any]:
54+
# Only checkpoint global_avg_losses and global_max_losses per log frequency
55+
# to avoid sync overhead in every iteration.
56+
global_avg_losses_bytes = BytesIO()
57+
torch.save(self.global_avg_losses, global_avg_losses_bytes)
58+
global_max_losses_bytes = BytesIO()
59+
torch.save(self.global_max_losses, global_max_losses_bytes)
60+
log_steps_bytes = BytesIO()
61+
torch.save(self.log_steps, log_steps_bytes)
62+
return {
63+
"step": torch.tensor(self.step, dtype=torch.int32),
64+
"global_avg_losses": global_avg_losses_bytes,
65+
"global_max_losses": global_max_losses_bytes,
66+
"log_steps": log_steps_bytes,
67+
}
68+
69+
def load_state_dict(self, state_dict) -> None:
70+
self.step = state_dict["step"].item()
71+
state_dict["global_avg_losses"].seek(0)
72+
self.global_avg_losses = torch.load(
73+
state_dict["global_avg_losses"], weights_only=False
74+
)
75+
state_dict["global_max_losses"].seek(0)
76+
self.global_max_losses = torch.load(
77+
state_dict["global_max_losses"], weights_only=False
78+
)
79+
state_dict["log_steps"].seek(0)
80+
self.log_steps = torch.load(state_dict["log_steps"], weights_only=False)
81+
82+
4483
class ModelWrapper(Stateful):
4584
def __init__(self, model: Union[nn.Module, List[nn.Module]]) -> None:
4685
self.model = [model] if isinstance(model, nn.Module) else model
@@ -124,10 +163,10 @@ def checkpoint_mp(recv, send):
124163
class CheckpointManager:
125164
def __init__(
126165
self,
166+
dataloader: DataLoader,
127167
model_parts: List[nn.Module],
128168
optimizers: List[torch.optim.Optimizer],
129169
lr_schedulers: List[torch.optim.lr_scheduler.LRScheduler],
130-
dataloader: DataLoader,
131170
states: Dict[str, Any],
132171
job_config: JobConfig,
133172
) -> None:
@@ -390,7 +429,7 @@ def save(self, curr_step: int, force: bool = False) -> None:
390429
f"in {time.monotonic() - begin:.2f} seconds."
391430
)
392431

393-
def wait_for_staging(self) -> None:
432+
def maybe_wait_for_staging(self) -> None:
394433
if (
395434
self.enable_checkpoint
396435
and self.async_mode == AsyncMode.ASYNC_WITH_PINNED_MEM

0 commit comments

Comments
 (0)