|
| 1 | +# Copyright (c) Meta Platforms, Inc. and affiliates. |
| 2 | +# All rights reserved. |
| 3 | +# |
| 4 | +# This source code is licensed under the BSD-style license found in the |
| 5 | +# LICENSE file in the root directory of this source tree. |
| 6 | + |
| 7 | +import os |
| 8 | +import time |
| 9 | +from datetime import timedelta |
| 10 | + |
| 11 | +import torch |
| 12 | +from torch.distributed.elastic.multiprocessing.errors import record |
| 13 | + |
| 14 | +from torchbenchmark.util.experiment.instantiator import ( |
| 15 | + load_model, |
| 16 | + TorchBenchModelConfig, |
| 17 | +) |
| 18 | +from torchbenchmark.util.experiment.metrics import get_model_flops |
| 19 | +from torchbenchmark.util.input import input_cast |
| 20 | + |
| 21 | +from torchtitan import utils |
| 22 | +from torchtitan.checkpoint import TrainState |
| 23 | +from torchtitan.config_manager import JobConfig, TORCH_DTYPE_MAP |
| 24 | +from torchtitan.logging import init_logger, logger |
| 25 | +from torchtitan.metrics import build_gpu_memory_monitor |
| 26 | +from torchtitan.parallelisms import ParallelDims |
| 27 | +from torchtitan.parallelisms.parallelize_llama import torch_spmd_parallelize |
| 28 | +from torchtitan.profiling import maybe_enable_memory_snapshot, maybe_enable_profiling |
| 29 | + |
| 30 | + |
| 31 | +# Enable debug tracing on failure: https://pytorch.org/docs/stable/elastic/errors.html |
| 32 | +@record |
| 33 | +def main(job_config: JobConfig): |
| 34 | + init_logger() |
| 35 | + logger.info(f"Starting job: {job_config.job.description}") |
| 36 | + |
| 37 | + # used for colorful printing |
| 38 | + color = utils.Color if job_config.metrics.enable_color_printing else utils.NoColor |
| 39 | + |
| 40 | + # take control of garbage collection to avoid stragglers |
| 41 | + gc_handler = utils.GarbageCollection(gc_freq=job_config.training.gc_freq) |
| 42 | + |
| 43 | + # init distributed |
| 44 | + world_size = int(os.environ["WORLD_SIZE"]) |
| 45 | + parallel_dims = ParallelDims( |
| 46 | + dp=job_config.training.data_parallel_degree, |
| 47 | + tp=job_config.training.tensor_parallel_degree, |
| 48 | + pp=job_config.experimental.pipeline_parallel_degree, |
| 49 | + world_size=world_size, |
| 50 | + enable_loss_parallel=job_config.training.enable_loss_parallel, |
| 51 | + dp_type=job_config.training.data_parallel_type, |
| 52 | + ) |
| 53 | + device = torch.device(f"cuda:{int(os.environ['LOCAL_RANK'])}") |
| 54 | + torch.cuda.set_device(device) |
| 55 | + utils.init_distributed(job_config) |
| 56 | + # initialize GPU memory monitor and get peak flops for MFU calculation |
| 57 | + gpu_memory_monitor = build_gpu_memory_monitor() |
| 58 | + gpu_peak_flops = utils.get_peak_flops(gpu_memory_monitor.device_name) |
| 59 | + |
| 60 | + # build meshes |
| 61 | + world_mesh = parallel_dims.build_mesh(device_type="cuda") |
| 62 | + if parallel_dims.dp_enabled: |
| 63 | + dp_mesh = world_mesh["dp"] |
| 64 | + dp_degree, dp_rank = dp_mesh.size(), dp_mesh.get_local_rank() |
| 65 | + else: |
| 66 | + dp_degree, dp_rank = 1, 0 |
| 67 | + |
| 68 | + if parallel_dims.pp_enabled: |
| 69 | + pp_mesh = world_mesh["pp"] |
| 70 | + |
| 71 | + model_name = job_config.model.name |
| 72 | + |
| 73 | + # initiate model from torchbench |
| 74 | + config = TorchBenchModelConfig( |
| 75 | + name=model_name, |
| 76 | + test="train", |
| 77 | + device="cuda", |
| 78 | + batch_size=job_config.training.batch_size, |
| 79 | + extra_args=[], |
| 80 | + ) |
| 81 | + model_flops = get_model_flops(config) |
| 82 | + benchmark_model = load_model(config) |
| 83 | + model, _ = benchmark_model.get_module() |
| 84 | + |
| 85 | + # TODO: there seems to be a bug with dtype conversion (e.g. use resnet50) |
| 86 | + # cast input dtype if needed |
| 87 | + param_dtype = TORCH_DTYPE_MAP[job_config.training.mixed_precision_param] |
| 88 | + input_cond = lambda x: x.dtype == torch.float32 |
| 89 | + input_action = lambda x: x.to(param_dtype) |
| 90 | + if hasattr(benchmark_model, "example_inputs"): |
| 91 | + benchmark_model.example_inputs = input_cast( |
| 92 | + input_cond, input_action, benchmark_model.example_inputs |
| 93 | + ) |
| 94 | + else: |
| 95 | + logger.warning( |
| 96 | + f"{model_name} example inputs haven't been cast to {action} yet!" |
| 97 | + ) |
| 98 | + |
| 99 | + # log model size |
| 100 | + model_param_count = utils.get_num_params(model) |
| 101 | + logger.info( |
| 102 | + f"{color.blue}Model {model_name} " |
| 103 | + f"{color.red}size: {model_param_count:,} total parameters{color.reset}" |
| 104 | + ) |
| 105 | + |
| 106 | + # apply PT-D Tensor Parallel, activation checkpointing, torch.compile, Data Parallel |
| 107 | + model = torch_spmd_parallelize(model, world_mesh, parallel_dims, job_config) |
| 108 | + |
| 109 | + # update model and optimizer after applying parallelisms |
| 110 | + benchmark_model.set_module(model) |
| 111 | + optimizer = benchmark_model.get_optimizer() |
| 112 | + optimizer.add_param_group({"params": model.parameters()}) |
| 113 | + |
| 114 | + model.train() |
| 115 | + |
| 116 | + gpu_mem_stats = gpu_memory_monitor.get_peak_stats() |
| 117 | + logger.info( |
| 118 | + f"GPU memory usage for model: " |
| 119 | + f"{gpu_mem_stats.max_reserved_gib:.2f}GiB" |
| 120 | + f"({gpu_mem_stats.max_reserved_pct:.2f}%)" |
| 121 | + ) |
| 122 | + |
| 123 | + train_state = TrainState() |
| 124 | + |
| 125 | + # variables used to keep info for metrics logging |
| 126 | + losses_since_last_log = [] |
| 127 | + gpu_memory_monitor.reset_peak_stats() |
| 128 | + |
| 129 | + # train loop |
| 130 | + logger.info( |
| 131 | + f"Training starts at step {train_state.step + 1}, " |
| 132 | + f"with local batch size {job_config.training.batch_size}, " |
| 133 | + f"global batch size {job_config.training.batch_size * dp_degree}, " |
| 134 | + f"total steps {job_config.training.steps}" |
| 135 | + ) |
| 136 | + with maybe_enable_profiling( |
| 137 | + job_config, global_step=train_state.step |
| 138 | + ) as torch_profiler, maybe_enable_memory_snapshot( |
| 139 | + job_config, global_step=train_state.step |
| 140 | + ) as memory_profiler: |
| 141 | + while train_state.step < job_config.training.steps: |
| 142 | + train_state.step += 1 |
| 143 | + gc_handler.run(train_state.step) |
| 144 | + |
| 145 | + torch.cuda.synchronize() |
| 146 | + start_event = torch.cuda.Event(enable_timing=True) |
| 147 | + end_event = torch.cuda.Event(enable_timing=True) |
| 148 | + |
| 149 | + # Collect time_ns() instead of time() which does not provide better precision than 1 |
| 150 | + # second according to https://docs.python.org/3/library/time.html#time.time. |
| 151 | + t0 = time.time_ns() |
| 152 | + start_event.record() |
| 153 | + |
| 154 | + is_staged = ( |
| 155 | + hasattr(benchmark_model, "forward") |
| 156 | + and hasattr(benchmark_model, "backward") |
| 157 | + and hasattr(benchmark_model, "optimizer_step") |
| 158 | + ) |
| 159 | + if is_staged and (getattr(benchmark_model, "train", None) is None): |
| 160 | + if optimizer is not None: |
| 161 | + optimizer.zero_grad() |
| 162 | + loss = benchmark_model.forward() |
| 163 | + benchmark_model.backward(loss) |
| 164 | + if optimizer is not None: |
| 165 | + benchmark_model.optimizer_step() |
| 166 | + else: |
| 167 | + loss = benchmark_model.train() |
| 168 | + |
| 169 | + end_event.record() |
| 170 | + torch.cuda.synchronize() |
| 171 | + t1 = time.time_ns() |
| 172 | + time_delta = start_event.elapsed_time(end_event), (t1 - t0) / 1_000_000 |
| 173 | + |
| 174 | + # log metrics |
| 175 | + losses_since_last_log.append(loss) |
| 176 | + if ( |
| 177 | + train_state.step == 1 |
| 178 | + or train_state.step % job_config.metrics.log_freq == 0 |
| 179 | + ): |
| 180 | + losses = [ |
| 181 | + loss.item() if isinstance(loss, torch.Tensor) else loss |
| 182 | + for loss in losses_since_last_log |
| 183 | + ] |
| 184 | + avg_loss, max_loss = sum(losses) / len(losses), max(losses) |
| 185 | + if parallel_dims.dp_enabled: |
| 186 | + global_avg_loss, global_max_loss = ( |
| 187 | + utils.dist_mean(avg_loss, dp_mesh), |
| 188 | + utils.dist_max(max_loss, dp_mesh), |
| 189 | + ) |
| 190 | + else: |
| 191 | + global_avg_loss, global_max_loss = avg_loss, max_loss |
| 192 | + |
| 193 | + gpu_mem_stats = gpu_memory_monitor.get_peak_stats() |
| 194 | + |
| 195 | + logger.info( |
| 196 | + f"{color.cyan}step: {train_state.step:2} " |
| 197 | + f"{color.green}loss: {global_avg_loss:7.4f} " |
| 198 | + f"{color.yellow}memory: {gpu_mem_stats.max_reserved_gib:5.2f}GiB" |
| 199 | + f"({gpu_mem_stats.max_reserved_pct:.2f}%) " |
| 200 | + f"{color.blue}GPU time: {time_delta[0]:.3f}ms " |
| 201 | + f"CPU wall time: {time_delta[1]:.3f}ms{color.reset}" |
| 202 | + ) |
| 203 | + |
| 204 | + losses_since_last_log.clear() |
| 205 | + gpu_memory_monitor.reset_peak_stats() |
| 206 | + |
| 207 | + # signal the profiler that the next profiling step has started |
| 208 | + if torch_profiler: |
| 209 | + torch_profiler.step() |
| 210 | + if memory_profiler: |
| 211 | + memory_profiler.step() |
| 212 | + |
| 213 | + # reduce timeout after first train step for faster signal |
| 214 | + # (assuming lazy init and compilation are finished) |
| 215 | + if train_state.step == 1: |
| 216 | + utils.set_pg_timeouts( |
| 217 | + timeout=timedelta(seconds=job_config.comm.train_timeout_seconds), |
| 218 | + world_mesh=world_mesh, |
| 219 | + ) |
| 220 | + |
| 221 | + if torch.distributed.get_rank() == 0: |
| 222 | + logger.info("Sleeping 2 seconds for other ranks to complete") |
| 223 | + time.sleep(2) |
| 224 | + |
| 225 | + logger.info("Training completed") |
| 226 | + |
| 227 | + |
| 228 | +if __name__ == "__main__": |
| 229 | + config = JobConfig() |
| 230 | + config.parse_args() |
| 231 | + main(config) |
| 232 | + torch.distributed.destroy_process_group() |
0 commit comments