diff --git a/.github/workflows/integration_test_4gpu.yaml b/.github/workflows/integration_test_4gpu.yaml index 3816f404..813e11af 100644 --- a/.github/workflows/integration_test_4gpu.yaml +++ b/.github/workflows/integration_test_4gpu.yaml @@ -39,5 +39,6 @@ jobs: python -m pip install --force-reinstall --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121 python -m pip install --pre torchdata --index-url https://download.pytorch.org/whl/nightly/ + USE_CPP=0 python -m pip install git+https://github.com/pytorch/ao.git mkdir artifacts-to-be-uploaded python ./test_runner.py artifacts-to-be-uploaded --ngpu 4 diff --git a/README.md b/README.md index 18364d8f..56785112 100644 --- a/README.md +++ b/README.md @@ -18,6 +18,14 @@ Our guiding principles when building `torchtitan`: [![Welcome to torchtitan!](assets/images/titan_play_video.png)](https://youtu.be/ee5DOEqD35I?si=_B94PbVv0V5ZnNKE "Welcome to torchtitan!") +### Dive into the code + +You may want to see how the model is defined or how parallelism techniques are applied. For a guided tour, see these files first: +* [train.py](https://github.com/pytorch/torchtitan/blob/main/train.py) - the main training loop and high-level setup code +* [torchtitan/parallelisms/parallelize_llama.py](https://github.com/pytorch/torchtitan/blob/main/torchtitan/parallelisms/parallelize_llama.py) - helpers for applying Data / Tensor / Pipeline Parallelisms to the model +* [torchtitan/checkpoint.py](https://github.com/pytorch/torchtitan/blob/main/torchtitan/checkpoint.py) - utils for saving/loading distributed checkpoints +* [torchtitan/models/llama/model.py](https://github.com/pytorch/torchtitan/blob/main/torchtitan/models/llama/model.py) - the Llama model definition (shared for Llama2 and Llama3 variants) + ## Pre-Release Updates: #### (4/25/2024): `torchtitan` is now public but in a pre-release state and under development. Currently we showcase pre-training **Llama 3 and Llama 2** LLMs of various sizes from scratch. `torchtitan` is tested and verified with the PyTorch nightly version `torch-2.4.0.dev20240412`. (We recommend latest PyTorch nightly). @@ -66,7 +74,7 @@ Once you have confirmed access, you can run the following command to download th ```bash # Get your HF token from https://huggingface.co/settings/tokens -# llama3 tokenizer.model +# llama3 or 3.1 tokenizer.model python torchtitan/datasets/download_tokenizer.py --repo_id meta-llama/Meta-Llama-3-8B --tokenizer_path "original" --hf_token=... # llama2 tokenizer.model diff --git a/create_seed_checkpoint.sh b/create_seed_checkpoint.sh index 1abc77ec..77185bfc 100755 --- a/create_seed_checkpoint.sh +++ b/create_seed_checkpoint.sh @@ -18,8 +18,6 @@ set -ex -export USE_LIBUV=1 -TRAINER_DIR=${1:-/home/$USER/local/torchtitan} NGPU=1 LOG_RANK=0 CONFIG_FILE=${CONFIG_FILE:-"./train_configs/debug_model.toml"} diff --git a/estimation.py b/estimation.py index ddf24d8a..acf867d5 100644 --- a/estimation.py +++ b/estimation.py @@ -9,22 +9,19 @@ import os import torch -import torch.nn.functional as F from torch._guards import active_fake_mode from torch._subclasses.fake_tensor import FakeTensorMode -from torch.distributed import destroy_process_group from torch.distributed._tools.fsdp2_mem_tracker import FSDPMemTracker -from torch.distributed.tensor.parallel import loss_parallel from torch.testing._internal.distributed.fake_pg import FakeStore from torchtitan.config_manager import JobConfig -from torchtitan.datasets import create_tokenizer -from torchtitan.float8_linear import build_fp8_linear -from torchtitan.logging_utils import init_logger, logger -from torchtitan.lr_scheduling import get_lr_schedulers +from torchtitan.datasets import build_tokenizer +from torchtitan.float8_linear import Float8Handler +from torchtitan.logging import init_logger, logger from torchtitan.models import model_name_to_cls, model_name_to_tokenizer, models_config +from torchtitan.optimizer import build_lr_schedulers, build_optimizers from torchtitan.parallelisms import models_parallelize_fns, ParallelDims -from train import build_optimizers +from train import get_train_context def estimate_memory(job_config: JobConfig): @@ -61,9 +58,10 @@ def estimate_memory(job_config: JobConfig): logger.info("Compiled RMSNorm is not supported yet. Switching to RMSNorm.") job_config.model.norm_type = "rmsnorm" - if job_config.training.compile: + if job_config.training.compile or job_config.experimental.enable_compiled_autograd: logger.info("Compile mode is not supported yet. Switching to eager mode.") job_config.training.compile = False + job_config.experimental.enable_compiled_autograd = False parallel_dims = ParallelDims( dp=job_config.training.data_parallel_degree, @@ -71,6 +69,7 @@ def estimate_memory(job_config: JobConfig): pp=job_config.experimental.pipeline_parallel_degree, world_size=world_size, enable_loss_parallel=job_config.training.enable_loss_parallel, + dp_type=job_config.training.data_parallel_type, ) device = torch.device(f"cuda:{int(os.environ['LOCAL_RANK'])}") @@ -93,16 +92,18 @@ def estimate_memory(job_config: JobConfig): # build tokenizer tokenizer_type = model_name_to_tokenizer[model_name] - tokenizer = create_tokenizer(tokenizer_type, job_config.model.tokenizer_path) + tokenizer = build_tokenizer(tokenizer_type, job_config.model.tokenizer_path) - # loss_parallel enables dispatching to efficient loss operators - loss_parallel_ctx = ( - loss_parallel if parallel_dims.loss_parallel_enabled else contextlib.nullcontext + train_context = get_train_context( + parallel_dims.loss_parallel_enabled, + job_config.experimental.enable_compiled_autograd, ) # loss fn can be shared by pipeline-parallel or non-pp execution def loss_fn(pred, labels): - return F.cross_entropy(pred.flatten(0, 1), labels.flatten(0, 1)) + return torch.nn.functional.cross_entropy( + pred.flatten(0, 1), labels.flatten(0, 1) + ) # build model (using meta init) model_cls = model_name_to_cls[model_name] @@ -123,9 +124,10 @@ def loss_fn(pred, labels): with torch.device("meta"): whole_model = model_cls.from_model_args(model_config) - # apply fp8 linear module swap - if job_config.training.fp8_linear: - build_fp8_linear(whole_model, job_config) + # a no-op hander if fp8 is not enabled + float8_handler = Float8Handler(job_config, parallel_dims) + # swap to Float8Linear base on fp8 config + float8_handler.convert_to_float8_training(whole_model) # apply PT-D DP/TP parallelisms and activation checkpointing model_parts = [whole_model] @@ -143,7 +145,7 @@ def loss_fn(pred, labels): # build optimizer after applying parallelisms to the model optimizers = build_optimizers(model_parts, job_config) - lr_schedulers = get_lr_schedulers(optimizers.optimizers, job_config) + lr_schedulers = build_lr_schedulers(optimizers.optimizers, job_config) for model in model_parts: model.train() @@ -170,7 +172,7 @@ def loss_fn(pred, labels): for iter_idx in range(2): input_ids, labels = batch # train step - with loss_parallel_ctx(): + with train_context(): pred = whole_model(input_ids) loss = loss_fn(pred, labels) del pred @@ -181,9 +183,14 @@ def loss_fn(pred, labels): torch.nn.utils.clip_grad_norm_( model.parameters(), job_config.training.max_norm, foreach=True ) + # sync float8 amaxes and scales + float8_handler.sync_float8_amax_and_scale_history(model) # optimizer step optimizers.step() lr_schedulers.step() + # calculate float8 dynamic amax/scale for all-parameter for FSDP2 + # it issues a single all-reduce for all parameters at once for better performance + float8_handler.precompute_fp8_dynamic_scale_for_fsdp(model) optimizers.zero_grad() print(f"Peak Memory at iter: {iter_idx}") fsdp_memtracker.display_snapshot("peak", units="MiB", tabulate=True) @@ -217,4 +224,4 @@ def loss_fn(pred, labels): try: estimate_memory(config) finally: - destroy_process_group() + torch.distributed.destroy_process_group() diff --git a/multinode_trainer.slurm b/multinode_trainer.slurm index 09b94ef1..4bc495d3 100644 --- a/multinode_trainer.slurm +++ b/multinode_trainer.slurm @@ -53,7 +53,6 @@ export NCCL_SOCKET_IFNAME="eth0,en,eth,em,bond" export NCCL_BUFFSIZE=2097152 #export TORCH_DIST_INIT_BARRIER=1 export FI_EFA_SET_CUDA_SYNC_MEMOPS=0 -#export USE_LIBUV=1 CONFIG_FILE=${CONFIG_FILE:-"./train_configs/llama2_13b.toml"} dcgmi profile --pause diff --git a/run_llama_train.sh b/run_llama_train.sh index cf4943a6..a4107806 100755 --- a/run_llama_train.sh +++ b/run_llama_train.sh @@ -7,22 +7,11 @@ set -ex -# libUV is a scalable backend for TCPStore which is used in processGroup -# rendezvous. This is the recommended backend for distributed training. -export USE_LIBUV=1 -TRAINER_DIR=${TRAINER_DIR:-/home/$USER/local/torchtitan} - # use envs as local overrides for convenience # e.g. # LOG_RANK=0,1 NGPU=4 ./run_llama_train.sh - NGPU=${NGPU:-"8"} -NNODES=${NNODES:-"1"} - -# by default log just rank 0 output, LOG_RANK=${LOG_RANK:-0} - - CONFIG_FILE=${CONFIG_FILE:-"./train_configs/debug_model.toml"} overrides="" @@ -30,16 +19,6 @@ if [ $# -ne 0 ]; then overrides="$*" fi -# Check if --estimate.memory=True is in the arguments -if echo "$overrides" | grep -q -- "--memory_estimation.enabled"; then - # Calculate WORLD_SIZE as the product of NGPU and NNODES - # Export WORLD_SIZE and LOCAL_RANK - export WORLD_SIZE=$((NGPU * NNODES)) - export LOCAL_RANK=0 - python estimation.py --job.config_file ${CONFIG_FILE} $overrides -else - # Call train.py if not in estimation mode - torchrun --nproc_per_node=${NGPU} --rdzv_backend c10d --rdzv_endpoint="localhost:0" \ - --local-ranks-filter ${LOG_RANK} --role rank --tee 3 \ - train.py --job.config_file ${CONFIG_FILE} $overrides -fi +torchrun --nproc_per_node=${NGPU} --rdzv_backend c10d --rdzv_endpoint="localhost:0" \ +--local-ranks-filter ${LOG_RANK} --role rank --tee 3 \ +train.py --job.config_file ${CONFIG_FILE} $overrides diff --git a/run_memory_estimation.sh b/run_memory_estimation.sh new file mode 100755 index 00000000..02148b84 --- /dev/null +++ b/run_memory_estimation.sh @@ -0,0 +1,26 @@ +#!/usr/bin/bash +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +set -ex + +# use envs as local overrides for convenience +# e.g. +# NGPU=4 ./run_memory_estimation.sh +NGPU=${NGPU:-"8"} +NNODES=${NNODES:-"1"} +CONFIG_FILE=${CONFIG_FILE:-"./train_configs/debug_model.toml"} + +overrides="" +if [ $# -ne 0 ]; then + overrides="$*" +fi + +# Calculate WORLD_SIZE as the product of NGPU and NNODES +# Export WORLD_SIZE and LOCAL_RANK +export WORLD_SIZE=$((NGPU * NNODES)) +export LOCAL_RANK=0 +python estimation.py --job.config_file ${CONFIG_FILE} --memory_estimation.enabled $overrides diff --git a/test/datasets/test_checkpoint.py b/test/datasets/test_checkpoint.py index 6f04dd23..741c997f 100644 --- a/test/datasets/test_checkpoint.py +++ b/test/datasets/test_checkpoint.py @@ -6,7 +6,7 @@ import torch from torchtitan.datasets.hf_datasets import build_hf_data_loader -from torchtitan.datasets.tokenizer import create_tokenizer +from torchtitan.datasets.tokenizer import build_tokenizer class TestCheckpoint: @@ -42,7 +42,7 @@ def _build_dataloader( self, dataset_name, dataset_path, batch_size, seq_len, world_size, rank ): tokenizer_type = "tiktoken" - tokenizer = create_tokenizer("tiktoken", "./test/assets/test_tiktoken.model") + tokenizer = build_tokenizer("tiktoken", "./test/assets/test_tiktoken.model") return build_hf_data_loader( dataset_name=dataset_name, dataset_path=dataset_path, diff --git a/test_runner.py b/test_runner.py index 82492bbd..a7c95ce1 100755 --- a/test_runner.py +++ b/test_runner.py @@ -46,6 +46,21 @@ def build_test_list(): """ integration_tests_flavors = defaultdict(list) integration_tests_flavors["debug_model.toml"] = [ + OverrideDefinitions( + [ + [ + "--checkpoint.enable_checkpoint", + "--experimental.pipeline_parallel_degree 4", + "--experimental.pipeline_parallel_split_points layers.1,layers.2,layers.3,layers.4,layers.5,layers.6,layers.7", + "--experimental.pipeline_parallel_schedule flexible_interleaved_1f1b", + "--model.norm_type rmsnorm", # fused_rmsnorm throws cuda context error with pp + ], + ], + "PP looped flexible 1f1b test", + "pp_looped_flexible_1f1b", + requires_seed_checkpoint=True, + ngpu=4, + ), OverrideDefinitions( [ [ @@ -284,6 +299,16 @@ def build_test_list(): "fsdp2_mem_tracker", ngpu=4, ), + OverrideDefinitions( + [ + [ + "--training.data_parallel_type ddp", + ] + ], + "DDP", + "ddp", + ngpu=4, + ), ] return integration_tests_flavors @@ -315,6 +340,8 @@ def run_test(test_flavor: OverrideDefinitions, full_path: str, output_dir: str): for override_arg in test_flavor.override_args: cmd = f"CONFIG_FILE={full_path} NGPU={test_flavor.ngpu} LOG_RANK={all_ranks} ./run_llama_train.sh" + if test_name == "fsdp2_mem_tracker": + cmd = f"CONFIG_FILE={full_path} NGPU={test_flavor.ngpu} LOG_RANK={all_ranks} ./run_memory_estimation.sh" cmd += " " + dump_folder_arg cmd += " " + model_flavor_arg if override_arg: diff --git a/torchtitan/checkpoint.py b/torchtitan/checkpoint.py index 30317e3c..b71419c6 100644 --- a/torchtitan/checkpoint.py +++ b/torchtitan/checkpoint.py @@ -10,6 +10,8 @@ import re import shutil import time +from dataclasses import dataclass, field +from io import BytesIO from multiprocessing import get_context from typing import Any, Dict, List, Union @@ -27,7 +29,7 @@ from torch.distributed.checkpoint.stateful import Stateful from torch.utils.data import DataLoader from torchtitan.config_manager import JobConfig, TORCH_DTYPE_MAP -from torchtitan.logging_utils import init_logger, logger +from torchtitan.logging import init_logger, logger class IntervalType(enum.Enum): @@ -41,6 +43,43 @@ class AsyncMode(str, enum.Enum): ASYNC_WITH_PINNED_MEM = "async_with_pinned_mem" +@dataclass +class TrainState(Stateful): + step: int = 0 + global_avg_losses: List[float] = field(default_factory=list) + global_max_losses: List[float] = field(default_factory=list) + log_steps: List[int] = field(default_factory=list) + + def state_dict(self) -> Dict[str, Any]: + # Only checkpoint global_avg_losses and global_max_losses per log frequency + # to avoid sync overhead in every iteration. + global_avg_losses_bytes = BytesIO() + torch.save(self.global_avg_losses, global_avg_losses_bytes) + global_max_losses_bytes = BytesIO() + torch.save(self.global_max_losses, global_max_losses_bytes) + log_steps_bytes = BytesIO() + torch.save(self.log_steps, log_steps_bytes) + return { + "step": torch.tensor(self.step, dtype=torch.int32), + "global_avg_losses": global_avg_losses_bytes, + "global_max_losses": global_max_losses_bytes, + "log_steps": log_steps_bytes, + } + + def load_state_dict(self, state_dict) -> None: + self.step = state_dict["step"].item() + state_dict["global_avg_losses"].seek(0) + self.global_avg_losses = torch.load( + state_dict["global_avg_losses"], weights_only=False + ) + state_dict["global_max_losses"].seek(0) + self.global_max_losses = torch.load( + state_dict["global_max_losses"], weights_only=False + ) + state_dict["log_steps"].seek(0) + self.log_steps = torch.load(state_dict["log_steps"], weights_only=False) + + class ModelWrapper(Stateful): def __init__(self, model: Union[nn.Module, List[nn.Module]]) -> None: self.model = [model] if isinstance(model, nn.Module) else model @@ -124,10 +163,10 @@ def checkpoint_mp(recv, send): class CheckpointManager: def __init__( self, + dataloader: DataLoader, model_parts: List[nn.Module], optimizers: List[torch.optim.Optimizer], lr_schedulers: List[torch.optim.lr_scheduler.LRScheduler], - dataloader: DataLoader, states: Dict[str, Any], job_config: JobConfig, ) -> None: @@ -390,7 +429,7 @@ def save(self, curr_step: int, force: bool = False) -> None: f"in {time.monotonic() - begin:.2f} seconds." ) - def wait_for_staging(self) -> None: + def maybe_wait_for_staging(self) -> None: if ( self.enable_checkpoint and self.async_mode == AsyncMode.ASYNC_WITH_PINNED_MEM diff --git a/torchtitan/config_manager.py b/torchtitan/config_manager.py index 3ade1b9d..2bc37bfb 100644 --- a/torchtitan/config_manager.py +++ b/torchtitan/config_manager.py @@ -16,7 +16,7 @@ except ModuleNotFoundError: import tomli as tomllib -from torchtitan.logging_utils import logger +from torchtitan.logging import logger TORCH_DTYPE_MAP = { "float16": torch.float16, @@ -275,7 +275,7 @@ def __init__(self): self.parser.add_argument( "--experimental.pipeline_parallel_schedule", type=str, - choices=["1f1b", "gpipe", "interleaved_1f1b"], + choices=["1f1b", "gpipe", "interleaved_1f1b", "flexible_interleaved_1f1b"], default="1f1b", help=""" Specify the Pipeline Parallel schedule to use. @@ -312,6 +312,17 @@ def __init__(self): The default value will be the number of pipeline stages, if unspecified. """, ) + self.parser.add_argument( + "--training.data_parallel_type", + type=str, + default="fsdp", + help="Data parallelism type. TorchTitan currently supports FSDP and DDP.", + ) + self.parser.add_argument( + "--experimental.enable_compiled_autograd", + action="store_true", + help="Enable CompiledAutograd to compile the backward.", + ) self.parser.add_argument( "--training.mixed_precision_param", type=str, @@ -337,16 +348,6 @@ def __init__(self): action="store_true", help="Whether to compile the model", ) - self.parser.add_argument( - "--training.fp8_linear", - action="store_true", - help=""" - If true, swaps `torch.nn.Linear` with `Float8Linear` with - default settings (dynamic scaling). - This feature requires you to install 'float8_experimental' which can be found - here: https://github.com/pytorch-labs/float8_experimental - """, - ) self.parser.add_argument( "--training.gc_freq", type=int, @@ -442,6 +443,7 @@ def __init__(self): 0 is the default value. """, ) + # activation checkpointing configs self.parser.add_argument( "--activation_checkpoint.mode", @@ -459,6 +461,48 @@ def __init__(self): """, ) + # float8 configs + self.parser.add_argument( + "--float8.enable_float8_linear", + action="store_true", + help=""" + If true, swaps `torch.nn.Linear` with `Float8Linear`. + This feature requires you to install 'torchao' which can be found + here: https://github.com/pytorch/ao + """, + ) + self.parser.add_argument( + "--float8.enable_fsdp_float8_all_gather", + action="store_true", + default=False, + help="Whether enable float8 all-gather in FSDP", + ) + self.parser.add_argument( + "--float8.precompute_float8_dynamic_scale_for_fsdp", + action="store_true", + default=False, + help="Whether precompute float8 scales dynamically for FSDP", + ) + self.parser.add_argument( + "--float8.scaling_type_input", + type=str, + default="dynamic", + help="float8 scaling for input, dynamic (default) or delayed", + choices=["dynamic", "delayed"], + ) + self.parser.add_argument( + "--float8.scaling_type_weight", + type=str, + default="dynamic", + help="float8 scaling for input, dynamic (default) or delayed", + ) + self.parser.add_argument( + "--float8.scaling_type_grad_output", + type=str, + default="dynamic", + help="float8 scaling for input, dynamic (default) or delayed", + ) + # communications library settings self.parser.add_argument( "--comm.init_timeout_seconds", diff --git a/torchtitan/datasets/__init__.py b/torchtitan/datasets/__init__.py index e9a149c6..75ea6b66 100644 --- a/torchtitan/datasets/__init__.py +++ b/torchtitan/datasets/__init__.py @@ -5,9 +5,9 @@ # LICENSE file in the root directory of this source tree. from torchtitan.datasets.hf_datasets import build_hf_data_loader -from torchtitan.datasets.tokenizer import create_tokenizer +from torchtitan.datasets.tokenizer import build_tokenizer __all__ = [ "build_hf_data_loader", - "create_tokenizer", + "build_tokenizer", ] diff --git a/torchtitan/datasets/download_tokenizer.py b/torchtitan/datasets/download_tokenizer.py index 44ef5f59..a419d709 100644 --- a/torchtitan/datasets/download_tokenizer.py +++ b/torchtitan/datasets/download_tokenizer.py @@ -20,8 +20,8 @@ def hf_download( try: hf_hub_download( - repo_id, - tokenizer_path, + repo_id=repo_id, + filename=tokenizer_path, local_dir=local_dir, local_dir_use_symlinks=False, token=hf_token, diff --git a/torchtitan/datasets/hf_datasets.py b/torchtitan/datasets/hf_datasets.py index d8cd5d83..0b894e24 100644 --- a/torchtitan/datasets/hf_datasets.py +++ b/torchtitan/datasets/hf_datasets.py @@ -20,7 +20,7 @@ ) from e from torchtitan.datasets.tokenizer import Tokenizer -from torchtitan.logging_utils import logger +from torchtitan.logging import logger from datasets import load_dataset from datasets.distributed import split_dataset_by_node diff --git a/torchtitan/datasets/tokenizer/__init__.py b/torchtitan/datasets/tokenizer/__init__.py index 346caf83..7ff74722 100644 --- a/torchtitan/datasets/tokenizer/__init__.py +++ b/torchtitan/datasets/tokenizer/__init__.py @@ -8,10 +8,10 @@ from torchtitan.datasets.tokenizer.tiktoken import TikTokenizer from torchtitan.datasets.tokenizer.tokenizer import Tokenizer -from torchtitan.logging_utils import logger +from torchtitan.logging import logger -def create_tokenizer(tokenizer_type: str, tokenizer_path: str) -> Tokenizer: +def build_tokenizer(tokenizer_type: str, tokenizer_path: str) -> Tokenizer: logger.info(f"Building {tokenizer_type} tokenizer locally from {tokenizer_path}") if tokenizer_type == "sentencepiece": return SentencePieceTokenizer(tokenizer_path) diff --git a/torchtitan/datasets/tokenizer/sentencepiece.py b/torchtitan/datasets/tokenizer/sentencepiece.py index 7229daa3..c71afddd 100644 --- a/torchtitan/datasets/tokenizer/sentencepiece.py +++ b/torchtitan/datasets/tokenizer/sentencepiece.py @@ -11,7 +11,7 @@ from sentencepiece import SentencePieceProcessor from torchtitan.datasets.tokenizer.tokenizer import Tokenizer -from torchtitan.logging_utils import logger +from torchtitan.logging import logger class SentencePieceTokenizer(Tokenizer): diff --git a/torchtitan/datasets/tokenizer/tiktoken.py b/torchtitan/datasets/tokenizer/tiktoken.py index 1ec5de20..c879e7f3 100644 --- a/torchtitan/datasets/tokenizer/tiktoken.py +++ b/torchtitan/datasets/tokenizer/tiktoken.py @@ -26,7 +26,7 @@ from tiktoken.load import load_tiktoken_bpe from torchtitan.datasets.tokenizer.tokenizer import Tokenizer -from torchtitan.logging_utils import logger +from torchtitan.logging import logger class TikTokenizer(Tokenizer): diff --git a/torchtitan/float8_linear.py b/torchtitan/float8_linear.py index 0bd0900c..494b6046 100644 --- a/torchtitan/float8_linear.py +++ b/torchtitan/float8_linear.py @@ -4,39 +4,136 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -# [Note] Getting the 'float8_experimental' package: -# This script requires the 'float8_experimental' package to function correctly. +# [Note] Getting the 'torchao' package: +# This script requires the 'torchao' package to function correctly. # Please ensure you have this package installed from the appropriate repository. -# You can obtain it from https://github.com/pytorch-labs/float8_experimental. -# Either clone and run `pip install .` or run `pip install git+https://github.com/pytorch-labs/float8_experimental.git` +# You can obtain it from https://github.com/pytorch/ao by following the +# installation instructions. # Note: Performance # Float8 experimental is intended to be ran under `torch.compile`` for competitive performance +import torch import torch.nn as nn from torchtitan.config_manager import JobConfig -from torchtitan.logging_utils import logger +from torchtitan.logging import logger +from torchtitan.parallelisms import ParallelDims -def build_fp8_linear(model: nn.Module, job_config: JobConfig): - """ - This function converts the linear layers to `Float8Linear`. Note that today, - only dynamic tensor scaling (the default) is supported. +def is_sm90_or_later(): + # Float8 is only supported on H100+ GPUs + return torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0) + + +class Float8Handler: + def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims): + self.enabled = False + + float8_config = job_config.float8 + if not float8_config.enable_float8_linear: + return + if not is_sm90_or_later(): + logger.warning( + "Failed to swap to Float8Linear because SM90 or later is not available", + ) + return + try: + from torchao.float8 import CastConfig, Float8LinearConfig, ScalingType + except ImportError as e: + raise ImportError( + "torchao is not installed. Please install it to use fp8 linear layers." + ) from e - This will mutate the model inplace. - """ - use_fp8_linear = job_config.training.fp8_linear - try: - from float8_experimental.float8_linear import Float8Linear - from float8_experimental.float8_linear_utils import ( - swap_linear_with_float8_linear, - ) - except ImportError as exc: - raise ImportError( - "float8_experimental is not installed. Please install it to use fp8 linear layers." - ) from exc - if use_fp8_linear: # Mutates the model inplace replacing instances of torch.nn.Linear with Float8Linear - swap_linear_with_float8_linear(model, Float8Linear) - logger.info("Swapped to Float8Linear layers") + enable_fsdp_float8_all_gather = ( + parallel_dims.dp_enabled + and parallel_dims.dp_type == "fsdp" + and float8_config.enable_fsdp_float8_all_gather + ) + scaling_type_input = ScalingType(float8_config.scaling_type_input) + scaling_type_weight = ScalingType(float8_config.scaling_type_weight) + scaling_type_grad_output = ScalingType(float8_config.scaling_type_grad_output) + self.config = Float8LinearConfig( + enable_fsdp_float8_all_gather=enable_fsdp_float8_all_gather, + cast_config_input=CastConfig(scaling_type=scaling_type_input), + cast_config_weight=CastConfig(scaling_type=scaling_type_weight), + cast_config_grad_output=CastConfig(scaling_type=scaling_type_grad_output), + enable_pre_and_post_forward=False, + ) + + self.enabled = True + + # for precompute_fp8_dynamic_scale_for_fsdp + self.precompute_scale = ( + enable_fsdp_float8_all_gather + and float8_config.precompute_float8_dynamic_scale_for_fsdp + ) + + # for sync_float8_amax_and_scale_history + self.delayed_scaling = ( + scaling_type_input == "delayed" + or scaling_type_weight == "delayed" + or scaling_type_grad_output == "delayed" + ) + self._sync_float8_amax_and_scale_history = None + self.compile = job_config.training.compile + + logger.info("Float8 training active") + + def convert_to_float8_training(self, model: nn.Module): + """ + This function converts the linear layers of `model` to `Float8Linear`. + Note that today, only dynamic tensor scaling (the default) is supported. + This will mutate the model inplace. + """ + if not self.enabled: + return + + from torchao.float8 import convert_to_float8_training + + # Mutates the model inplace replacing instances of nn.Linear with Float8Linear + convert_to_float8_training( + model, + config=self.config, + module_filter_fn=lambda mod, fqn: fqn != "output", + ) + logger.info( + "Swapped to Float8Linear layers with enable_fsdp_float8_all_gather=" + f"{self.config.enable_fsdp_float8_all_gather}" + ) + + def precompute_fp8_dynamic_scale_for_fsdp(self, model: nn.Module): + if not self.enabled: + return + + if not self.precompute_scale: + return + + from torchao.float8 import precompute_float8_dynamic_scale_for_fsdp + + precompute_float8_dynamic_scale_for_fsdp(model) + + def sync_float8_amax_and_scale_history(self, model: nn.Module): + if not self.enabled: + return + + if not self.delayed_scaling: + return + + from torchao.float8 import sync_float8_amax_and_scale_history + + # TODO(vkuzo): see if precalculating the modules to sync over is going to + # meaningfully help performance + + if self._sync_float8_amax_and_scale_history is None: + if self.compile: + self._sync_float8_amax_and_scale_history = torch.compile( + sync_float8_amax_and_scale_history + ) + else: + self._sync_float8_amax_and_scale_history = ( + sync_float8_amax_and_scale_history + ) + + self._sync_float8_amax_and_scale_history(model) diff --git a/torchtitan/logging_utils.py b/torchtitan/logging.py similarity index 100% rename from torchtitan/logging_utils.py rename to torchtitan/logging.py diff --git a/torchtitan/lr_scheduling.py b/torchtitan/lr_scheduling.py deleted file mode 100644 index 35f39e13..00000000 --- a/torchtitan/lr_scheduling.py +++ /dev/null @@ -1,58 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -from torch.optim.lr_scheduler import LambdaLR -from torchtitan.config_manager import JobConfig - -# global states for scheduling -# these are needed as LambdaLR does not support argument passing -_warmup_steps = 200 -_decay_steps = 0 - - -def linear_warmup_linear_decay(current_step: int) -> float: - """Computes linear warmup followed by linear decay. - Per LambdaLR requirement, this is accomplished by returning - a multiplicative factor to adjust the learning rate to - create the desired schedule. - """ - if current_step < _warmup_steps: - # linear warmup - # 0-indexed step, hence + 1 adjustments - current_step += 1 - curr_adjustment = float(current_step / (_warmup_steps + 1)) - - else: - # linear decay - normalized_step = _decay_steps - (current_step - _warmup_steps) - curr_adjustment = 1 - (_decay_steps - normalized_step) / _decay_steps - - return curr_adjustment - - -def get_lr_schedulers(optimizers, job_config: JobConfig): - def _get_lr_scheduler(optimizer): - """Build a linear warmup and linear decay scheduler""" - global _warmup_steps, _decay_steps - _warmup_steps = int(job_config.training.warmup_steps) - _decay_steps = float(max(1, job_config.training.steps - _warmup_steps)) - - warmup_scheduler = LambdaLR(optimizer, lr_lambda=linear_warmup_linear_decay) - return warmup_scheduler - - class SchedulersContainer: - """Util for calling step on multiple learning rate schedulers needed for virtual pipeline stages""" - - def __init__(self, schedulers): - self.schedulers = schedulers - - def step(self): - for schedulers in self.schedulers: - schedulers.step() - - return SchedulersContainer( - [_get_lr_scheduler(optimizer) for optimizer in optimizers] - ) diff --git a/torchtitan/metrics.py b/torchtitan/metrics.py index 1717439b..f86ccc98 100644 --- a/torchtitan/metrics.py +++ b/torchtitan/metrics.py @@ -12,7 +12,8 @@ import torch from torch.utils.tensorboard import SummaryWriter from torchtitan.config_manager import JobConfig -from torchtitan.logging_utils import logger +from torchtitan.logging import logger +from torchtitan.parallelisms import ParallelDims # named tuple for passing GPU memory stats for logging GPUMemStats = namedtuple( @@ -110,16 +111,29 @@ def close(self): self.writer.close() +def _get_metrics_rank(parallel_dims: ParallelDims) -> int: + """ + Returns global rank 0 in non-pipeline-parallel configs, and returns the global + rank of the 0th rank in the last pipeline stage when pipeline parallelism is enabled. + """ + if parallel_dims.pp_enabled: + world_size = parallel_dims.world_size + pp_size = parallel_dims.pp + metrics_log_rank = (world_size // pp_size) * (pp_size - 1) + else: + metrics_log_rank = 0 + + return metrics_log_rank + + def build_metric_logger( - config: JobConfig, metrics_log_rank: int = 0, tag: Optional[str] = None + config: JobConfig, parallel_dims: ParallelDims, tag: Optional[str] = None ): """ - metrics_log_rank controls which rank acts as 'rank 0' for logging metrics. - - If 'tb_config.rank_0_only' is set, then `metrics_log_rank` will be used as the rank to log metrics. - This is intended to allow logging from the 0th rank within the last pipeline stage group, in case pipeline - parallelism is enabled, without forcing logging from all ranks to capture loss information when using pipeline - parallelism. + parallel_dims is used to determine the rank to log metrics from if 'tb_config.rank_0_only=True'. + In that case, `_get_metrics_rank` will be used to calculate which rank acts as 'rank 0'. This is + intended to allow logging from the 0th rank within the last pipeline stage group, in case pipeline + parallelism is enabled, without forcing logging from all ranks to capture loss information. """ dump_dir = config.job.dump_folder tb_config = config.metrics @@ -134,7 +148,7 @@ def build_metric_logger( f"Metrics logging active. Tensorboard logs will be saved at {log_dir}" ) if tb_config.rank_0_only: - enable_tb = torch.distributed.get_rank() == metrics_log_rank + enable_tb = torch.distributed.get_rank() == _get_metrics_rank(parallel_dims) else: rank_str = f"rank_{torch.distributed.get_rank()}" log_dir = os.path.join(log_dir, rank_str) diff --git a/torchtitan/models/llama/__init__.py b/torchtitan/models/llama/__init__.py index 3cdfe0f9..887a96cd 100644 --- a/torchtitan/models/llama/__init__.py +++ b/torchtitan/models/llama/__init__.py @@ -48,4 +48,13 @@ multiple_of=4096, rope_theta=500000, ), + "405B": ModelArgs( + dim=16384, + n_layers=126, + n_heads=128, + n_kv_heads=8, + ffn_dim_multiplier=1.2, + multiple_of=4096, + rope_theta=500000, + ), } diff --git a/torchtitan/models/llama/model.py b/torchtitan/models/llama/model.py index 49cda624..e357f432 100644 --- a/torchtitan/models/llama/model.py +++ b/torchtitan/models/llama/model.py @@ -14,7 +14,7 @@ import torch import torch.nn.functional as F from torch import nn -from torchtitan.models.norms import create_norm +from torchtitan.models.norms import build_norm @dataclass @@ -190,9 +190,12 @@ def forward( bs, seqlen, _ = x.shape xq, xk, xv = self.wq(x), self.wk(x), self.wv(x) - xq = xq.view(bs, seqlen, self.n_heads, self.head_dim) - xk = xk.view(bs, seqlen, self.n_kv_heads, self.head_dim) - xv = xv.view(bs, seqlen, self.n_kv_heads, self.head_dim) + # Use -1 instead of `n_heads` (or `n_kv_heads`) to infer the actual + # local heads from sizes of xq, xk, and xv as TP may have sharded them + # after the above linear ops. + xq = xq.view(bs, seqlen, -1, self.head_dim) + xk = xk.view(bs, seqlen, -1, self.head_dim) + xv = xv.view(bs, seqlen, -1, self.head_dim) xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis) @@ -291,10 +294,10 @@ def __init__(self, layer_id: int, model_args: ModelArgs): self.layer_id = layer_id self.num_layers = model_args.n_layers - self.attention_norm = create_norm( + self.attention_norm = build_norm( model_args.norm_type, dim=model_args.dim, eps=model_args.norm_eps ) - self.ffn_norm = create_norm( + self.ffn_norm = build_norm( model_args.norm_type, dim=model_args.dim, eps=model_args.norm_eps ) @@ -370,7 +373,7 @@ def __init__(self, model_args: ModelArgs): for layer_id in range(model_args.n_layers): self.layers[str(layer_id)] = TransformerBlock(layer_id, model_args) - self.norm = create_norm( + self.norm = build_norm( model_args.norm_type, dim=model_args.dim, eps=model_args.norm_eps ) diff --git a/torchtitan/models/norms.py b/torchtitan/models/norms.py index 10a6b853..c0ef6a80 100644 --- a/torchtitan/models/norms.py +++ b/torchtitan/models/norms.py @@ -18,18 +18,18 @@ from torch.distributed._tensor.experimental import local_map -def create_norm(norm_type: str, dim: int, eps: float = 1e-6): +def build_norm(norm_type: str, dim: int, eps: float = 1e-6): """ - Creates the specified normalization layer based on the norm_type. + Builds the specified normalization layer based on the norm_type. Args: - norm_type (str): The type of normalization layer to create. + norm_type (str): The type of normalization layer to build. Supported types: 1. rmsnorm 2. fused_rmsnorm 3. layernorm 4. np_layernorm dim (int): The dimension of the normalization layer. eps (float, optional): The epsilon value for numerical stability. Defaults to 1e-6. Returns: - The created normalization layer. + The built normalization layer. Raises: NotImplementedError: If an unknown norm_type is provided. diff --git a/torchtitan/optimizer.py b/torchtitan/optimizer.py new file mode 100644 index 00000000..3f9eb3a8 --- /dev/null +++ b/torchtitan/optimizer.py @@ -0,0 +1,105 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import functools + +import torch +from torch.optim.lr_scheduler import LambdaLR +from torchtitan.config_manager import JobConfig + + +# consider split between PP and non-PP +def build_optimizers(model_parts, job_config: JobConfig): + """Wrap one optimizer per model part in an OptimizersContainer which provides a single + step() and zero_grad() method for all the child optimizers. + """ + + def _build_optimizer(model): + name = job_config.optimizer.name + lr = job_config.optimizer.lr + fused = job_config.optimizer.fused + + # Common parameters for both optimizers + optimizer_kwargs = { + "lr": lr, + "betas": (0.9, 0.95), + "weight_decay": 0.1, + "fused": fused, + "foreach": not fused, + } + if name == "Adam": + # TODO: make the optimizer options configurable by toml/cmd args + optimizer = torch.optim.Adam(model.parameters(), **optimizer_kwargs) + elif name == "AdamW": + optimizer = torch.optim.AdamW(model.parameters(), **optimizer_kwargs) + else: + raise NotImplementedError(f"Optimizer {name} not added.") + + return optimizer + + class OptimizersContainer: + """Util for calling step/zero_grad on multiple optimizers needed for virtual pipeline stages""" + + def __init__(self, optimizers): + self.optimizers = optimizers + + def step(self): + for optimizer in self.optimizers: + optimizer.step() + + def zero_grad(self): + for optimizer in self.optimizers: + optimizer.zero_grad() + + return OptimizersContainer([_build_optimizer(model) for model in model_parts]) + + +def linear_warmup_linear_decay( + warmup_steps: int, decay_steps: int, current_step: int +) -> float: + """Computes linear warmup followed by linear decay. + Per LambdaLR requirement, this is accomplished by returning + a multiplicative factor to adjust the learning rate to + create the desired schedule. + """ + if current_step < warmup_steps: + # linear warmup + # 0-indexed step, hence + 1 adjustments + current_step += 1 + curr_adjustment = float(current_step / (warmup_steps + 1)) + + else: + # linear decay + normalized_step = decay_steps - (current_step - warmup_steps) + curr_adjustment = 1 - (decay_steps - normalized_step) / decay_steps + + return curr_adjustment + + +def build_lr_schedulers(optimizers, job_config: JobConfig): + def _build_lr_scheduler(optimizer): + """Build a linear warmup and linear decay scheduler""" + warmup_steps = int(job_config.training.warmup_steps) + decay_steps = float(max(1, job_config.training.steps - warmup_steps)) + lr_lambda = functools.partial( + linear_warmup_linear_decay, warmup_steps, decay_steps + ) + warmup_scheduler = LambdaLR(optimizer, lr_lambda=lr_lambda) + return warmup_scheduler + + class SchedulersContainer: + """Util for calling step on multiple learning rate schedulers needed for virtual pipeline stages""" + + def __init__(self, schedulers): + self.schedulers = schedulers + + def step(self): + for schedulers in self.schedulers: + schedulers.step() + + return SchedulersContainer( + [_build_lr_scheduler(optimizer) for optimizer in optimizers] + ) diff --git a/torchtitan/parallelisms/__init__.py b/torchtitan/parallelisms/__init__.py index 7e1b21c7..7188474d 100644 --- a/torchtitan/parallelisms/__init__.py +++ b/torchtitan/parallelisms/__init__.py @@ -8,8 +8,17 @@ from functools import cached_property from torch.distributed.device_mesh import init_device_mesh -from torchtitan.logging_utils import logger +from torchtitan.logging import logger from torchtitan.parallelisms.parallelize_llama import parallelize_llama, pipeline_llama +from torchtitan.parallelisms.pipelining_utils import build_pipeline_schedule + + +__all__ = [ + "build_pipeline_schedule", + "models_parallelize_fns", + "models_pipelining_fns", + "ParallelDims", +] models_parallelize_fns = { "llama2": parallelize_llama, @@ -28,8 +37,10 @@ class ParallelDims: pp: int world_size: int enable_loss_parallel: bool + dp_type: str def __post_init__(self): + self.dp_type = self.dp_type.lower() self._validate() def _validate(self): @@ -42,6 +53,7 @@ def _validate(self): assert ( dp * tp * pp == self.world_size ), f"Invalid parallel dims: dp({dp}) * tp({tp}) * pp({pp}) != WORLD_SIZE({self.world_size})" + assert self.dp_type in ("fsdp", "ddp") def build_mesh(self, device_type): dims = [] diff --git a/torchtitan/parallelisms/parallelize_llama.py b/torchtitan/parallelisms/parallelize_llama.py index 90b7edba..a4b69344 100644 --- a/torchtitan/parallelisms/parallelize_llama.py +++ b/torchtitan/parallelisms/parallelize_llama.py @@ -9,11 +9,15 @@ import copy from collections import defaultdict -from typing import Dict, Tuple +from typing import Tuple, TYPE_CHECKING, Union import torch +import torch.nn as nn +from torch.distributed import DeviceMesh from torch.distributed._composable.fsdp import fully_shard, MixedPrecisionPolicy + +from torch.distributed._composable.replicate import replicate from torch.distributed._tensor import Replicate, Shard from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( checkpoint_wrapper as ptd_checkpoint_wrapper, @@ -28,9 +32,16 @@ ) from torchtitan.config_manager import JobConfig, TORCH_DTYPE_MAP -from torchtitan.logging_utils import logger +from torchtitan.logging import logger +from torchtitan.models.llama.model import ModelArgs from torchtitan.parallelisms.pipelining_utils import stage_ids_this_rank +if TYPE_CHECKING: + from torchtitan.parallelisms import ParallelDims + + +DeviceType = Union[int, str, torch.device] + # for selective AC no_recompute_list = { torch.ops.aten.mm.default, @@ -106,15 +117,20 @@ def selective_checkpointing_context_fn(): return module -def get_tp_parallel_strategy( - job_config: JobConfig, +def get_tp_parallel_strategy_for_transformer_block( + enable_float8: bool, ) -> Tuple[RowwiseParallel, ColwiseParallel, PrepareModuleInput]: """Get the parallel strategy for the transformer model. This function handles the special case of using float8 with tensor parallelism. """ - if job_config.training.fp8_linear == "dynamic": - from float8_experimental.float8_tensor_parallel import ( + if enable_float8: + # TODO(vkuzo): once float8 configuration supports delayed + # scaling, add a check here to enforce supported float8 all-gather + # configurations + # TODO(vkuzo): add the items below to __init__.py of torchao.float8, + # and import from there + from torchao.float8.float8_tensor_parallel import ( Float8ColwiseParallel, Float8RowwiseParallel, PrepareFloat8ModuleInput, @@ -125,23 +141,30 @@ def get_tp_parallel_strategy( def pipeline_llama( - model, world_mesh, parallel_dims, job_config: JobConfig, device, model_config: Dict + model: nn.Module, + pp_mesh: DeviceMesh, + parallel_dims: "ParallelDims", + job_config: JobConfig, + device: DeviceType, + model_config: ModelArgs, ): - if job_config.experimental.pipeline_parallel_split_mode == "manual": + split_mode = job_config.experimental.pipeline_parallel_split_mode + valid_split_modes = ("manual", "tracer") + if split_mode not in valid_split_modes: + raise ValueError( + f"Invalid split mode: {split_mode}. Valid split modes: {valid_split_modes}" + ) + if split_mode == "manual": return pipeline_llama_manual( - model, world_mesh, parallel_dims, job_config, device, model_config + model, pp_mesh, parallel_dims, job_config, device, model_config ) - elif job_config.experimental.pipeline_parallel_split_mode == "tracer": + elif split_mode == "tracer": return pipeline_llama_tracer( - model, world_mesh, parallel_dims, job_config, device, model_config - ) - else: - raise NotImplementedError( - f"{job_config.experimental.pipeline_parallel_split_mode} is not a valid split mode" + model, pp_mesh, parallel_dims, job_config, device, model_config ) -def _llama_trace_input(job_config, model_config, device="meta"): +def _llama_trace_input(job_config: JobConfig, model_config: ModelArgs, device="meta"): """Get meta tensors with the right input shapes used for tracing""" tokens_shape = (job_config.training.batch_size, job_config.training.seq_len) tokens = torch.randint( @@ -153,18 +176,18 @@ def _llama_trace_input(job_config, model_config, device="meta"): def _mixed_precision_dtype( job_config: JobConfig, parallel_dims, default: torch.dtype = torch.float32 ) -> torch.dtype: - """Get the mixed precision dtype if fsdp is enabled, otherwise return the default""" + """Get the mixed precision dtype if FSDP is enabled, otherwise return the default""" mp_arg = job_config.training.mixed_precision_param return TORCH_DTYPE_MAP[mp_arg] if parallel_dims.dp_enabled else default def pipeline_llama_manual( - whole_model, - world_mesh, - parallel_dims, + whole_model: nn.Module, + pp_mesh: DeviceMesh, + parallel_dims: "ParallelDims", job_config: JobConfig, - device, - model_config: Dict, + device: DeviceType, + model_config: ModelArgs, ): """ This API extracts one torch.nn.Module objects for the part of the model configured to run inside this stage. @@ -174,7 +197,6 @@ def pipeline_llama_manual( The stage object is used to create a pipeline schedule, and the model object can be used for applying SPMD parallelism. """ - pp_mesh = world_mesh["pp"] pp_rank = pp_mesh.get_local_rank() pp_size = pp_mesh.size() microbatches = ( @@ -262,22 +284,26 @@ def _build_stage(stage_idx, start_layer, stop_layer, is_first=False, is_last=Fal def pipeline_llama_tracer( - model, world_mesh, parallel_dims, job_config: JobConfig, device, model_config: Dict + model: nn.Module, + pp_mesh: DeviceMesh, + parallel_dims: "ParallelDims", + job_config: JobConfig, + device: DeviceType, + model_config: ModelArgs, ): if job_config.model.norm_type == "fused_rmsnorm": - # TODO(whc) - torch._dynamo.exc.Unsupported: Illegal getattr invocation stride in strict mode - # coming from ` if dy.stride(-1) != 1:` in fused_rmsnorm + # TODO(whc) - torch._dynamo.exc.Unsupported: Illegal getattr + # invocation stride in strict mode from `if dy.stride(-1) != 1:` in + # fused_rmsnorm raise NotImplementedError( - "fused_rmsnorm not yet compatible with Pipeline Tracer (strides error). Please use layernorm or rmsnorm." + "fused_rmsnorm is not compatible with Pipeline Tracer yet. Please use rmsnorm or layernorm." ) - - if _mixed_precision_dtype(job_config, parallel_dims) == torch.bfloat16: + if _mixed_precision_dtype(job_config, parallel_dims) != torch.float32: raise NotImplementedError( - "pipeline tracer doesn't work with fsdp mixed precision currently. " - "To work around, edit fsdp mixed precision config to use fp32." + "Pipeline tracer does not work with FSDP mixed precision yet. " + "To work around, set mixed_precision_param to float32." ) - pp_mesh = world_mesh["pp"] pp_rank = pp_mesh.get_local_rank() pp_size = pp_mesh.size() microbatches = ( @@ -310,19 +336,14 @@ def pipeline_llama_tracer( return (stages, models) -def apply_tp(model, world_mesh, parallel_dims, job_config: JobConfig): - """ - Apply tensor parallelism. - """ - - tp_mesh = world_mesh["tp"] - ( - row_parallel_strategy, - col_parallel_strategy, - prepare_module_input, - ) = get_tp_parallel_strategy(job_config) - loss_parallel = parallel_dims.loss_parallel_enabled - +def apply_tp( + model: nn.Module, + tp_mesh: DeviceMesh, + loss_parallel: bool, + enable_float8: bool, + enable_async_tp: bool, +): + """Apply tensor parallelism.""" # 1. Parallelize the embedding and shard its outputs (which are the first # transformer block's inputs) # 2. Parallelize the root norm layer over the sequence dim @@ -336,7 +357,7 @@ def apply_tp(model, world_mesh, parallel_dims, job_config: JobConfig): output_layouts=Shard(1), ), "norm": SequenceParallel(), - "output": col_parallel_strategy( + "output": ColwiseParallel( input_layouts=Shard(1), output_layouts=Shard(-1) if loss_parallel else Replicate(), use_local_output=not loss_parallel, @@ -344,6 +365,14 @@ def apply_tp(model, world_mesh, parallel_dims, job_config: JobConfig): }, ) + # Parallel styles used for transformer block linear weights and their + # inputs may be different for float8 linears + ( + rowwise_parallel_weight, + colwise_parallel_weight, + prepare_module_input, + ) = get_tp_parallel_strategy_for_transformer_block(enable_float8) + # Apply tensor + sequence parallelism to every transformer block # NOTE: At the cost of model code change, we can accelerate Sequence Parallel # by folding (and unfolding) the batch dimension and the sequence dimension. @@ -355,48 +384,53 @@ def apply_tp(model, world_mesh, parallel_dims, job_config: JobConfig): input_layouts=(Shard(1), None), desired_input_layouts=(Replicate(), None), ), - "attention.wq": col_parallel_strategy(), - "attention.wk": col_parallel_strategy(), - "attention.wv": col_parallel_strategy(), - "attention.wo": row_parallel_strategy(output_layouts=Shard(1)), + "attention.wq": colwise_parallel_weight(), + "attention.wk": colwise_parallel_weight(), + "attention.wv": colwise_parallel_weight(), + "attention.wo": rowwise_parallel_weight(output_layouts=Shard(1)), "ffn_norm": SequenceParallel(), "feed_forward": prepare_module_input( input_layouts=(Shard(1),), desired_input_layouts=(Replicate(),), ), - "feed_forward.w1": col_parallel_strategy(), - "feed_forward.w2": row_parallel_strategy(output_layouts=Shard(1)), - "feed_forward.w3": col_parallel_strategy(), + "feed_forward.w1": colwise_parallel_weight(), + "feed_forward.w2": rowwise_parallel_weight(output_layouts=Shard(1)), + "feed_forward.w3": colwise_parallel_weight(), } - # Adjust attention module to use the local number of heads - attn_layer = transformer_block.attention - attn_layer.n_heads = attn_layer.n_heads // tp_mesh.size() - attn_layer.n_kv_heads = attn_layer.n_kv_heads // tp_mesh.size() - parallelize_module( module=transformer_block, device_mesh=tp_mesh, parallelize_plan=layer_plan, ) - if job_config.experimental.enable_async_tensor_parallel: + # updates expressly for async tensor parallel + if enable_async_tp: from torch.distributed._symmetric_memory import enable_symm_mem_for_group + torch._dynamo.config.cache_size_limit = 10000 + logger.info( + "Updating torch._dynamo.config.cache_size_limit to 10000 to support Async TP" + ) + torch._inductor.config._micro_pipeline_tp = True enable_symm_mem_for_group(tp_mesh.get_group().group_name) - logger.info("Applied Tensor Parallelism to the model") - return model - + if not job_config.training.compile: + logger.warning( + "Async TP requires compilation...auto enabling compile = True for this job to resolve." + ) + job_config.training.compile = True -def apply_ac(model, job_config: JobConfig): - """ - Apply activation checkpointing to the model. - """ + logger.info( + f"Applied {'Async ' if enable_async_tp else ''}" + "Tensor Parallelism to the model" + ) + return model - ac_config = job_config.activation_checkpoint +def apply_ac(model: nn.Module, ac_config: JobConfig): + """Apply activation checkpointing to the model.""" for layer_id, transformer_block in model.layers.named_children(): transformer_block = checkpoint_wrapper(transformer_block, ac_config) model.layers.register_module(layer_id, transformer_block) @@ -405,19 +439,13 @@ def apply_ac(model, job_config: JobConfig): return model -def apply_compile(model, job_config: JobConfig): - """ - Apply torch.compile to the model. - """ - - if job_config.model.norm_type == "fused_rmsnorm": - raise NotImplementedError( - "fused_rmsnorm not yet compatible with torch.compile. Please use layernorm or rmsnorm." - ) +def apply_compile(model: nn.Module): + """Apply torch.compile to each transformer block.""" - # TODO(anijain): the following flag is on to accelarate compilation - # remove it after it's enabled in pytorch by default - torch._dynamo.config.inline_inbuilt_nn_modules = True + # the following flag can be used to to accelarate per-block compilation + # TODO(bdhirsh): turning it off because it's currently not working with 2D + # TODO(anijain): remove it after it's enabled in pytorch by default + # torch._dynamo.config.inline_inbuilt_nn_modules = True for layer_id, transformer_block in model.layers.named_children(): # turn on per-transformer block compile after AC wrapping and before FSDP @@ -428,22 +456,21 @@ def apply_compile(model, job_config: JobConfig): return model -def apply_dp(model, world_mesh, parallel_dims, job_config: JobConfig): +def apply_fsdp( + model: nn.Module, + dp_mesh: DeviceMesh, + param_dtype: torch.dtype, + reduce_dtype: torch.dtype, + pp_enabled: bool, +): """ - Apply data parallelism (FSDP2) to the model. + Apply data parallelism to the model. FSDP2 is used here. """ - - dp_mesh = world_mesh["dp"] if world_mesh.ndim > 1 else world_mesh - assert dp_mesh.mesh_dim_names == ("dp",), dp_mesh.mesh_dim_names - - mp_policy = MixedPrecisionPolicy( - param_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_param], - reduce_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce], - ) + mp_policy = MixedPrecisionPolicy(param_dtype=param_dtype, reduce_dtype=reduce_dtype) fsdp_config = {"mesh": dp_mesh, "mp_policy": mp_policy} for layer_id, transformer_block in model.layers.items(): - if parallel_dims.pp_enabled: + if pp_enabled: # For PP, do not reshard after forward to avoid per-microbatch # all-gathers, which can be expensive and non-overlapped reshard_after_forward = False @@ -456,15 +483,49 @@ def apply_dp(model, world_mesh, parallel_dims, job_config: JobConfig): **fsdp_config, reshard_after_forward=reshard_after_forward, ) - fully_shard( - model, **fsdp_config, reshard_after_forward=not parallel_dims.pp_enabled - ) - + fully_shard(model, **fsdp_config, reshard_after_forward=not pp_enabled) + + if pp_enabled: + # TODO + # This PR https://github.com/pytorch/pytorch/pull/129519 added a safety check to avoid using 2D/3D DCP since + # without strided sharding, DCP can not safely support resharding for 2D/3D. However, for PP to work, even + # without resharding, we load a seed-checkpoint and need to disable the safety mechanism. This hack should be + # removed after strided sharding is landed in DCP. + for module in model.modules(): + assert len(module._load_state_dict_pre_hooks) <= 1 + module._load_state_dict_pre_hooks.clear() + assert len(module._state_dict_pre_hooks) <= 1 + module._state_dict_pre_hooks.clear() logger.info("Applied FSDP to the model") return model -def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig): +def apply_ddp( + model: nn.Module, + dp_mesh: DeviceMesh, + enable_compile: bool, + enable_compiled_autograd: bool, +): + if enable_compile: + if enable_compiled_autograd: + torch._dynamo.config.optimize_ddp = ( + "python_reducer_without_compiled_forward" + ) + else: + torch._dynamo.config.optimize_ddp = "ddp_optimizer" + + model = replicate(model, device_mesh=dp_mesh, bucket_cap_mb=100) + + logger.info("Applied DDP to the model") + return model + + +def parallelize_llama( + model: nn.Module, + world_mesh: DeviceMesh, + parallel_dims: "ParallelDims", + job_config: JobConfig, +): """ Apply tensor parallelism, activation checkpointing, torch.compile, and data parallelism to the model. @@ -474,15 +535,46 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig): """ if parallel_dims.tp_enabled: - model = apply_tp(model, world_mesh, parallel_dims, job_config) + model = apply_tp( + model, + world_mesh["tp"], + loss_parallel=parallel_dims.loss_parallel_enabled, + enable_float8=job_config.float8.enable_float8_linear, + enable_async_tp=job_config.experimental.enable_async_tensor_parallel, + ) if job_config.activation_checkpoint.mode != "none": - model = apply_ac(model, job_config) + model = apply_ac(model, job_config.activation_checkpoint) if job_config.training.compile: - model = apply_compile(model, job_config) + if job_config.model.norm_type == "fused_rmsnorm": + raise NotImplementedError( + "fused_rmsnorm is not compatible with torch.compile yet. Please use rmsnorm or layernorm." + ) + model = apply_compile(model) if parallel_dims.dp_enabled: - model = apply_dp(model, world_mesh, parallel_dims, job_config) + if parallel_dims.dp_type == "fsdp": + dp_mesh = world_mesh["dp"] if world_mesh.ndim > 1 else world_mesh + assert dp_mesh.mesh_dim_names == ("dp",), dp_mesh.mesh_dim_names + + model = apply_fsdp( + model, + dp_mesh, + param_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_param], + reduce_dtype=TORCH_DTYPE_MAP[ + job_config.training.mixed_precision_reduce + ], + pp_enabled=parallel_dims.pp_enabled, + ) + else: + if world_mesh.ndim > 1: + raise RuntimeError("DDP has not supported > 1D parallelism.") + model = apply_ddp( + model, + world_mesh, + enable_compile=job_config.training.compile, + enable_compiled_autograd=job_config.experimental.enable_compiled_autograd, + ) return model diff --git a/torchtitan/parallelisms/pipelining_utils.py b/torchtitan/parallelisms/pipelining_utils.py index e60b7f51..aafe70fa 100644 --- a/torchtitan/parallelisms/pipelining_utils.py +++ b/torchtitan/parallelisms/pipelining_utils.py @@ -7,15 +7,16 @@ from torch.distributed.pipelining import ( Schedule1F1B, + ScheduleFlexibleInterleaved1F1B, ScheduleGPipe, ScheduleInterleaved1F1B, ) -from torchtitan.logging_utils import logger +from torchtitan.logging import logger def build_pipeline_schedule(job_config, parallel_dims, stages, loss_fn): - looped_schedule = False + if job_config.experimental.pipeline_parallel_schedule == "1f1b": schedule_class = Schedule1F1B elif job_config.experimental.pipeline_parallel_schedule == "gpipe": @@ -23,6 +24,12 @@ def build_pipeline_schedule(job_config, parallel_dims, stages, loss_fn): elif job_config.experimental.pipeline_parallel_schedule == "interleaved_1f1b": schedule_class = ScheduleInterleaved1F1B looped_schedule = True + elif ( + job_config.experimental.pipeline_parallel_schedule + == "flexible_interleaved_1f1b" + ): + schedule_class = ScheduleFlexibleInterleaved1F1B + looped_schedule = True else: raise NotImplementedError( f"{job_config.experimental.pipeline_parallel_schedule} is not implemented" diff --git a/torchtitan/profiling.py b/torchtitan/profiling.py index 662b64f8..9da5c8fb 100644 --- a/torchtitan/profiling.py +++ b/torchtitan/profiling.py @@ -11,7 +11,7 @@ import torch from torchtitan.config_manager import JobConfig -from torchtitan.logging_utils import logger +from torchtitan.logging import logger # the number of warmup steps before the active step in each profiling cycle WARMUP = 3 diff --git a/torchtitan/utils.py b/torchtitan/utils.py index c2983660..3ed74d13 100644 --- a/torchtitan/utils.py +++ b/torchtitan/utils.py @@ -4,6 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import gc import os from dataclasses import dataclass from datetime import timedelta @@ -13,18 +14,17 @@ import torch.distributed._functional_collectives as funcol import torch.distributed.distributed_c10d as c10d from torch.distributed.device_mesh import DeviceMesh -from torchtitan.logging_utils import logger -from torchtitan.parallelisms import ParallelDims +from torchtitan.logging import logger def dist_max(x: Union[int, float], mesh: DeviceMesh) -> float: tensor = torch.tensor(x).cuda() - return funcol.all_reduce(tensor, reduceOp=c10d.ReduceOp.MAX.name, group=mesh) + return funcol.all_reduce(tensor, reduceOp=c10d.ReduceOp.MAX.name, group=mesh).item() def dist_mean(x: Union[int, float], mesh: DeviceMesh) -> float: tensor = torch.tensor(x).cuda() - return funcol.all_reduce(tensor, reduceOp=c10d.ReduceOp.AVG.name, group=mesh) + return funcol.all_reduce(tensor, reduceOp=c10d.ReduceOp.AVG.name, group=mesh).item() def _warn_overwrite_env(env, val): @@ -35,24 +35,6 @@ def _warn_overwrite_env(env, val): os.environ[env] = val -def get_metrics_rank(world_mesh: DeviceMesh, parallel_dims: ParallelDims) -> int: - """ - Returns global rank 0 in non-pipeline-parallel configs, and returns the global - rank of the 0th rank in the last pipeline stage when pipeline parallelism is enabled. - """ - if parallel_dims.pp_enabled: - assert ( - world_mesh.mesh_dim_names[0] == "pp" - ), "get_metrics_rank assumes pp is the outer mesh dim" - pp_mesh = world_mesh["pp"] - pp_size = pp_mesh.size() - metrics_log_rank = int((world_mesh.size() // pp_size) * (pp_size - 1)) - else: - metrics_log_rank = 0 - - return metrics_log_rank - - def set_pg_timeouts(timeout, world_mesh): """ Sets the timeout for all PGs in the provided mesh, and the default (world) group. @@ -80,6 +62,19 @@ def set_pg_timeouts(timeout, world_mesh): torch.distributed.distributed_c10d._set_pg_timeout(timeout, group) +# used to avoid stragglers in garbage collection +class GarbageCollection: + def __init__(self, gc_freq=1000): + assert gc_freq > 0, "gc_freq must be a positive integer" + self.gc_freq = gc_freq + gc.disable() + gc.collect(1) + + def run(self, step_count): + if step_count > 1 and step_count % self.gc_freq == 0: + gc.collect(1) + + TRACE_BUFFER_SIZE = "TORCH_NCCL_TRACE_BUFFER_SIZE" TRACE_FILE = "TORCH_NCCL_DEBUG_INFO_TEMP_FILE" DUMP_ON_TIMEOUT = "TORCH_NCCL_DUMP_ON_TIMEOUT" diff --git a/train.py b/train.py index 8e55c210..615ed4e3 100644 --- a/train.py +++ b/train.py @@ -5,134 +5,43 @@ # LICENSE file in the root directory of this source tree. import contextlib -import gc import os import time - -from dataclasses import dataclass, field from datetime import timedelta -from io import BytesIO -from timeit import default_timer as timer -from typing import Any, Dict, List - -import numpy as np import torch -import torch.nn.functional as F -from torch.distributed import destroy_process_group -from torch.distributed.checkpoint.stateful import Stateful +import torchtitan.utils as utils from torch.distributed.elastic.multiprocessing.errors import record -from torch.distributed.tensor.parallel import loss_parallel - -from torchtitan.checkpoint import CheckpointManager +from torchtitan.checkpoint import CheckpointManager, TrainState from torchtitan.config_manager import JobConfig -from torchtitan.datasets import build_hf_data_loader, create_tokenizer -from torchtitan.float8_linear import build_fp8_linear -from torchtitan.logging_utils import init_logger, logger -from torchtitan.lr_scheduling import get_lr_schedulers +from torchtitan.datasets import build_hf_data_loader, build_tokenizer +from torchtitan.float8_linear import Float8Handler +from torchtitan.logging import init_logger, logger from torchtitan.metrics import build_gpu_memory_monitor, build_metric_logger from torchtitan.models import model_name_to_cls, model_name_to_tokenizer, models_config +from torchtitan.optimizer import build_lr_schedulers, build_optimizers from torchtitan.parallelisms import ( + build_pipeline_schedule, models_parallelize_fns, models_pipelining_fns, ParallelDims, ) -from torchtitan.parallelisms.pipelining_utils import build_pipeline_schedule from torchtitan.profiling import maybe_enable_memory_snapshot, maybe_enable_profiling -from torchtitan.utils import ( - Color, - dist_max, - dist_mean, - get_metrics_rank, - get_num_flop_per_token, - get_num_params, - get_peak_flops, - init_distributed, - NoColor, - set_pg_timeouts, -) - - -@dataclass -class TrainState(Stateful): - step: int = 0 - global_avg_losses: List[float] = field(default_factory=list) - global_max_losses: List[float] = field(default_factory=list) - log_steps: List[int] = field(default_factory=list) - - def state_dict(self) -> Dict[str, Any]: - # Only checkpoint global_avg_losses and global_max_losses per log frequency - # to avoid sync overhead in every iteration. - global_avg_losses_bytes = BytesIO() - torch.save(self.global_avg_losses, global_avg_losses_bytes) - global_max_losses_bytes = BytesIO() - torch.save(self.global_max_losses, global_max_losses_bytes) - log_steps_bytes = BytesIO() - torch.save(self.log_steps, log_steps_bytes) - return { - "step": torch.tensor(self.step, dtype=torch.int32), - "global_avg_losses": global_avg_losses_bytes, - "global_max_losses": global_max_losses_bytes, - "log_steps": log_steps_bytes, - } - - def load_state_dict(self, state_dict) -> None: - self.step = state_dict["step"].item() - state_dict["global_avg_losses"].seek(0) - self.global_avg_losses = torch.load( - state_dict["global_avg_losses"], weights_only=False - ) - state_dict["global_max_losses"].seek(0) - self.global_max_losses = torch.load( - state_dict["global_max_losses"], weights_only=False - ) - state_dict["log_steps"].seek(0) - self.log_steps = torch.load(state_dict["log_steps"], weights_only=False) - - -def build_optimizers(model_parts, job_config: JobConfig): - """Wrap one optimizer per model part in an OptimizersContainer which provides a single - step() and zero_grad() method for all the child optimizers. - """ - - def _build_optimizer(model): - name = job_config.optimizer.name - lr = job_config.optimizer.lr - fused = job_config.optimizer.fused - # Common parameters for both optimizers - optimizer_kwargs = { - "lr": lr, - "betas": (0.9, 0.95), - "weight_decay": 0.1, - "fused": fused, - "foreach": not fused, - } - if name == "Adam": - # TODO: make the optimizer options configurable by toml/cmd args - optimizer = torch.optim.Adam(model.parameters(), **optimizer_kwargs) - elif name == "AdamW": - optimizer = torch.optim.AdamW(model.parameters(), **optimizer_kwargs) - else: - raise NotImplementedError(f"Optimizer {name} not added.") - return optimizer - - class OptimizersContainer: - """Util for calling step/zero_grad on multiple optimizers needed for virtual pipeline stages""" - - def __init__(self, optimizers): - self.optimizers = optimizers - - def step(self): - for optimizer in self.optimizers: - optimizer.step() - - def zero_grad(self): - for optimizer in self.optimizers: - optimizer.zero_grad() +def get_train_context(enable_loss_parallel: bool, enable_compiled_autograd: bool): + @contextlib.contextmanager + def context(): + with contextlib.ExitStack() as stack: + if enable_loss_parallel: + stack.enter_context(torch.distributed.tensor.parallel.loss_parallel()) + if enable_compiled_autograd: + stack.enter_context( + torch._dynamo.utils.maybe_enable_compiled_autograd(True) + ) + yield - return OptimizersContainer([_build_optimizer(model) for model in model_parts]) + return context # Enable debug tracing on failure: https://pytorch.org/docs/stable/elastic/errors.html @@ -142,12 +51,10 @@ def main(job_config: JobConfig): logger.info(f"Starting job: {job_config.job.description}") # used for colorful printing - color = Color if job_config.metrics.enable_color_printing else NoColor + color = utils.Color if job_config.metrics.enable_color_printing else utils.NoColor # take control of garbage collection to avoid stragglers - _gc_freq = job_config.training.gc_freq - gc.disable() - gc.collect(1) + gc_handler = utils.GarbageCollection(gc_freq=job_config.training.gc_freq) # init distributed world_size = int(os.environ["WORLD_SIZE"]) @@ -157,17 +64,20 @@ def main(job_config: JobConfig): pp=job_config.experimental.pipeline_parallel_degree, world_size=world_size, enable_loss_parallel=job_config.training.enable_loss_parallel, + dp_type=job_config.training.data_parallel_type, ) device = torch.device(f"cuda:{int(os.environ['LOCAL_RANK'])}") torch.cuda.set_device(device) - init_distributed(job_config) + utils.init_distributed(job_config) + # initialize GPU memory monitor and get peak flops for MFU calculation + gpu_memory_monitor = build_gpu_memory_monitor() + gpu_peak_flops = utils.get_peak_flops(gpu_memory_monitor.device_name) # build meshes world_mesh = parallel_dims.build_mesh(device_type="cuda") if parallel_dims.dp_enabled: dp_mesh = world_mesh["dp"] - dp_degree = dp_mesh.size() - dp_rank = dp_mesh.get_local_rank() + dp_degree, dp_rank = dp_mesh.size(), dp_mesh.get_local_rank() else: dp_degree, dp_rank = 1, 0 @@ -178,7 +88,7 @@ def main(job_config: JobConfig): # build tokenizer tokenizer_type = model_name_to_tokenizer[model_name] - tokenizer = create_tokenizer(tokenizer_type, job_config.model.tokenizer_path) + tokenizer = build_tokenizer(tokenizer_type, job_config.model.tokenizer_path) # build dataloader data_loader = build_hf_data_loader( @@ -191,15 +101,6 @@ def main(job_config: JobConfig): dp_rank, ) - # loss_parallel enables dispatching to efficient loss operators - loss_parallel_ctx = ( - loss_parallel if parallel_dims.loss_parallel_enabled else contextlib.nullcontext - ) - - # loss fn can be shared by pipeline-parallel or non-pp execution - def loss_fn(pred, labels): - return F.cross_entropy(pred.flatten(0, 1), labels.flatten(0, 1)) - # build model (using meta init) model_cls = model_name_to_cls[model_name] model_config = models_config[model_name][job_config.model.flavor] @@ -215,14 +116,15 @@ def loss_fn(pred, labels): with torch.device("meta"): whole_model = model_cls.from_model_args(model_config) - # apply fp8 linear module swap - if job_config.training.fp8_linear: - build_fp8_linear(whole_model, job_config) + # a no-op hander if fp8 is not enabled + float8_handler = Float8Handler(job_config, parallel_dims) + # swap to Float8Linear base on fp8 config + float8_handler.convert_to_float8_training(whole_model) # log model size - model_param_count = get_num_params(whole_model) - num_flop_per_token = get_num_flop_per_token( - get_num_params(whole_model, exclude_embedding=True), + model_param_count = utils.get_num_params(whole_model) + num_flop_per_token = utils.get_num_flop_per_token( + utils.get_num_params(whole_model, exclude_embedding=True), model_config, job_config.training.seq_len, ) @@ -231,14 +133,9 @@ def loss_fn(pred, labels): f"{color.red}size: {model_param_count:,} total parameters{color.reset}" ) - # initialize GPU memory monitor before applying parallelisms to the model - gpu_memory_monitor = build_gpu_memory_monitor() - # obtain the peak flops of bf16 type for MFU calculation - gpu_peak_flops = get_peak_flops(gpu_memory_monitor.device_name) - if parallel_dims.pp_enabled: stages, model_parts = models_pipelining_fns[model_name]( - whole_model, world_mesh, parallel_dims, job_config, device, model_config + whole_model, pp_mesh, parallel_dims, job_config, device, model_config ) else: # In 1D/2D cases or PP with simple schedules, model_parts is just one item @@ -256,6 +153,12 @@ def loss_fn(pred, labels): for model in model_parts: model.to_empty(device=init_device) + # loss fn can be shared by pipeline-parallel or non-pp execution + def loss_fn(pred, labels): + return torch.nn.functional.cross_entropy( + pred.flatten(0, 1), labels.flatten(0, 1) + ) + if parallel_dims.pp_enabled: pp_schedule = build_pipeline_schedule( job_config, parallel_dims, stages, loss_fn @@ -275,11 +178,7 @@ def loss_fn(pred, labels): # build optimizer after applying parallelisms to the model optimizers = build_optimizers(model_parts, job_config) - lr_schedulers = get_lr_schedulers(optimizers.optimizers, job_config) - - metric_logger = build_metric_logger( - job_config, metrics_log_rank=get_metrics_rank(world_mesh, parallel_dims) - ) + lr_schedulers = build_lr_schedulers(optimizers.optimizers, job_config) train_state = TrainState() @@ -289,10 +188,10 @@ def loss_fn(pred, labels): # load initial checkpoint checkpoint = CheckpointManager( + dataloader=data_loader, model_parts=model_parts, optimizers=optimizers.optimizers, lr_schedulers=lr_schedulers.schedulers, - dataloader=data_loader, states={"train_state": train_state}, job_config=job_config, ) @@ -313,6 +212,8 @@ def loss_fn(pred, labels): "Please run `./create_seed_checkpoint.sh` and rerun training with `--checkpoint.enable_checkpoint`" ) + metric_logger = build_metric_logger(job_config, parallel_dims) + # plot losses loaded from checkpoint (if any) to TensorBoard # NOTE: Loss info after the last log step before checkpoint saving will not be ploted. # This can be avoided by setting checkpoint.interval to be a multiple of metrics.log_freq @@ -326,17 +227,29 @@ def loss_fn(pred, labels): data_iterator = iter(data_loader) - checkpoint.reset() + train_context = get_train_context( + parallel_dims.loss_parallel_enabled, + job_config.experimental.enable_compiled_autograd, + ) # variables used to keep info for metrics logging - losses_since_last_log: List[float] = [] + losses_since_last_log = [] ntokens_since_last_log = 0 - data_loading_times: List[float] = [] - time_last_log = timer() + data_loading_times = [] + time_last_log = time.perf_counter() gpu_memory_monitor.reset_peak_stats() + checkpoint.reset() + # train loop - logger.info(f"Training starts at step {train_state.step + 1}") + logger.info( + f"Training starts at step {train_state.step + 1}, " + f"with local batch size {job_config.training.batch_size}, " + f"global batch size {job_config.training.batch_size * dp_degree}, " + f"sequence length {job_config.training.seq_len}, " + f"total steps {job_config.training.steps} " + f"(warmup {job_config.training.warmup_steps})" + ) with maybe_enable_profiling( job_config, global_step=train_state.step ) as torch_profiler, maybe_enable_memory_snapshot( @@ -344,15 +257,14 @@ def loss_fn(pred, labels): ) as memory_profiler: while train_state.step < job_config.training.steps: train_state.step += 1 - if train_state.step > 1 and train_state.step % _gc_freq == 0: - gc.collect(1) + gc_handler.run(train_state.step) # get batch - data_load_start = timer() + data_load_start = time.perf_counter() batch = next(data_iterator) input_ids, labels = batch ntokens_since_last_log += labels.numel() - data_loading_times.append(timer() - data_load_start) + data_loading_times.append(time.perf_counter() - data_load_start) input_ids = input_ids.cuda() labels = labels.cuda() @@ -362,7 +274,7 @@ def loss_fn(pred, labels): # pipeline parallel forward / backward inside step() call is_last_stage = pp_mesh.get_local_rank() == pp_mesh.size() - 1 - with loss_parallel_ctx(): + with train_context(): if pp_mesh.get_local_rank() == 0: pp_schedule.step(input_ids) elif is_last_stage: @@ -379,7 +291,7 @@ def loss_fn(pred, labels): ) else: # Non-PP forward / backward - with loss_parallel_ctx(): + with train_context(): pred = model(input_ids) loss = loss_fn(pred, labels) # pred.shape=(bs, seq_len, vocab_size) @@ -393,11 +305,18 @@ def loss_fn(pred, labels): model.parameters(), job_config.training.max_norm, foreach=True ) + # sync float8 amaxes and scales + float8_handler.sync_float8_amax_and_scale_history(model) + # optimizer step - checkpoint.wait_for_staging() + checkpoint.maybe_wait_for_staging() optimizers.step() lr_schedulers.step() + # calculate float8 dynamic amax/scale for all-parameter for FSDP2 + # it issues a single all-reduce for all parameters at once for better performance + float8_handler.precompute_fp8_dynamic_scale_for_fsdp(model) + losses_since_last_log.append(loss) # log metrics @@ -406,23 +325,21 @@ def loss_fn(pred, labels): or train_state.step % job_config.metrics.log_freq == 0 ): losses = [loss.item() for loss in losses_since_last_log] - avg_loss, max_loss = ( - np.mean(losses), - np.max(losses), - ) + avg_loss, max_loss = sum(losses) / len(losses), max(losses) if parallel_dims.dp_enabled: global_avg_loss, global_max_loss = ( - dist_mean(avg_loss, dp_mesh).item(), - dist_max(max_loss, dp_mesh).item(), + utils.dist_mean(avg_loss, dp_mesh), + utils.dist_max(max_loss, dp_mesh), ) else: global_avg_loss, global_max_loss = avg_loss, max_loss + # update train state train_state.log_steps.append(train_state.step) train_state.global_avg_losses.append(global_avg_loss) train_state.global_max_losses.append(global_max_loss) - time_delta = timer() - time_last_log + time_delta = time.perf_counter() - time_last_log # tokens per second, abbr. as wps by convention wps = ntokens_since_last_log / ( @@ -434,8 +351,8 @@ def loss_fn(pred, labels): mfu = 100 * num_flop_per_token * wps / gpu_peak_flops time_end_to_end = time_delta / job_config.metrics.log_freq - time_data_loading = np.mean(data_loading_times) - time_data_loading_pct = 100 * np.sum(data_loading_times) / time_delta + time_data_loading = sum(data_loading_times) / len(data_loading_times) + time_data_loading_pct = 100 * sum(data_loading_times) / time_delta gpu_mem_stats = gpu_memory_monitor.get_peak_stats() @@ -468,7 +385,7 @@ def loss_fn(pred, labels): losses_since_last_log.clear() ntokens_since_last_log = 0 data_loading_times.clear() - time_last_log = timer() + time_last_log = time.perf_counter() gpu_memory_monitor.reset_peak_stats() checkpoint.save( @@ -484,7 +401,7 @@ def loss_fn(pred, labels): # Reduce timeout after first train step for faster signal (assumes lazy init, compile are finished) if train_state.step == 1: - set_pg_timeouts( + utils.set_pg_timeouts( timeout=timedelta(seconds=job_config.comm.train_timeout_seconds), world_mesh=world_mesh, ) @@ -501,4 +418,4 @@ def loss_fn(pred, labels): config = JobConfig() config.parse_args() main(config) - destroy_process_group() + torch.distributed.destroy_process_group() diff --git a/train_configs/debug_model.toml b/train_configs/debug_model.toml index cb2fb215..7d4187dc 100644 --- a/train_configs/debug_model.toml +++ b/train_configs/debug_model.toml @@ -37,12 +37,12 @@ max_norm = 1.0 # grad norm clipping steps = 10 data_parallel_degree = -1 tensor_parallel_degree = 1 -fp8_linear = false compile = false dataset = "c4_mini" # supported datasets: c4_mini (45K), c4 (177M) [experimental] pipeline_parallel_degree = 1 +enable_async_tensor_parallel = false [checkpoint] enable_checkpoint = false @@ -56,3 +56,6 @@ async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"] [activation_checkpoint] mode = 'selective' # ['none', 'selective', 'full'] selective_ac_option = '2' # 'int' = ac every positive int layer or 'op', ac based on ops policy + +[float8] +enable_float8_linear = false diff --git a/train_configs/llama2_13b.toml b/train_configs/llama2_13b.toml index 05e3c27b..4727f965 100644 --- a/train_configs/llama2_13b.toml +++ b/train_configs/llama2_13b.toml @@ -33,7 +33,6 @@ max_norm = 1.0 # grad norm clipping steps = 1000 data_parallel_degree = -1 tensor_parallel_degree = 1 -fp8_linear = false compile = false dataset = "c4" @@ -52,3 +51,6 @@ async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"] [activation_checkpoint] mode = 'selective' # ['none', 'selective', 'full'] selective_ac_option = 'op' # 'int' = ac every positive int layer or 'op', ac based on ops policy + +[float8] +enable_float8_linear = false diff --git a/train_configs/llama2_70b.toml b/train_configs/llama2_70b.toml index 5b2dd493..83114876 100644 --- a/train_configs/llama2_70b.toml +++ b/train_configs/llama2_70b.toml @@ -33,7 +33,6 @@ max_norm = 1.0 # grad norm clipping steps = 1000 data_parallel_degree = -1 tensor_parallel_degree = 8 # 8-way TP -fp8_linear = false compile = false dataset = "c4" @@ -51,3 +50,6 @@ async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"] [activation_checkpoint] mode = 'full' # ['none', 'selective', 'full'] + +[float8] +enable_float8_linear = false diff --git a/train_configs/llama2_7b.toml b/train_configs/llama2_7b.toml index 9b72246a..22ab6c76 100644 --- a/train_configs/llama2_7b.toml +++ b/train_configs/llama2_7b.toml @@ -32,7 +32,6 @@ max_norm = 1.0 # grad norm clipping steps = 1000 data_parallel_degree = -1 tensor_parallel_degree = 1 # dp-only would be sufficient for 7B -fp8_linear = false compile = false dataset = "c4" @@ -51,3 +50,6 @@ async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"] [activation_checkpoint] mode = 'selective' # ['none', 'selective', 'full'] selective_ac_option = 'op' # 'int' = ac every positive int layer or 'op', ac based on ops policy + +[float8] +enable_float8_linear = false diff --git a/train_configs/llama3_405b.toml b/train_configs/llama3_405b.toml new file mode 100644 index 00000000..fb250642 --- /dev/null +++ b/train_configs/llama3_405b.toml @@ -0,0 +1,53 @@ +# torchtitan Config.toml +# NOTE: this toml config is a preset for 128 H100 GPUs. + +[job] +dump_folder = "./outputs" +description = "Llama 3 405B training" + +[profiling] +enable_profiling = true +save_traces_folder = "profile_trace" +profile_freq = 100 + +[metrics] +log_freq = 10 +enable_tensorboard = true +save_tb_folder = "tb" + +[model] +name = "llama3" +flavor = "405B" +norm_type = "rmsnorm" # layernorm / np_layernorm / rmsnorm / compiled_rmsnorm / fused_rmsnorm +tokenizer_path = "./torchtitan/datasets/tokenizer/original/tokenizer.model" + +[optimizer] +name = "AdamW" +lr = 0.8e-4 + +[training] +batch_size = 2 +seq_len = 8192 +warmup_steps = 600 # lr scheduler warm up, normally 20% of the train steps +max_norm = 1.0 # grad norm clipping +steps = 3000 +data_parallel_degree = -1 +tensor_parallel_degree = 8 # 8-way TP +enable_float8_linear = false +compile = false +dataset = "c4" + +[experimental] +pipeline_parallel_degree = 1 + +[checkpoint] +enable_checkpoint = false +folder = "checkpoint" +interval_type = "steps" +interval = 500 +model_weights_only = false +export_dtype = "float32" +async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"] + +[activation_checkpoint] +mode = 'full' # ['none', 'selective', 'full'] diff --git a/train_configs/llama3_70b.toml b/train_configs/llama3_70b.toml index 93b529f6..62d75dfb 100644 --- a/train_configs/llama3_70b.toml +++ b/train_configs/llama3_70b.toml @@ -33,7 +33,6 @@ max_norm = 1.0 # grad norm clipping steps = 1000 data_parallel_degree = -1 tensor_parallel_degree = 8 # 8-way TP -fp8_linear = false compile = false dataset = "c4" @@ -51,3 +50,6 @@ async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"] [activation_checkpoint] mode = 'full' + +[float8] +enable_float8_linear = false diff --git a/train_configs/llama3_8b.toml b/train_configs/llama3_8b.toml index 95a53d56..517dd81e 100644 --- a/train_configs/llama3_8b.toml +++ b/train_configs/llama3_8b.toml @@ -33,7 +33,6 @@ max_norm = 1.0 # grad norm clipping steps = 1000 data_parallel_degree = -1 tensor_parallel_degree = 1 -fp8_linear = false compile = false dataset = "c4" @@ -52,3 +51,6 @@ async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"] [activation_checkpoint] mode = 'selective' # ['none', 'selective', 'full'] selective_ac_option = 'op' # 'int' = ac every positive int layer or 'op', ac based on ops policy + +[float8] +enable_float8_linear = false