Skip to content

Commit 8849580

Browse files
committed
[BE][4/n] split pipeline_llama into a separate file
ghstack-source-id: 5ebb4adf3152f413fa33a923c272c9aa3ce1f775 Pull Request resolved: #499
1 parent c44cca0 commit 8849580

9 files changed

+474
-462
lines changed

README.md

+4-2
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,10 @@ Our guiding principles when building `torchtitan`:
2222

2323
You may want to see how the model is defined or how parallelism techniques are applied. For a guided tour, see these files first:
2424
* [train.py](https://github.com/pytorch/torchtitan/blob/main/train.py) - the main training loop and high-level setup code
25-
* [torchtitan/parallelisms/parallelize_llama.py](https://github.com/pytorch/torchtitan/blob/main/torchtitan/parallelisms/parallelize_llama.py) - helpers for applying Data / Tensor / Pipeline Parallelisms to the model
25+
* [torchtitan/parallelisms/parallelize_llama.py](https://github.com/pytorch/torchtitan/blob/main/torchtitan/parallelisms/parallelize_llama.py) - helpers for applying Data Parallel, Tensor Parallel, activation checkpointing, and `torch.compile` to the model
26+
* [torchtitan/parallelisms/pipeline_llama.py](https://github.com/pytorch/torchtitan/blob/main/torchtitan/parallelisms/pipeline_llama.py) - helpers for applying Pipeline Parallel to the model
2627
* [torchtitan/checkpoint.py](https://github.com/pytorch/torchtitan/blob/main/torchtitan/checkpoint.py) - utils for saving/loading distributed checkpoints
28+
* [torchtitan/float8.py](https://github.com/pytorch/torchtitan/blob/main/torchtitan/float8.py) - utils for applying Float8 techniques
2729
* [torchtitan/models/llama/model.py](https://github.com/pytorch/torchtitan/blob/main/torchtitan/models/llama/model.py) - the Llama model definition (shared for Llama2 and Llama3 variants)
2830

2931
## Pre-Release Updates:
@@ -48,7 +50,7 @@ We report our [Performance](docs/performance.md) verified on 64 A100 GPUs
4850
### Coming soon
4951

5052
1. Async checkpointing
51-
2. FP8 support
53+
2. Float8 support
5254
3. Context Parallel
5355
4. 3D Pipeline Parallel
5456
5. `torch.compile` support

estimation.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
from torchtitan.config_manager import JobConfig
1818
from torchtitan.datasets import build_tokenizer
19-
from torchtitan.float8_linear import Float8Handler
19+
from torchtitan.float8 import Float8Handler
2020
from torchtitan.logging import init_logger, logger
2121
from torchtitan.models import model_name_to_cls, model_name_to_tokenizer, models_config
2222
from torchtitan.optimizer import build_lr_schedulers, build_optimizers
@@ -124,9 +124,9 @@ def loss_fn(pred, labels):
124124
with torch.device("meta"):
125125
whole_model = model_cls.from_model_args(model_config)
126126

127-
# a no-op hander if fp8 is not enabled
127+
# a no-op hander if float8 is not enabled
128128
float8_handler = Float8Handler(job_config, parallel_dims)
129-
# swap to Float8Linear base on fp8 config
129+
# swap to Float8Linear based on float8 configs
130130
float8_handler.convert_to_float8_training(whole_model)
131131

132132
# apply PT-D DP/TP parallelisms and activation checkpointing
@@ -190,7 +190,7 @@ def loss_fn(pred, labels):
190190
lr_schedulers.step()
191191
# calculate float8 dynamic amax/scale for all-parameter for FSDP2
192192
# it issues a single all-reduce for all parameters at once for better performance
193-
float8_handler.precompute_fp8_dynamic_scale_for_fsdp(model)
193+
float8_handler.precompute_float8_dynamic_scale_for_fsdp(model)
194194
optimizers.zero_grad()
195195
print(f"Peak Memory at iter: {iter_idx}")
196196
fsdp_memtracker.display_snapshot("peak", units="MiB", tabulate=True)

torchtitan/float8_linear.py torchtitan/float8.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from torchtitan.parallelisms import ParallelDims
2222

2323

24-
def is_sm90_or_later():
24+
def _is_sm90_or_later():
2525
# Float8 is only supported on H100+ GPUs
2626
return torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0)
2727

@@ -33,7 +33,7 @@ def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims):
3333
float8_config = job_config.float8
3434
if not float8_config.enable_float8_linear:
3535
return
36-
if not is_sm90_or_later():
36+
if not _is_sm90_or_later():
3737
logger.warning(
3838
"Failed to swap to Float8Linear because SM90 or later is not available",
3939
)
@@ -42,7 +42,7 @@ def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims):
4242
from torchao.float8 import CastConfig, Float8LinearConfig, ScalingType
4343
except ImportError as e:
4444
raise ImportError(
45-
"torchao is not installed. Please install it to use fp8 linear layers."
45+
"torchao is not installed. Please install it to use float8 linear layers."
4646
) from e
4747

4848
# Mutates the model inplace replacing instances of torch.nn.Linear with Float8Linear
@@ -64,7 +64,7 @@ def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims):
6464

6565
self.enabled = True
6666

67-
# for precompute_fp8_dynamic_scale_for_fsdp
67+
# for precompute_float8_dynamic_scale_for_fsdp
6868
self.precompute_scale = (
6969
enable_fsdp_float8_all_gather
7070
and float8_config.precompute_float8_dynamic_scale_for_fsdp
@@ -103,7 +103,7 @@ def convert_to_float8_training(self, model: nn.Module):
103103
f"{self.config.enable_fsdp_float8_all_gather}"
104104
)
105105

106-
def precompute_fp8_dynamic_scale_for_fsdp(self, model: nn.Module):
106+
def precompute_float8_dynamic_scale_for_fsdp(self, model: nn.Module):
107107
if not self.enabled:
108108
return
109109

torchtitan/metrics.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -127,16 +127,16 @@ def _get_metrics_rank(parallel_dims: ParallelDims) -> int:
127127

128128

129129
def build_metric_logger(
130-
config: JobConfig, parallel_dims: ParallelDims, tag: Optional[str] = None
130+
job_config: JobConfig, parallel_dims: ParallelDims, tag: Optional[str] = None
131131
):
132132
"""
133133
parallel_dims is used to determine the rank to log metrics from if 'tb_config.rank_0_only=True'.
134134
In that case, `_get_metrics_rank` will be used to calculate which rank acts as 'rank 0'. This is
135135
intended to allow logging from the 0th rank within the last pipeline stage group, in case pipeline
136136
parallelism is enabled, without forcing logging from all ranks to capture loss information.
137137
"""
138-
dump_dir = config.job.dump_folder
139-
tb_config = config.metrics
138+
dump_dir = job_config.job.dump_folder
139+
tb_config = job_config.metrics
140140
save_tb_folder = tb_config.save_tb_folder
141141
# since we don't have run id, use current minute as the identifier
142142
datetime_str = datetime.now().strftime("%Y%m%d-%H%M")

torchtitan/parallelisms/__init__.py

+3-64
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,10 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
from dataclasses import dataclass
8-
from functools import cached_property
97

10-
from torch.distributed.device_mesh import init_device_mesh
11-
from torchtitan.logging import logger
12-
from torchtitan.parallelisms.parallelize_llama import parallelize_llama, pipeline_llama
8+
from torchtitan.parallelisms.parallel_dims import ParallelDims
9+
from torchtitan.parallelisms.parallelize_llama import parallelize_llama
10+
from torchtitan.parallelisms.pipeline_llama import pipeline_llama
1311
from torchtitan.parallelisms.pipelining_utils import build_pipeline_schedule
1412

1513

@@ -28,62 +26,3 @@
2826
"llama2": pipeline_llama,
2927
"llama3": pipeline_llama,
3028
}
31-
32-
33-
@dataclass
34-
class ParallelDims:
35-
dp: int
36-
tp: int
37-
pp: int
38-
world_size: int
39-
enable_loss_parallel: bool
40-
dp_type: str
41-
42-
def __post_init__(self):
43-
self.dp_type = self.dp_type.lower()
44-
self._validate()
45-
46-
def _validate(self):
47-
dp, tp, pp = self.dp, self.tp, self.pp
48-
if dp == -1:
49-
self.dp = dp = self.world_size // (tp * pp)
50-
assert dp >= 1, dp
51-
assert tp >= 1, tp
52-
assert pp >= 1, pp
53-
assert (
54-
dp * tp * pp == self.world_size
55-
), f"Invalid parallel dims: dp({dp}) * tp({tp}) * pp({pp}) != WORLD_SIZE({self.world_size})"
56-
assert self.dp_type in ("fsdp", "ddp")
57-
58-
def build_mesh(self, device_type):
59-
dims = []
60-
names = []
61-
for d, name in zip(
62-
[self.pp, self.dp, self.tp], ["pp", "dp", "tp"], strict=True
63-
):
64-
if d > 1:
65-
dims.append(d)
66-
names.append(name)
67-
logger.info(f"Building {len(dims)}-D device mesh with {names}, {dims}")
68-
names = tuple(names)
69-
return init_device_mesh(device_type, dims, mesh_dim_names=names)
70-
71-
@property
72-
def dp_enabled(self):
73-
return self.dp > 1
74-
75-
@property
76-
def tp_enabled(self):
77-
return self.tp > 1
78-
79-
@property
80-
def pp_enabled(self):
81-
return self.pp > 1
82-
83-
@property
84-
def loss_parallel_enabled(self):
85-
return self.tp > 1 and self.enable_loss_parallel
86-
87-
@cached_property
88-
def model_parallel_size(self):
89-
return self.tp * self.pp
+70
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from dataclasses import dataclass
8+
from functools import cached_property
9+
10+
from torch.distributed.device_mesh import init_device_mesh
11+
from torchtitan.logging import logger
12+
13+
14+
@dataclass
15+
class ParallelDims:
16+
dp: int
17+
tp: int
18+
pp: int
19+
world_size: int
20+
enable_loss_parallel: bool
21+
dp_type: str
22+
23+
def __post_init__(self):
24+
self.dp_type = self.dp_type.lower()
25+
self._validate()
26+
27+
def _validate(self):
28+
dp, tp, pp = self.dp, self.tp, self.pp
29+
if dp == -1:
30+
self.dp = dp = self.world_size // (tp * pp)
31+
assert dp >= 1, dp
32+
assert tp >= 1, tp
33+
assert pp >= 1, pp
34+
assert (
35+
dp * tp * pp == self.world_size
36+
), f"Invalid parallel dims: dp({dp}) * tp({tp}) * pp({pp}) != WORLD_SIZE({self.world_size})"
37+
assert self.dp_type in ("fsdp", "ddp")
38+
39+
def build_mesh(self, device_type):
40+
dims = []
41+
names = []
42+
for d, name in zip(
43+
[self.pp, self.dp, self.tp], ["pp", "dp", "tp"], strict=True
44+
):
45+
if d > 1:
46+
dims.append(d)
47+
names.append(name)
48+
logger.info(f"Building {len(dims)}-D device mesh with {names}, {dims}")
49+
names = tuple(names)
50+
return init_device_mesh(device_type, dims, mesh_dim_names=names)
51+
52+
@property
53+
def dp_enabled(self):
54+
return self.dp > 1
55+
56+
@property
57+
def tp_enabled(self):
58+
return self.tp > 1
59+
60+
@property
61+
def pp_enabled(self):
62+
return self.pp > 1
63+
64+
@property
65+
def loss_parallel_enabled(self):
66+
return self.tp > 1 and self.enable_loss_parallel
67+
68+
@cached_property
69+
def model_parallel_size(self):
70+
return self.tp * self.pp

0 commit comments

Comments
 (0)