From 8a13dd2bb379b8f62e13fb4aebb7b21e4e809322 Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Thu, 18 Jul 2024 09:55:38 -0700 Subject: [PATCH] Add support of DDP and experimental CompiledAutograd Summary: Address the comments in https://github.com/pytorch/torchtitan/pull/319 and resubmit the PR to fit the current code base. Test Plan: ``` CONFIG_FILE=./train_configs/debug_model.toml ./run_llama_train.sh --comm.train_timeout_seconds=3600 --training.tensor_parallel_degree=1 --training.data_parallel_degree=8 --experimental.data_parallel_type=ddp --training.steps=1000 --metrics.log_freq=10 --profiling.profile_freq=1000 ``` ghstack-source-id: 81dc85d42df13df4ed727bebd825681879af936b Pull Request resolved: https://github.com/pytorch/torchtitan/pull/432 --- estimation.py | 1 + test_runner.py | 9 +++++ torchtitan/config_manager.py | 11 ++++++ torchtitan/parallelisms/__init__.py | 3 ++ torchtitan/parallelisms/parallelize_llama.py | 36 ++++++++++++++++++-- train.py | 27 ++++++++++++--- 6 files changed, 79 insertions(+), 8 deletions(-) diff --git a/estimation.py b/estimation.py index e652c5818..3e393399b 100644 --- a/estimation.py +++ b/estimation.py @@ -71,6 +71,7 @@ def estimate_memory(job_config: JobConfig): pp=job_config.experimental.pipeline_parallel_degree, world_size=world_size, enable_loss_parallel=job_config.training.enable_loss_parallel, + dp_type=job_config.training.data_parallel_type, ) device = torch.device(f"cuda:{int(os.environ['LOCAL_RANK'])}") diff --git a/test_runner.py b/test_runner.py index c84ca6af5..6a7b6b1a5 100755 --- a/test_runner.py +++ b/test_runner.py @@ -304,6 +304,15 @@ def build_test_list(): ], "FSDP2 with float8 all-gather and precomputed dynamic scales", "fsdp2_float8_all_gather_precompute_dynamic_scales", + ), + OverrideDefinitions( + [ + [ + "--training.data_parallel_type ddp", + ] + ], + "DDP", + "ddp", ngpu=4, ), ] diff --git a/torchtitan/config_manager.py b/torchtitan/config_manager.py index 2bd6e3705..9a0868306 100644 --- a/torchtitan/config_manager.py +++ b/torchtitan/config_manager.py @@ -312,6 +312,17 @@ def __init__(self): The default value will be the number of pipeline stages, if unspecified. """, ) + self.parser.add_argument( + "--training.data_parallel_type", + type=str, + default="fsdp", + help="Data parallelism type. TorchTitan currently supports FSDP and DDP.", + ) + self.parser.add_argument( + "--experimental.enable_compiled_autograd", + action="store_true", + help="Enable CompiledAutograd to compile the backward.", + ) self.parser.add_argument( "--training.mixed_precision_param", type=str, diff --git a/torchtitan/parallelisms/__init__.py b/torchtitan/parallelisms/__init__.py index 7e1b21c79..2fdba316f 100644 --- a/torchtitan/parallelisms/__init__.py +++ b/torchtitan/parallelisms/__init__.py @@ -28,8 +28,10 @@ class ParallelDims: pp: int world_size: int enable_loss_parallel: bool + dp_type: str def __post_init__(self): + self.dp_type = self.dp_type.lower() self._validate() def _validate(self): @@ -42,6 +44,7 @@ def _validate(self): assert ( dp * tp * pp == self.world_size ), f"Invalid parallel dims: dp({dp}) * tp({tp}) * pp({pp}) != WORLD_SIZE({self.world_size})" + assert self.dp_type in ("fsdp", "ddp") def build_mesh(self, device_type): dims = [] diff --git a/torchtitan/parallelisms/parallelize_llama.py b/torchtitan/parallelisms/parallelize_llama.py index ec0f67637..33b9d6d35 100644 --- a/torchtitan/parallelisms/parallelize_llama.py +++ b/torchtitan/parallelisms/parallelize_llama.py @@ -16,6 +16,8 @@ from torch.distributed import DeviceMesh from torch.distributed._composable.fsdp import fully_shard, MixedPrecisionPolicy + +from torch.distributed._composable.replicate import replicate from torch.distributed._tensor import Replicate, Shard from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( checkpoint_wrapper as ptd_checkpoint_wrapper, @@ -453,13 +455,15 @@ def apply_compile(model: nn.Module, job_config: JobConfig): return model -def apply_dp( +def apply_fsdp( model: nn.Module, world_mesh: DeviceMesh, parallel_dims: "ParallelDims", job_config: JobConfig, ): - """Apply data parallelism (FSDP2) to the model.""" + """ + Apply data parallelism to the model. FSDP2 is used here. + """ dp_mesh = world_mesh["dp"] if world_mesh.ndim > 1 else world_mesh assert dp_mesh.mesh_dim_names == ("dp",), dp_mesh.mesh_dim_names @@ -492,6 +496,29 @@ def apply_dp( return model +def apply_ddp( + model: nn.Module, + world_mesh: DeviceMesh, + parallel_dims: "ParallelDims", + job_config: JobConfig, +): + if world_mesh.ndim > 1: + raise RuntimeError("DDP has not supported > 1D parallelism.") + + if job_config.training.compile: + if job_config.experimental.enable_compiled_autograd: + torch._dynamo.config.optimize_ddp = ( + "python_reducer_without_compiled_forward" + ) + else: + torch._dynamo.config.optimize_ddp = "ddp_optimizer" + + model = replicate(model, device_mesh=world_mesh, bucket_cap_mb=100) + + logger.info("Applied DDP to the model") + return model + + def parallelize_llama( model: nn.Module, world_mesh: DeviceMesh, @@ -516,6 +543,9 @@ def parallelize_llama( model = apply_compile(model, job_config) if parallel_dims.dp_enabled: - model = apply_dp(model, world_mesh, parallel_dims, job_config) + if parallel_dims.dp_type == "fsdp": + model = apply_fsdp(model, world_mesh, parallel_dims, job_config) + else: + model = apply_ddp(model, world_mesh, parallel_dims, job_config) return model diff --git a/train.py b/train.py index afd1d8887..b7eee302f 100644 --- a/train.py +++ b/train.py @@ -138,6 +138,22 @@ def zero_grad(self): return OptimizersContainer([_build_optimizer(model) for model in model_parts]) +def get_train_context(enable_loss_parallel: bool, enable_compiled_autograd: bool): + @contextlib.contextmanager + def context(): + with contextlib.ExitStack() as stack: + if enable_loss_parallel: + stack.enter_context(loss_parallel()) + if enable_compiled_autograd: + stack.enter_context( + torch._dynamo.utils.maybe_enable_compiled_autograd(True) + ) + + yield + + return context + + # Enable debug tracing on failure: https://pytorch.org/docs/stable/elastic/errors.html @record def main(job_config: JobConfig): @@ -160,6 +176,7 @@ def main(job_config: JobConfig): pp=job_config.experimental.pipeline_parallel_degree, world_size=world_size, enable_loss_parallel=job_config.training.enable_loss_parallel, + dp_type=job_config.training.data_parallel_type, ) device = torch.device(f"cuda:{int(os.environ['LOCAL_RANK'])}") torch.cuda.set_device(device) @@ -194,9 +211,9 @@ def main(job_config: JobConfig): dp_rank, ) - # loss_parallel enables dispatching to efficient loss operators - loss_parallel_ctx = ( - loss_parallel if parallel_dims.loss_parallel_enabled else contextlib.nullcontext + train_context = get_train_context( + parallel_dims.loss_parallel_enabled, + job_config.experimental.enable_compiled_autograd, ) # loss fn can be shared by pipeline-parallel or non-pp execution @@ -364,7 +381,7 @@ def loss_fn(pred, labels): # pipeline parallel forward / backward inside step() call is_last_stage = pp_mesh.get_local_rank() == pp_mesh.size() - 1 - with loss_parallel_ctx(): + with train_context(): if pp_mesh.get_local_rank() == 0: pp_schedule.step(input_ids) elif is_last_stage: @@ -381,7 +398,7 @@ def loss_fn(pred, labels): ) else: # Non-PP forward / backward - with loss_parallel_ctx(): + with train_context(): pred = model(input_ids) loss = loss_fn(pred, labels) # pred.shape=(bs, seq_len, vocab_size)