-
Notifications
You must be signed in to change notification settings - Fork 273
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
FSDP checkpoints don't load when run is restarted with greater world size #811
Comments
ye, the dataloader and learning rate scheduler do not support resharding but model and optimizer states do support resharding, we may be able to add selective resharding support. |
before supporting this, we should error out in data loader if world size becomes smaller, otherwise it's silent error. |
We add error messages at #816 when loading under low dp_degree at checkpoint saved with high dp_degree.
For the error message at summary here, which loads under high dp_degree at checkpoint saved with low dp_degree, we would support optional checkpoint loading in the next step. |
…m checkpoint (#816) Solve the issue here #811 to avoid users to run with data loader resharding. DataLoader resharding is not supported yet. For checkpoint loading before this PR, Case 1: save (dp:4) -> load (dp:4) Checkpoint works successfully as expected. Case 2: save (dp:4) -> load (dp:2) Run successfully but `dataloader.dp_rank_2` and `dataloader.dp_rank_3` are missing Case 3: save (dp:2) -> load (dp:4) Raise error that dataloader.dp_rank_2 and dataloader.dp_rank_3 not found in checkpoint state_dict The PR here aims to raise error at Case 2 as dataloader info are missing. In this PR, we store `dp_degree`(or say as `dp_world_size`), at dataloader state_dict. After loading from checkpoint, we compare `dp_degree` with the current. Test with Case 2 that load from checkpoint at step 3. ``` [rank0]:2025-02-03 13:39:06,055 - root - INFO - Starting job: Llama 3 8B training [rank0]:2025-02-03 13:39:06,866 - root - WARNING - ENV[TORCH_NCCL_ASYNC_ERROR_HANDLING] = 1 will be overridden to 3 based on job config [rank0]:2025-02-03 13:39:06,868 - root - INFO - CUDA capacity: NVIDIA H100 with 95.00GiB memory [rank0]:2025-02-03 13:39:06,920 - root - INFO - Peak FLOPS used for computing MFU: 9.890e+14 [rank0]:2025-02-03 13:39:06,920 - root - INFO - Building 2-D device mesh with ['dp_shard', 'tp'], [2, 4] [rank0]:2025-02-03 13:39:08,099 - root - INFO - Building tiktoken tokenizer locally from ./torchtitan/datasets/tokenizer/original/tokenizer.model [rank0]:2025-02-03 13:39:08,283 - root - INFO - TikTokenizer built: #words 128256, BOS ID 128000, EOS ID 128001 [rank0]:2025-02-03 13:39:08,284 - root - INFO - Preparing c4 dataset from allenai/c4 [rank0]:2025-02-03 13:39:13,047 - root - INFO - Building llama3 8B with ModelArgs(dim=4096, n_layers=32, n_heads=32, n_kv_heads=8, vocab_size=128256, multiple_of=1024, ffn_dim_multiplier=1.3, norm_eps=1e-05, rope_theta=500000, max_seq_len=8192, depth_init=True, norm_type='rmsnorm') [rank0]:2025-02-03 13:39:13,182 - root - INFO - Model llama3 8B size: 8,030,261,248 total parameters [rank0]:2025-02-03 13:39:13,252 - root - INFO - Applied Tensor Parallelism to the model [rank0]:2025-02-03 13:39:13,253 - root - INFO - Applied selective activation checkpointing to the model [rank0]:2025-02-03 13:39:13,296 - root - INFO - Compiling each TransformerBlock with torch.compile [rank0]:2025-02-03 13:39:13,386 - root - INFO - Applied FSDP to the model [rank0]:NCCL version 2.21.5+cuda12.0 [rank0]:2025-02-03 13:39:13,606 - root - INFO - CUDA memory usage for model: 3.77GiB(3.97%) [rank0]:2025-02-03 13:39:13,607 - root - INFO - Checkpointing active. Checkpoints will be loaded from and saved to ./outputs/checkpoint [rank0]:2025-02-03 13:39:13,607 - root - INFO - Loading the checkpoint at step 2. [rank0]:[rank0]: Traceback (most recent call last): [rank0]:[rank0]: File "/data/users/.../torchtitan/train.py", line 433, in <module> [rank0]:[rank0]: main(config) [rank0]:[rank0]: File "/data/users/.../pytorch/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 355, in wrapper [rank0]:[rank0]: return f(*args, **kwargs) [rank0]:[rank0]: File "/data/users/.../torchtitan/train.py", line 214, in main [rank0]:[rank0]: checkpoint.load(step=job_config.checkpoint.load_step) [rank0]:[rank0]: File "/data/users/.../torchtitan/torchtitan/checkpoint.py", line 441, in load [rank0]:[rank0]: dcp.load( [rank0]:[rank0]: File "/data/users/.../pytorch/torch/distributed/checkpoint/logger.py", line 83, in wrapper [rank0]:[rank0]: result = func(*args, **kwargs) [rank0]:[rank0]: File "/data/users/.../pytorch/torch/distributed/checkpoint/utils.py", line 438, in inner_func [rank0]:[rank0]: return func(*args, **kwargs) [rank0]:[rank0]: File "/data/users/.../pytorch/torch/distributed/checkpoint/state_dict_loader.py", line 188, in load [rank0]:[rank0]: elem.load_state_dict(statetful_sd[key]) [rank0]:[rank0]: File "/data/users/.../torchtitan/torchtitan/datasets/hf_datasets.py", line 178, in load_state_dict [rank0]:[rank0]: self._world_size == state_dict["world_size"] [rank0]:[rank0]: AssertionError: dp_degree is inconsistent before and after checkpoint, DataLoader resharding is not supported yet. ```
We can close this issue after #819 is landed. |
A checkpoint is saved from an 8-GPU run with
dp_shard
set to 8 and all other parallelisms set to 1. My understanding is that this is configured as an FSDP run.The checkpoint is resumed from 16 GPUs with
dp_shard
now set to 16. When loading the checkpoint, we get this error:My understanding is that torch distributed checkpoints are supposed to support dynamic resharding at load time. Does this not work with torchtitan?
I was able to successfully resume a checkpoint going down from 32 GPUs to 16.
The text was updated successfully, but these errors were encountered: