Skip to content

Commit 42589ae

Browse files
tianyu-lAndrew GusanketpurandareYifu Wangvkuzo
authored
merge upstream changes and add support for torchbench (#9)
* Set `record_shapes=True` for profiler ghstack-source-id: 6f1ed49d15ce311f1bf118820965cdb5309a8030 Pull Request resolved: pytorch#419 * Improved `repeat_kv` eager perf ghstack-source-id: 39e484954814e61cdfb2ba661f0a98c83bc0ce60 Pull Request resolved: pytorch#418 * Adding FSDP Memory Tracking and Estimation ghstack-source-id: c8ed20fc585957bd164dd963307616a53991615d Pull Request resolved: pytorch#425 * Adding integration test for FSDP Memory Tracking and Estimation ghstack-source-id: cc224db8951ec7a133fd769845a4765cbedc6454 Pull Request resolved: pytorch#426 * by default disable heavy memory profiling ghstack-source-id: cad7b3c41fd60ec19c0e6e7d058e8aa00602a187 Pull Request resolved: pytorch#430 * Add the option to turn on async-TP ghstack-source-id: 0a03379eeb3a63b2d1ad4dff84d0e61ca82b1bbf Pull Request resolved: pytorch#429 * Modifying memory estimation options and minor changes ghstack-source-id: 5f09824cddaed6585cc094095e1e95dd070d76f4 Pull Request resolved: pytorch#435 * add comment pointing to Sequence Parallel optimization example ghstack-source-id: 6fa0dcd4bca876e10a6a8349283fb940a59ad234 Pull Request resolved: pytorch#438 * switch float8 logic from Float8DynamicLinear to Float8Linear (pytorch#436) Summary: After pytorch-labs/float8_experimental#300, `Float8Linear` with default settings is equivalent to `Float8DynamicLinear`. This PR changes `torchtitan` to use `Float8Linear`. To support the new UX of `float8_experimental` better, I also switched the `fp8_linear` configuration to be a boolean on whether to swap the linears or not. In the future we can add new options on how to configure each linear (scaling type, scaling granularity, etc) - saving that for a future PR. Test Plan: ``` // run baseline (Float8DynamicLinear) for llama3_8b for 50 iterations on 4 GPUs, // verify performance and loss values do not change meaningfully between // baseline and this PR // baseline (before this PR) // 1. compile, bf16 // 2. compile, float8 // 3. compile, float8, fdsp_fp8_allgather=True // 4. compile, float8, fdsp_fp8_allgather=True, tp=2 // logs: https://gist.github.com/vkuzo/e6d5f3b15349862bfad3706baad8c9ce // experiment (this PR): repeat all of the above, but with Float8Linear // logs: https://gist.github.com/vkuzo/a4d6754358facffa64df931654459631 ``` Reviewers: Subscribers: Tasks: Tags: * Removed `_experimental_support_context_fn_in_torch_utils_checkpoint` ghstack-source-id: 50b2d0c2b4c22e2f045cafd8630c16f3a8c6d35f Pull Request resolved: pytorch#444 * Reordered TP parallel plan to follow execution order ghstack-source-id: b4924952adeb5f16d08b60faa54690762841c422 Pull Request resolved: pytorch#445 * Made some stylistic changes to `apply_dp` ghstack-source-id: fb78e9eb8aa406ba87d6ad6cf2229c1027dae42f Pull Request resolved: pytorch#446 * Refactored activation checkpointing ghstack-source-id: 785c7e47651cda97ea22d0147d14b8d061ce042d Pull Request resolved: pytorch#447 * compiled RMSNorm ghstack-source-id: c4efb81ec6acc5442955908cc376df3e6d889af3 Pull Request resolved: pytorch#442 * Renamed parallel styles for transformer block weights ghstack-source-id: 5fb0bf3d08cacf27242ec0f85d5dd3cdc03b739e Pull Request resolved: pytorch#448 * Added type annotations and more stylistic changes ghstack-source-id: 1bd5b9d5abc8644785132f8eb2baaf8b1cfc5fb5 Pull Request resolved: pytorch#449 * [Cleanup] Remove libuv from run_llama_train.sh libuv is now enabled by default. we can proably do without the educational blurb there, and don't need the env either since the default has landed. ghstack-source-id: 68c8d2abe7eb0777e2add8df7634367c31b7ec06 Pull Request resolved: pytorch#453 * [Cleanup] Organize run_llama_train.sh options Just a little code motion but it looks cleaner to me this way ghstack-source-id: 055fbd557cd9cf189e6b9bd6a7048f1204e1dc5c Pull Request resolved: pytorch#454 * [Cleanup] Split run_llama_train.sh and run_memory_estimation.sh Make each script simpler to read ghstack-source-id: ba3aa65feb6e304736c73daf5bc8ab5fb254f196 Pull Request resolved: pytorch#455 * [Cleanup] Remove unused TRAINER_DIR This argument seems to be left over from older times- it is not used anywhere in the codebase. ghstack-source-id: abbcf82ed4d1b8fbb71c6a6b48acbc1296dbec64 Pull Request resolved: pytorch#456 * Add educational code pointers to top level README ghstack-source-id: 522aa2fa0bf1679f55d9f3a8a38fdcd319d5e3df Pull Request resolved: pytorch#457 * enable FSDP2 + fp8 all-gather and fix TP fp8 all-gather (pytorch#413) we have landed fp8 all-gather optimizations in float8_experimental pytorch-labs/float8_experimental#266 this PR proposes torchtitan changes. also include fp8 in CI ``` from float8_experimental.fsdp_utils import precompute_float8_dynamic_scale_for_fsdp # inside the training loop model(input).sum().backward() optim.step() precompute_float8_dynamic_scale_for_fsdp(model) ``` FSDP2 fp8 all-gather are added to CI ``` CONFIG_FILE="./train_configs/debug_model.toml" ./run_llama_train.sh --training.enable_fp8_linear CONFIG_FILE="./train_configs/debug_model.toml" ./run_llama_train.sh --training.enable_fp8_linear --training.enable_fsdp_fp8_all_gather CONFIG_FILE="./train_configs/debug_model.toml" ./run_llama_train.sh --training.enable_fp8_linear --training.enable_fsdp_fp8_all_gather --training.precompute_float8_dynamic_scale_for_fsdp ``` TP fp8 all-gather are locally tested. will add them to CI after uploading a new tokenizer with vacab size 2560 (divisible by 16) ``` CONFIG_FILE="./train_configs/llama3_8b.toml" NGPU=4 ./run_llama_train.sh --training.enable_fp8_linear --training.data_parallel_degree 1 --training.tensor_parallel_degree 4 CONFIG_FILE="./train_configs/llama3_8b.toml" NGPU=4 ./run_llama_train.sh --training.enable_fp8_linear --training.data_parallel_degree 2 --training.tensor_parallel_degree 2 ``` precompute scales after optimizer.step <img width="319" alt="Screenshot 2024-07-12 at 5 11 14 PM" src="https://github.com/user-attachments/assets/1c55bd89-9183-42ca-9445-23f3b95e0817"> FSDP2 pre-all-gather do not have any small all-reduces <img width="794" alt="Screenshot 2024-07-12 at 5 13 04 PM" src="https://github.com/user-attachments/assets/1a00dc70-a8ca-4ce1-a93c-316f22efdb08"> TODO * upload tokenizer with vacab size 2560 to enable CI on TP fp8 all-gather * torch.compile complains about fp8 * add delayed scaling and brainstorm about best config option to express fp8 * compare perf between delayed scaling and dynamic scaling https://github.com/pytorch-labs/float8_experimental/pull/312/files * import float8_experimental only when fp8 is enabled and install it in CI (pytorch#464) make sure to only import float8_experimental when fp8 is enabled for 4 gpu CI, make sure we can import float8_experimental correctly in CI `python -m pip install git+https://github.com/pytorch-labs/float8_experimental.git` * skip fp8 CI on non-H100 GPUs (pytorch#465) skip fp8 tests on non-H100 GPUs by checking `torch.cuda.get_device_capability() >= (9, 0)` this makes 4 GPU CI healthy again * clean up float8 configs in torchtitan (pytorch#466) Summary: 1. standardizes on `float8` instead of `fp8` for config names 2. removes usage of non-public objects such as `Float8Linear` Test Plan: ``` with-proxy NGPU=1 CUDA_VISIBLE_DEVICES=7 CONFIG_FILE="./train_configs/debug_model.toml" ./run_llama_train.sh --training.compile --training.enable_float8_linear ``` Reviewers: Subscribers: Tasks: Tags: * Add support of DDP and experimental CompiledAutograd Summary: Address the comments in pytorch#319 and resubmit the PR to fit the current code base. Test Plan: ``` CONFIG_FILE=./train_configs/debug_model.toml ./run_llama_train.sh --comm.train_timeout_seconds=3600 --training.tensor_parallel_degree=1 --training.data_parallel_degree=8 --experimental.data_parallel_type=ddp --training.steps=1000 --metrics.log_freq=10 --profiling.profile_freq=1000 ``` ghstack-source-id: 81dc85d42df13df4ed727bebd825681879af936b Pull Request resolved: pytorch#432 * add torch.compile + FSDP2 float8 all-gather in CI (pytorch#468) fixed my bug in float8_experimental. now we can torch.compile transfromer blocks with FSDP float8 all-gather pytorch-labs/float8_experimental#321 local test: `CONFIG_FILE="./train_configs/debug_model.toml" ./run_llama_train.sh --training.enable_float8_linear --training.enable_fsdp_float8_all_gather --training.precompute_float8_dynamic_scale_for_fsdp --training.compile` profiler traces: I can see compiled region in cpu thread and float8 malmul `sm90_xmma_gemm_e4m3bf16...` in cuda stream <img width="1468" alt="Screenshot 2024-07-18 at 4 22 17 PM" src="https://github.com/user-attachments/assets/0cf58dee-aae1-4582-a3f1-b8aa48b45129"> * [float8] keep model.output as `nn.Linear` (high precision, not fp8) (pytorch#469) **keep model.output as nn.Linear**: it's a common practice to NOT apply fp8 on final output layer * specify `skip_fqn_list` in swapping * when applying TP to model.output, use plain `ColwiseParallel` instead of `Float8ColwiseParallel` credit to @awgu, we do not need tokentizer vacab size to be divisible by 16 pytorch#461 1D TP + float8 all-gather, eager mode: `CONFIG_FILE="./train_configs/debug_model.toml" NGPU=4 ./run_llama_train.sh --training.enable_float8_linear --training.data_parallel_degree 1 --training.tensor_parallel_degree 4` 1D TP + float8 all-gather, compile mode: `CONFIG_FILE="./train_configs/debug_model.toml" NGPU=4 ./run_llama_train.sh --training.enable_float8_linear --training.data_parallel_degree 1 --training.tensor_parallel_degree 4 --training.compile` 2D FSDP2 + TP + float8 all-gather, eager mode: `CONFIG_FILE="./train_configs/debug_model.toml" NGPU=4 ./run_llama_train.sh --training.enable_float8_linear --training.enable_fsdp_float8_all_gather --training.precompute_float8_dynamic_scale_for_fsdp --training.tensor_parallel_degree 2` 2D FSDP2 + TP + float8 all-gather, eager mode: `CONFIG_FILE="./train_configs/debug_model.toml" NGPU=4 ./run_llama_train.sh --training.enable_float8_linear --training.enable_fsdp_float8_all_gather --training.precompute_float8_dynamic_scale_for_fsdp --training.tensor_parallel_degree 2 --training.compile` 1D TP + float8 all-gather trace: see float8 and all-gather in the trace <img width="1611" alt="Screenshot 2024-07-19 at 1 16 59 PM" src="https://github.com/user-attachments/assets/9a95dfd9-40e0-4133-b2bb-e22ddf5b8472"> 2D + float8 all-gather trace: see float8 and FSDP collectives and TP collectives <img width="1038" alt="Screenshot 2024-07-19 at 1 29 59 PM" src="https://github.com/user-attachments/assets/6a34bcaa-bcae-402b-9994-cc892554fec7"> * remove CI for FSDP2 + fp8 all-gather (pytorch#470) per discussion from pytorch#469 (comment) we are planning BC breaking changes in float8_experimental. remove CI for FSDP2 + fp8 all-gather for now. When public APIs are finalized, we can discuss bringing it back * dynamically update torch.compile cache config to ensure async tp support, enhance async tp UX (pytorch#471) This PR adds some enhancements for supporting async tp: 1 - if async tp is active, auto updates the torch.dynamo cache limit to 10K. If this is not updated, async tp will not be activated on larger models as it will quietly stop compilation due to 'cache limit reached' with no info for the user. This config update is logged. 2 - if async tp is enabled, verifies that torch.compile is set to true for this job config. If not, it warns and then activates torch.compile to ensure user gets working async tp. (see WARNING in below screenshot) <img width="1345" alt="Screenshot 2024-07-20 at 4 33 04 PM" src="https://github.com/user-attachments/assets/26e5a48e-4bb8-4f33-b1b5-8939c1517c1d"> 3 - Updates the 'Applied Tensor Parallel' to the model to be 'Applied Async Tensor Parallel' when async tp is active to make it clear in the logs which TP is active. (see above screenshot) * Fix 8gpu PP failure due to 2D DCP disablement DCP recently added safeties to avoid using it for 2D/3D since strided sharding (a feature needed for safe 2D/3D resharding) is not ready yet. PP uses DCP to load a seed checkpoint. Disabling the safety mechanism is enough to make 3D/PP still work (for the case where we train from the beginning or do not re-shard. (Resharding refers to saving a checkpoint from one world size/parallelism config and loading/resuming under a different one). ghstack-source-id: c069d2186c79517c72f5b3c99485cebdc15df08f Pull Request resolved: pytorch#460 * update float8 integration after UX changes (pytorch#484) Summary: float8_experimental landed various BC-breaking UX changes last week. This PR updates torchtitan to work with the version of float8_experimental after pytorch-labs/float8_experimental#332 and pytorch-labs/float8_experimental#337 Test Plan: ``` with-proxy CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 NGPU=8 CONFIG_FILE="./train_configs/llama3_8b.toml" ./run_llama_train.sh --training.enable_float8_linear --training.compile ``` Reviewers: Subscribers: Tasks: Tags: * Re-enable FSDP2 Mem Tracker integration tests ghstack-source-id: 8344603f7a5596cb2909c9bf04dd1b9e4730c9b8 Pull Request resolved: pytorch#485 * Used `partial` instead of global vars for LR scheduling ghstack-source-id: 12c4418b0574d93e1441f4ca3d1de79c8aad7a40 Pull Request resolved: pytorch#487 * [EZ] Add logs for some basic training params so that we can verify in… (pytorch#491) As title, while testing on 405B model, I found that we need to somehow need the logs for some training params. So added some here. Tested locally and the logging is shown as in the screenshot: <img width="900" alt="image" src="https://github.com/user-attachments/assets/b94e34f5-3e88-4c5f-94ed-75f50dde9786"> * make float8 scaling type configurable (pytorch#489) Summary: Adds config options to configure float8 scaling type for input, weight, grad_output. Performance is not ideal yet, but that's because we have not optimized it. Test Plan: ``` // repeat for input, weight, grad_out with-proxy CONFIG_FILE="./train_configs/llama3_8b.toml" ./run_llama_train.sh --training.enable_float8_linear --training.float8_scaling_type_weight delayed --training.compile ``` Reviewers: Subscribers: Tasks: Tags: * [PP] add flexible interleaved 1f1b schedule pytorch#490 (pytorch#493) This was approved in pytorch#490, but merged into the wrong branch, merging this into main * move float8 callsites to torchao.float8 (pytorch#492) Summary: The `float8_experimental` repository moved to `torchao.float8` in pytorch/ao#551 This PR updates `torchtitan` to use float8 from the new location. Test Plan: ``` with-proxy CONFIG_FILE="./train_configs/debug_model.toml" ./run_llama_train.sh --training.enable_float8_linear --training.compile ``` Reviewers: Subscribers: Tasks: Tags: * [BE][1/n] simplify train.py ghstack-source-id: 3879e764e7b33afde5d778810c71d1d2a8f82f6d Pull Request resolved: pytorch#494 * [BE][2/n] use proper method signatures in parallelize_llama ghstack-source-id: 17a1ee9f03f13423a30183c5c8d7ad30f8c8dbfc Pull Request resolved: pytorch#495 * [BE][3/n] wrap fp8 logic using Float8Handler ghstack-source-id: e94c7f6f4fad87c5432262c54beabd02de5541b8 Pull Request resolved: pytorch#496 * Bring LLaMa 3.1 405B to TorchTitan family (pytorch#481) With the official launch of LLaMa 3.1 model, we want to add the config to TorchTitan. Of course, there are more work to be done, but we want to go an incremental way. So more PRs will be needed. For now, we try on 128 GPUs with current config (TP=8, FSDP=16). The perf number is wps: 109 mfu: 29%. Loss curve for 3000 steps with 600 warmup (lr = 0.8e-4). <img width="1037" alt="image" src="https://github.com/user-attachments/assets/f57dd3fa-07d8-4ef4-8f68-8f7a08e9652e"> Loss curve for 3000 steps with 600 warmup (lr = 1.1e-4). ![image](https://github.com/user-attachments/assets/429b9738-94cb-4b37-90ef-049a5587ddd0) * [TP] Infer local n_heads instead of ad-hoc model changes ghstack-source-id: 587e3d6e5270714ca734b8031ce41a962e6394ea Pull Request resolved: pytorch#498 * some compile-related updates ghstack-source-id: 63af8025c184fd5ad34f2f57bf78a37dda2cd33d Pull Request resolved: pytorch#443 * [EZ][405B] Use scientific notation for 405B model lr (pytorch#504) As title, use `8e-5` rather than `0.8e-4`. * [BE][4/n] split pipeline_llama into a separate file ghstack-source-id: 5ebb4adf3152f413fa33a923c272c9aa3ce1f775 Pull Request resolved: pytorch#499 * [fix] float8 should be applied on all model_parts ghstack-source-id: 52ed6836de39e82c4c5824a40ecfc1d9ec7ed2bd Pull Request resolved: pytorch#500 * Add warning to compile rmsnorm (pytorch#505) as titled, add warning to compile rmsnorm as it's not fully ready yet, i.e. this issue pytorch#497 We can remove this warning once we fix the issue * add float8 to README (pytorch#509) add float8 link in README so we can redirect people from dev-discuss post to torchtitan repo README looks like this after rendering <img width="518" alt="Screenshot 2024-08-06 at 5 42 10 PM" src="https://github.com/user-attachments/assets/50af99d7-93be-459a-89d7-8c08b8fb95d4"> float8.md looks like this <img width="563" alt="Screenshot 2024-08-06 at 5 04 17 PM" src="https://github.com/user-attachments/assets/06d30aad-4133-4cec-9037-cfcf155b45c4"> I tried the command locally and traces are looking good <img width="726" alt="Screenshot 2024-08-06 at 5 00 00 PM" src="https://github.com/user-attachments/assets/bdfa3d7e-efe1-4009-92a1-0f5c310013fb"> * address TODOs as 2D recompiles is fixed ghstack-source-id: 2927f0a8082171da3e9f59a5d04f8325cbdf3653 Pull Request resolved: pytorch#508 * [BE][5/n] simply pp vs. non-pp set up ghstack-source-id: 003bfbfbcf1511ddbd18e15d031b39f597d8e7db Pull Request resolved: pytorch#510 * [BE][6/n] replace large c4_mini datasets by c4_test with the first 2K entries ghstack-source-id: 319f4961b092778703101b98937803073132afa1 Pull Request resolved: pytorch#512 * Create composability.md (pytorch#511) Explain the rationale and challenges behind certain changes we made to llama model to support 3D parallelism. --------- Co-authored-by: tianyu-l <150487191+tianyu-l@users.noreply.github.com> * depend on torchdata 0.8.0 instead of nightly ghstack-source-id: 1965d3122885fed3c28e2e058c55581187e7816c Pull Request resolved: pytorch#513 * add support for torchbench --------- Co-authored-by: Andrew Gu <andgu@fb.com> Co-authored-by: Sanket Jayant Purandare <sanketpurandare@meta.com> Co-authored-by: Yifu Wang <yifu@fb.com> Co-authored-by: Vasiliy Kuznetsov <vkuzo@users.noreply.github.com> Co-authored-by: Will Constable <whc@meta.com> Co-authored-by: Wei (Will) Feng <134637289+weifengpy@users.noreply.github.com> Co-authored-by: Chien-Chin Huang <chienchin@fb.com> Co-authored-by: Less Wright <lessw@etrillium.com> Co-authored-by: Sanket Jayant Purandare <sanketpurandare@fb.com> Co-authored-by: Hugo <6937752+fduwjj@users.noreply.github.com> Co-authored-by: Howard Huang <howardhuang96@gmail.com> Co-authored-by: Ke Wen <kw2501@meta.com> Co-authored-by: Wanchao <wanchaol@users.noreply.github.com> Co-authored-by: Will Constable <willconstable@gmail.com>
1 parent d86885f commit 42589ae

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

50 files changed

+3672
-945
lines changed

.ci/docker/requirements.txt

+1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
torch >= 2.3.0
2+
torchdata >= 0.8.0
23
datasets >= 2.19.0
34
tomli >= 1.1.0 ; python_version < "3.11"
45
tensorboard

.github/workflows/integration_test_4gpu.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,6 @@ jobs:
3838
pip config --user set global.progress_bar off
3939
4040
python -m pip install --force-reinstall --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121
41-
python -m pip install --pre torchdata --index-url https://download.pytorch.org/whl/nightly/
41+
USE_CPP=0 python -m pip install git+https://github.com/pytorch/ao.git
4242
mkdir artifacts-to-be-uploaded
4343
python ./test_runner.py artifacts-to-be-uploaded --ngpu 4

.github/workflows/integration_test_8gpu.yaml

-1
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,5 @@ jobs:
3737
pip config --user set global.progress_bar off
3838
3939
python -m pip install --force-reinstall --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121
40-
python -m pip install --pre torchdata --index-url https://download.pytorch.org/whl/nightly/
4140
mkdir artifacts-to-be-uploaded
4241
python ./test_runner.py artifacts-to-be-uploaded --ngpu 8

.github/workflows/unit_test_cpu.yaml

-1
Original file line numberDiff line numberDiff line change
@@ -25,5 +25,4 @@ jobs:
2525
pip config --user set global.progress_bar off
2626
2727
pip install --force-reinstall --pre torch --index-url https://download.pytorch.org/whl/nightly/cpu
28-
pip install --pre torchdata --index-url https://download.pytorch.org/whl/nightly
2928
pytest test --cov=. --cov-report=xml --durations=20 -vv

README.md

+16-7
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,16 @@ Our guiding principles when building `torchtitan`:
1818

1919
[![Welcome to torchtitan!](assets/images/titan_play_video.png)](https://youtu.be/ee5DOEqD35I?si=_B94PbVv0V5ZnNKE "Welcome to torchtitan!")
2020

21+
### Dive into the code
22+
23+
You may want to see how the model is defined or how parallelism techniques are applied. For a guided tour, see these files first:
24+
* [train.py](https://github.com/pytorch/torchtitan/blob/main/train.py) - the main training loop and high-level setup code
25+
* [torchtitan/parallelisms/parallelize_llama.py](https://github.com/pytorch/torchtitan/blob/main/torchtitan/parallelisms/parallelize_llama.py) - helpers for applying Data Parallel, Tensor Parallel, activation checkpointing, and `torch.compile` to the model
26+
* [torchtitan/parallelisms/pipeline_llama.py](https://github.com/pytorch/torchtitan/blob/main/torchtitan/parallelisms/pipeline_llama.py) - helpers for applying Pipeline Parallel to the model
27+
* [torchtitan/checkpoint.py](https://github.com/pytorch/torchtitan/blob/main/torchtitan/checkpoint.py) - utils for saving/loading distributed checkpoints
28+
* [torchtitan/float8.py](https://github.com/pytorch/torchtitan/blob/main/torchtitan/float8.py) - utils for applying Float8 techniques
29+
* [torchtitan/models/llama/model.py](https://github.com/pytorch/torchtitan/blob/main/torchtitan/models/llama/model.py) - the Llama model definition (shared for Llama2 and Llama3 variants)
30+
2131
## Pre-Release Updates:
2232
#### (4/25/2024): `torchtitan` is now public but in a pre-release state and under development.
2333
Currently we showcase pre-training **Llama 3 and Llama 2** LLMs of various sizes from scratch. `torchtitan` is tested and verified with the PyTorch nightly version `torch-2.4.0.dev20240412`. (We recommend latest PyTorch nightly).
@@ -33,18 +43,18 @@ Currently we showcase pre-training **Llama 3 and Llama 2** LLMs of various sizes
3343
6. Learning rate scheduler, meta init, Optional Fused RMSNorm
3444
7. All options easily configured via [toml files](train_configs/)
3545
8. [Interoperable checkpoints](docs/checkpoint.md) which can be loaded directly into [`torchtune`](https://github.com/pytorch/torchtune) for fine tuning
46+
9. [Float8 support](docs/float8.md)
3647

3748
We report our [Performance](docs/performance.md) verified on 64 A100 GPUs
3849

3950

4051
### Coming soon
4152

4253
1. Async checkpointing
43-
2. FP8 support
44-
3. Context Parallel
45-
4. 3D Pipeline Parallel
46-
5. `torch.compile` support
47-
6. Scalable data loading solution
54+
2. Context Parallel
55+
3. 3D Pipeline Parallel
56+
4. `torch.compile` support
57+
5. Scalable data loading solution
4858

4959

5060
## Installation
@@ -54,7 +64,6 @@ git clone https://github.com/pytorch/torchtitan
5464
cd torchtitan
5565
pip install -r requirements.txt
5666
pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121 # or cu118
57-
pip3 install --pre torchdata --index-url https://download.pytorch.org/whl/nightly
5867
```
5968

6069
### Downloading a tokenizer
@@ -66,7 +75,7 @@ Once you have confirmed access, you can run the following command to download th
6675
```bash
6776
# Get your HF token from https://huggingface.co/settings/tokens
6877

69-
# llama3 tokenizer.model
78+
# llama3 or 3.1 tokenizer.model
7079
python torchtitan/datasets/download_tokenizer.py --repo_id meta-llama/Meta-Llama-3-8B --tokenizer_path "original" --hf_token=...
7180

7281
# llama2 tokenizer.model

benchmark.py

+232
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,232 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import os
8+
import time
9+
from datetime import timedelta
10+
11+
import torch
12+
from torch.distributed.elastic.multiprocessing.errors import record
13+
14+
from torchbenchmark.util.experiment.instantiator import (
15+
load_model,
16+
TorchBenchModelConfig,
17+
)
18+
from torchbenchmark.util.experiment.metrics import get_model_flops
19+
from torchbenchmark.util.input import input_cast
20+
21+
from torchtitan import utils
22+
from torchtitan.checkpoint import TrainState
23+
from torchtitan.config_manager import JobConfig, TORCH_DTYPE_MAP
24+
from torchtitan.logging import init_logger, logger
25+
from torchtitan.metrics import build_gpu_memory_monitor
26+
from torchtitan.parallelisms import ParallelDims
27+
from torchtitan.parallelisms.parallelize_llama import torch_spmd_parallelize
28+
from torchtitan.profiling import maybe_enable_memory_snapshot, maybe_enable_profiling
29+
30+
31+
# Enable debug tracing on failure: https://pytorch.org/docs/stable/elastic/errors.html
32+
@record
33+
def main(job_config: JobConfig):
34+
init_logger()
35+
logger.info(f"Starting job: {job_config.job.description}")
36+
37+
# used for colorful printing
38+
color = utils.Color if job_config.metrics.enable_color_printing else utils.NoColor
39+
40+
# take control of garbage collection to avoid stragglers
41+
gc_handler = utils.GarbageCollection(gc_freq=job_config.training.gc_freq)
42+
43+
# init distributed
44+
world_size = int(os.environ["WORLD_SIZE"])
45+
parallel_dims = ParallelDims(
46+
dp=job_config.training.data_parallel_degree,
47+
tp=job_config.training.tensor_parallel_degree,
48+
pp=job_config.experimental.pipeline_parallel_degree,
49+
world_size=world_size,
50+
enable_loss_parallel=job_config.training.enable_loss_parallel,
51+
dp_type=job_config.training.data_parallel_type,
52+
)
53+
device = torch.device(f"cuda:{int(os.environ['LOCAL_RANK'])}")
54+
torch.cuda.set_device(device)
55+
utils.init_distributed(job_config)
56+
# initialize GPU memory monitor and get peak flops for MFU calculation
57+
gpu_memory_monitor = build_gpu_memory_monitor()
58+
gpu_peak_flops = utils.get_peak_flops(gpu_memory_monitor.device_name)
59+
60+
# build meshes
61+
world_mesh = parallel_dims.build_mesh(device_type="cuda")
62+
if parallel_dims.dp_enabled:
63+
dp_mesh = world_mesh["dp"]
64+
dp_degree, dp_rank = dp_mesh.size(), dp_mesh.get_local_rank()
65+
else:
66+
dp_degree, dp_rank = 1, 0
67+
68+
if parallel_dims.pp_enabled:
69+
pp_mesh = world_mesh["pp"]
70+
71+
model_name = job_config.model.name
72+
73+
# initiate model from torchbench
74+
config = TorchBenchModelConfig(
75+
name=model_name,
76+
test="train",
77+
device="cuda",
78+
batch_size=job_config.training.batch_size,
79+
extra_args=[],
80+
)
81+
model_flops = get_model_flops(config)
82+
benchmark_model = load_model(config)
83+
model, _ = benchmark_model.get_module()
84+
85+
# TODO: there seems to be a bug with dtype conversion (e.g. use resnet50)
86+
# cast input dtype if needed
87+
param_dtype = TORCH_DTYPE_MAP[job_config.training.mixed_precision_param]
88+
input_cond = lambda x: x.dtype == torch.float32
89+
input_action = lambda x: x.to(param_dtype)
90+
if hasattr(benchmark_model, "example_inputs"):
91+
benchmark_model.example_inputs = input_cast(
92+
input_cond, input_action, benchmark_model.example_inputs
93+
)
94+
else:
95+
logger.warning(
96+
f"{model_name} example inputs haven't been cast to {action} yet!"
97+
)
98+
99+
# log model size
100+
model_param_count = utils.get_num_params(model)
101+
logger.info(
102+
f"{color.blue}Model {model_name} "
103+
f"{color.red}size: {model_param_count:,} total parameters{color.reset}"
104+
)
105+
106+
# apply PT-D Tensor Parallel, activation checkpointing, torch.compile, Data Parallel
107+
model = torch_spmd_parallelize(model, world_mesh, parallel_dims, job_config)
108+
109+
# update model and optimizer after applying parallelisms
110+
benchmark_model.set_module(model)
111+
optimizer = benchmark_model.get_optimizer()
112+
optimizer.add_param_group({"params": model.parameters()})
113+
114+
model.train()
115+
116+
gpu_mem_stats = gpu_memory_monitor.get_peak_stats()
117+
logger.info(
118+
f"GPU memory usage for model: "
119+
f"{gpu_mem_stats.max_reserved_gib:.2f}GiB"
120+
f"({gpu_mem_stats.max_reserved_pct:.2f}%)"
121+
)
122+
123+
train_state = TrainState()
124+
125+
# variables used to keep info for metrics logging
126+
losses_since_last_log = []
127+
gpu_memory_monitor.reset_peak_stats()
128+
129+
# train loop
130+
logger.info(
131+
f"Training starts at step {train_state.step + 1}, "
132+
f"with local batch size {job_config.training.batch_size}, "
133+
f"global batch size {job_config.training.batch_size * dp_degree}, "
134+
f"total steps {job_config.training.steps}"
135+
)
136+
with maybe_enable_profiling(
137+
job_config, global_step=train_state.step
138+
) as torch_profiler, maybe_enable_memory_snapshot(
139+
job_config, global_step=train_state.step
140+
) as memory_profiler:
141+
while train_state.step < job_config.training.steps:
142+
train_state.step += 1
143+
gc_handler.run(train_state.step)
144+
145+
torch.cuda.synchronize()
146+
start_event = torch.cuda.Event(enable_timing=True)
147+
end_event = torch.cuda.Event(enable_timing=True)
148+
149+
# Collect time_ns() instead of time() which does not provide better precision than 1
150+
# second according to https://docs.python.org/3/library/time.html#time.time.
151+
t0 = time.time_ns()
152+
start_event.record()
153+
154+
is_staged = (
155+
hasattr(benchmark_model, "forward")
156+
and hasattr(benchmark_model, "backward")
157+
and hasattr(benchmark_model, "optimizer_step")
158+
)
159+
if is_staged and (getattr(benchmark_model, "train", None) is None):
160+
if optimizer is not None:
161+
optimizer.zero_grad()
162+
loss = benchmark_model.forward()
163+
benchmark_model.backward(loss)
164+
if optimizer is not None:
165+
benchmark_model.optimizer_step()
166+
else:
167+
loss = benchmark_model.train()
168+
169+
end_event.record()
170+
torch.cuda.synchronize()
171+
t1 = time.time_ns()
172+
time_delta = start_event.elapsed_time(end_event), (t1 - t0) / 1_000_000
173+
174+
# log metrics
175+
losses_since_last_log.append(loss)
176+
if (
177+
train_state.step == 1
178+
or train_state.step % job_config.metrics.log_freq == 0
179+
):
180+
losses = [
181+
loss.item() if isinstance(loss, torch.Tensor) else loss
182+
for loss in losses_since_last_log
183+
]
184+
avg_loss, max_loss = sum(losses) / len(losses), max(losses)
185+
if parallel_dims.dp_enabled:
186+
global_avg_loss, global_max_loss = (
187+
utils.dist_mean(avg_loss, dp_mesh),
188+
utils.dist_max(max_loss, dp_mesh),
189+
)
190+
else:
191+
global_avg_loss, global_max_loss = avg_loss, max_loss
192+
193+
gpu_mem_stats = gpu_memory_monitor.get_peak_stats()
194+
195+
logger.info(
196+
f"{color.cyan}step: {train_state.step:2} "
197+
f"{color.green}loss: {global_avg_loss:7.4f} "
198+
f"{color.yellow}memory: {gpu_mem_stats.max_reserved_gib:5.2f}GiB"
199+
f"({gpu_mem_stats.max_reserved_pct:.2f}%) "
200+
f"{color.blue}GPU time: {time_delta[0]:.3f}ms "
201+
f"CPU wall time: {time_delta[1]:.3f}ms{color.reset}"
202+
)
203+
204+
losses_since_last_log.clear()
205+
gpu_memory_monitor.reset_peak_stats()
206+
207+
# signal the profiler that the next profiling step has started
208+
if torch_profiler:
209+
torch_profiler.step()
210+
if memory_profiler:
211+
memory_profiler.step()
212+
213+
# reduce timeout after first train step for faster signal
214+
# (assuming lazy init and compilation are finished)
215+
if train_state.step == 1:
216+
utils.set_pg_timeouts(
217+
timeout=timedelta(seconds=job_config.comm.train_timeout_seconds),
218+
world_mesh=world_mesh,
219+
)
220+
221+
if torch.distributed.get_rank() == 0:
222+
logger.info("Sleeping 2 seconds for other ranks to complete")
223+
time.sleep(2)
224+
225+
logger.info("Training completed")
226+
227+
228+
if __name__ == "__main__":
229+
config = JobConfig()
230+
config.parse_args()
231+
main(config)
232+
torch.distributed.destroy_process_group()

create_seed_checkpoint.sh

-2
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,6 @@
1818

1919
set -ex
2020

21-
export USE_LIBUV=1
22-
TRAINER_DIR=${1:-/home/$USER/local/torchtitan}
2321
NGPU=1
2422
LOG_RANK=0
2523
CONFIG_FILE=${CONFIG_FILE:-"./train_configs/debug_model.toml"}

docs/composability.md

+23
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
# Building a Clean, Readable Distributed LLM
2+
One of the main goals for TorchTitan was to provide a version of distributed LLM that was not only high performance, but utilized native pytorch techniques and readable code. The challenge is how to compose together so many individual library components (FSDP, TP, PP, FP8, Compile, DCP, ...) just to name a few, and avoid having to make too many changes to the model guts in the process. A lot of the work is behind the scenes, designing individual components to make fewer assumptions, use common abstractions (e.g. DTensor) and generally 'get along'. But we found a few tweaks to the model code invaluable as well, and wanted to share those changes and the rationale for them.
3+
4+
5+
6+
# Making the model "pipeline friendly"
7+
When applying Pipeline Parallelism, you will have to construct nn.Module objects representing the portion of the model that runs on a given pipeline stage. Whether you plan to manually edit your model code, or use techniques like tracing to extract model chunks, a few changes to the original model code can go a long way to making this process easier.
8+
9+
### Simplifying the top-level model forward
10+
Most likely, you can write your model in such a way that the top-level nn.Module owns a sequence of child modules that it calls during forward, delegating most of the complexity to the child module forwards. If you can reduce your top level forward to mostly a for-loop over child module calls, then you'll simplify the pipeline-partitioning task to choosing the set of submodules to keep per stage. If you have non-trivial logic in the top-level forward, you'll have to find a way to patch that logic back onto the resulting pipeline stage model, which can be annoying.
11+
12+
example ([PR #321](https://github.com/pytorch/torchtitan/pull/321)):
13+
we used to slice the `freqs_cis` buffer by `seq_len` in the top level forward, pass that into child modules, and expect that inside the child modules the `seq_len` would match up with the size of other local tensors. But we don't know about whether TP was applied or not when we consider PP splitting and could create a mismatch. Its just as easy to perform the `freqs_cis` slicing inside the child submodule, using the runtime-accurate local `seq_len`, and this sidesteps the issue at PP slicing time.
14+
15+
example ([PR #322])https://github.com/pytorch/torchtitan/pull/322)): We decided to actually reuse the top-level model object on every PP stage, just delete the layers we don't want, and make sure that the top-level forward would do the right thing. This means we don't have to make a separate runtime pp_forward that glues together child modules per stage. The first change was using a moduledict instead of modulelist to store layers. This preserves layer Fully Qualified Names (FQNs) even when deleting some layers - e.g. layers.1 stays layers.1 even if you remove layers.0, which isn't true for a list- this matters for checkpoint save/load. Preserving FQNs is a requirement for using Distributed Checkpointing (DCP) since it uses FQNs as globally unique IDs for sharding metadata. The second change was making the input and output layers optional- if the layer exists, we run it, otherwise we feed the input through to bypass it. With these two changes, we can just (meta)-initialize the whole model, delete the unused parts per stage, then materialize the remaining part on GPU before loading a checkpoint.
16+
17+
# Using a seed checkpoint for init
18+
Initializing the pipeline-parallel model is challenging becuase we assume the model could be so large as to not fit on local GPU (or possibly, even on CPU), and we also want to use the (bitwise) same initialization as we use for 1D or 2D parallel models, to ease debugging or comparisons between runs. It's not that easy to rewrite the original model's `init_weights` function to be tolerant of initializing only some layers, and also serializing initialization operations globally for consistent RNG order.
19+
20+
For now, we sidestep all these problems with a simple but brutal solution: Initialize the whole model on some CPU instance, save a checkpoint file, and then lean on Distributed Checkpointing's "load" functionality to initialize the FQNs that are present on a given PP stage after stage creation. For future work, we consider adding a more elaborate initialization scheme to `torch.pipelining`.
21+
22+
One issue with seed checkpoints is that we rely on initializing _every_ model state from the checkpoint, which means the model can't have any non-persistent buffers, or else we have to specially initialize those in `train.py` after pipeline splitting. `freqs_cis` was originally a non-persistent buffer, and we changed this to persistent in order to load it from the seed checkpoint.
23+

docs/float8.md

+18
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
## Enable Float8 Training on H100s
2+
3+
Please install latest [TorchAO](https://github.com/pytorch/ao/tree/main/torchao/float8) to support float8 dtype
4+
```
5+
USE_CPP=0 python -m pip install git+https://github.com/pytorch/ao.git
6+
```
7+
8+
Launch training job with the following command (or alternatively set configs in toml files)
9+
```
10+
CONFIG_FILE="./train_configs/llama3_8b.toml" ./run_llama_train.sh --float8.enable_float8_linear --float8.enable_fsdp_float8_all_gather --float8.precompute_float8_dynamic_scale_for_fsdp
11+
```
12+
* `--float8.enable_float8_linear`: swap `nn.Linear` with `Float8Linear` to perform float8 matmul.
13+
* `--float8.enable_fsdp_float8_all_gather`: cast `Float8Linear.weight` from high precision to float8 before FSDP all-gather so we can communicate in float8 to save bandwidth.
14+
* `--float8.precompute_float8_dynamic_scale_for_fsdp` (optional): communicate AMAX/scales efficiently in a single all-reduce for all parameters instead of doing many small all-reduce for each parameter.
15+
16+
For parallelisms, we support float8 all-gather for FSDP (optional) and for TP (by default for `Float8Linear`).
17+
18+
For scaling strategy, we currently support tensor-wise scaling with dynamic scales, and are actively working on tensor-wise scaling with delayed scales. Row-wise scaling is under exploration.

0 commit comments

Comments
 (0)