forked from pytorch/torchtitan
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[BE][4/n] split pipeline_llama into a separate file
ghstack-source-id: 5ebb4adf3152f413fa33a923c272c9aa3ce1f775 Pull Request resolved: pytorch#499
- Loading branch information
Showing
9 changed files
with
474 additions
and
462 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
from dataclasses import dataclass | ||
from functools import cached_property | ||
|
||
from torch.distributed.device_mesh import init_device_mesh | ||
from torchtitan.logging import logger | ||
|
||
|
||
@dataclass | ||
class ParallelDims: | ||
dp: int | ||
tp: int | ||
pp: int | ||
world_size: int | ||
enable_loss_parallel: bool | ||
dp_type: str | ||
|
||
def __post_init__(self): | ||
self.dp_type = self.dp_type.lower() | ||
self._validate() | ||
|
||
def _validate(self): | ||
dp, tp, pp = self.dp, self.tp, self.pp | ||
if dp == -1: | ||
self.dp = dp = self.world_size // (tp * pp) | ||
assert dp >= 1, dp | ||
assert tp >= 1, tp | ||
assert pp >= 1, pp | ||
assert ( | ||
dp * tp * pp == self.world_size | ||
), f"Invalid parallel dims: dp({dp}) * tp({tp}) * pp({pp}) != WORLD_SIZE({self.world_size})" | ||
assert self.dp_type in ("fsdp", "ddp") | ||
|
||
def build_mesh(self, device_type): | ||
dims = [] | ||
names = [] | ||
for d, name in zip( | ||
[self.pp, self.dp, self.tp], ["pp", "dp", "tp"], strict=True | ||
): | ||
if d > 1: | ||
dims.append(d) | ||
names.append(name) | ||
logger.info(f"Building {len(dims)}-D device mesh with {names}, {dims}") | ||
names = tuple(names) | ||
return init_device_mesh(device_type, dims, mesh_dim_names=names) | ||
|
||
@property | ||
def dp_enabled(self): | ||
return self.dp > 1 | ||
|
||
@property | ||
def tp_enabled(self): | ||
return self.tp > 1 | ||
|
||
@property | ||
def pp_enabled(self): | ||
return self.pp > 1 | ||
|
||
@property | ||
def loss_parallel_enabled(self): | ||
return self.tp > 1 and self.enable_loss_parallel | ||
|
||
@cached_property | ||
def model_parallel_size(self): | ||
return self.tp * self.pp |
Oops, something went wrong.