Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support of DDP and experimental CompiledAutograd #432

Merged
merged 9 commits into from
Jul 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions estimation.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ def estimate_memory(job_config: JobConfig):
pp=job_config.experimental.pipeline_parallel_degree,
world_size=world_size,
enable_loss_parallel=job_config.training.enable_loss_parallel,
dp_type=job_config.training.data_parallel_type,
)

device = torch.device(f"cuda:{int(os.environ['LOCAL_RANK'])}")
Expand Down
9 changes: 9 additions & 0 deletions test_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,15 @@ def build_test_list():
],
"FSDP2 with float8 all-gather and precomputed dynamic scales",
"fsdp2_float8_all_gather_precompute_dynamic_scales",
),
OverrideDefinitions(
[
[
"--training.data_parallel_type ddp",
]
],
"DDP",
"ddp",
ngpu=4,
),
]
Expand Down
11 changes: 11 additions & 0 deletions torchtitan/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,17 @@ def __init__(self):
The default value will be the number of pipeline stages, if unspecified.
""",
)
self.parser.add_argument(
"--training.data_parallel_type",
type=str,
default="fsdp",
help="Data parallelism type. TorchTitan currently supports FSDP and DDP.",
)
self.parser.add_argument(
"--experimental.enable_compiled_autograd",
action="store_true",
help="Enable CompiledAutograd to compile the backward.",
)
self.parser.add_argument(
"--training.mixed_precision_param",
type=str,
Expand Down
3 changes: 3 additions & 0 deletions torchtitan/parallelisms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,10 @@ class ParallelDims:
pp: int
world_size: int
enable_loss_parallel: bool
dp_type: str

def __post_init__(self):
self.dp_type = self.dp_type.lower()
self._validate()

def _validate(self):
Expand All @@ -42,6 +44,7 @@ def _validate(self):
assert (
dp * tp * pp == self.world_size
), f"Invalid parallel dims: dp({dp}) * tp({tp}) * pp({pp}) != WORLD_SIZE({self.world_size})"
assert self.dp_type in ("fsdp", "ddp")

def build_mesh(self, device_type):
dims = []
Expand Down
36 changes: 33 additions & 3 deletions torchtitan/parallelisms/parallelize_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
from torch.distributed import DeviceMesh

from torch.distributed._composable.fsdp import fully_shard, MixedPrecisionPolicy

from torch.distributed._composable.replicate import replicate
from torch.distributed._tensor import Replicate, Shard
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
checkpoint_wrapper as ptd_checkpoint_wrapper,
Expand Down Expand Up @@ -453,13 +455,15 @@ def apply_compile(model: nn.Module, job_config: JobConfig):
return model


def apply_dp(
def apply_fsdp(
model: nn.Module,
world_mesh: DeviceMesh,
parallel_dims: "ParallelDims",
job_config: JobConfig,
):
"""Apply data parallelism (FSDP2) to the model."""
"""
Apply data parallelism to the model. FSDP2 is used here.
"""

dp_mesh = world_mesh["dp"] if world_mesh.ndim > 1 else world_mesh
assert dp_mesh.mesh_dim_names == ("dp",), dp_mesh.mesh_dim_names
Expand Down Expand Up @@ -492,6 +496,29 @@ def apply_dp(
return model


def apply_ddp(
model: nn.Module,
world_mesh: DeviceMesh,
parallel_dims: "ParallelDims",
job_config: JobConfig,
):
if world_mesh.ndim > 1:
raise RuntimeError("DDP has not supported > 1D parallelism.")

if job_config.training.compile:
if job_config.experimental.enable_compiled_autograd:
torch._dynamo.config.optimize_ddp = (
"python_reducer_without_compiled_forward"
)
else:
torch._dynamo.config.optimize_ddp = "ddp_optimizer"

model = replicate(model, device_mesh=world_mesh, bucket_cap_mb=100)

logger.info("Applied DDP to the model")
return model


def parallelize_llama(
model: nn.Module,
world_mesh: DeviceMesh,
Expand All @@ -516,6 +543,9 @@ def parallelize_llama(
model = apply_compile(model, job_config)

if parallel_dims.dp_enabled:
model = apply_dp(model, world_mesh, parallel_dims, job_config)
if parallel_dims.dp_type == "fsdp":
model = apply_fsdp(model, world_mesh, parallel_dims, job_config)
else:
model = apply_ddp(model, world_mesh, parallel_dims, job_config)

return model
27 changes: 22 additions & 5 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,22 @@ def zero_grad(self):
return OptimizersContainer([_build_optimizer(model) for model in model_parts])


def get_train_context(enable_loss_parallel: bool, enable_compiled_autograd: bool):
@contextlib.contextmanager
def context():
with contextlib.ExitStack() as stack:
if enable_loss_parallel:
stack.enter_context(loss_parallel())
if enable_compiled_autograd:
stack.enter_context(
torch._dynamo.utils.maybe_enable_compiled_autograd(True)
)

yield

return context


# Enable debug tracing on failure: https://pytorch.org/docs/stable/elastic/errors.html
@record
def main(job_config: JobConfig):
Expand All @@ -160,6 +176,7 @@ def main(job_config: JobConfig):
pp=job_config.experimental.pipeline_parallel_degree,
world_size=world_size,
enable_loss_parallel=job_config.training.enable_loss_parallel,
dp_type=job_config.training.data_parallel_type,
)
device = torch.device(f"cuda:{int(os.environ['LOCAL_RANK'])}")
torch.cuda.set_device(device)
Expand Down Expand Up @@ -194,9 +211,9 @@ def main(job_config: JobConfig):
dp_rank,
)

# loss_parallel enables dispatching to efficient loss operators
loss_parallel_ctx = (
loss_parallel if parallel_dims.loss_parallel_enabled else contextlib.nullcontext
train_context = get_train_context(
parallel_dims.loss_parallel_enabled,
job_config.experimental.enable_compiled_autograd,
)

# loss fn can be shared by pipeline-parallel or non-pp execution
Expand Down Expand Up @@ -364,7 +381,7 @@ def loss_fn(pred, labels):
# pipeline parallel forward / backward inside step() call
is_last_stage = pp_mesh.get_local_rank() == pp_mesh.size() - 1

with loss_parallel_ctx():
with train_context():
if pp_mesh.get_local_rank() == 0:
pp_schedule.step(input_ids)
elif is_last_stage:
Expand All @@ -381,7 +398,7 @@ def loss_fn(pred, labels):
)
else:
# Non-PP forward / backward
with loss_parallel_ctx():
with train_context():
pred = model(input_ids)
loss = loss_fn(pred, labels)
# pred.shape=(bs, seq_len, vocab_size)
Expand Down
Loading