Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FSDP checkpoints don't load when run is restarted with greater world size #811

Closed
darkmirage opened this issue Jan 28, 2025 · 4 comments · Fixed by #819
Closed

FSDP checkpoints don't load when run is restarted with greater world size #811

darkmirage opened this issue Jan 28, 2025 · 4 comments · Fixed by #819
Assignees
Labels
bug Something isn't working documentation Improvements or additions to documentation enhancement New feature or request module: fsdp

Comments

@darkmirage
Copy link

A checkpoint is saved from an 8-GPU run with dp_shard set to 8 and all other parallelisms set to 1. My understanding is that this is configured as an FSDP run.

The checkpoint is resumed from 16 GPUs with dp_shard now set to 16. When loading the checkpoint, we get this error:

[rank0]: Traceback (most recent call last): (RANK 15)                                                                                            [rank0]:   File "/app/.venv/lib/python3.10/site-packages/torch/distributed/checkpoint/utils.py", line 164, in reduce_scatter                     [rank0]:     local_data = map_fun()                                                                                                              [rank0]:   File "/app/.venv/lib/python3.10/site-packages/torch/distributed/checkpoint/logger.py", line 83, in wrapper                            
[rank0]:     result = func(*args, **kwargs)                                                                                                      
[rank0]:   File "/app/.venv/lib/python3.10/site-packages/torch/distributed/checkpoint/state_dict_loader.py", line 211, in local_step             
[rank0]:     local_plan = planner.create_local_plan()                                                                                            
[rank0]:   File "/app/.venv/lib/python3.10/site-packages/torch/distributed/checkpoint/default_planner.py", line 233, in create_local_plan        
[rank0]:     return create_default_local_load_plan(                                                                                              
[rank0]:   File "/app/.venv/lib/python3.10/site-packages/torch/distributed/checkpoint/default_planner.py", line 354, in create_default_local_load
[rank0]:     raise RuntimeError(f"Missing key in checkpoint state_dict: {fqn}.")                                                                 
[rank0]: RuntimeError: Missing key in checkpoint state_dict: dataloader.dp_rank_15.  

My understanding is that torch distributed checkpoints are supposed to support dynamic resharding at load time. Does this not work with torchtitan?

I was able to successfully resume a checkpoint going down from 32 GPUs to 16.

@fegin
Copy link
Contributor

fegin commented Jan 30, 2025

ye, the dataloader and learning rate scheduler do not support resharding but model and optimizer states do support resharding, we may be able to add selective resharding support.

@tianyu-l tianyu-l added documentation Improvements or additions to documentation bug Something isn't working labels Jan 30, 2025
@tianyu-l
Copy link
Contributor

tianyu-l commented Jan 30, 2025

we may be able to add selective resharding support.

before supporting this, we should error out in data loader if world size becomes smaller, otherwise it's silent error.

@mori360
Copy link
Contributor

mori360 commented Feb 4, 2025

We add error messages at #816 when loading under low dp_degree at checkpoint saved with high dp_degree.

[rank0]: RuntimeError: Missing key in checkpoint state_dict: dataloader.dp_rank_15.

For the error message at summary here, which loads under high dp_degree at checkpoint saved with low dp_degree, we would support optional checkpoint loading in the next step.

mori360 added a commit that referenced this issue Feb 4, 2025
…m checkpoint (#816)

Solve the issue here #811 to
avoid users to run with data loader resharding.

DataLoader resharding is not supported yet.
For checkpoint loading before this PR,
Case 1: save (dp:4) -> load (dp:4)
Checkpoint works successfully as expected.

Case 2: save (dp:4) -> load (dp:2)
Run successfully but `dataloader.dp_rank_2` and `dataloader.dp_rank_3`
are missing

Case 3: save (dp:2) -> load (dp:4)
Raise error that dataloader.dp_rank_2 and dataloader.dp_rank_3 not found
in checkpoint state_dict

The PR here aims to raise error at Case 2 as dataloader info are
missing.


In this PR, we store `dp_degree`(or say as `dp_world_size`), at
dataloader state_dict.
After loading from checkpoint, we compare `dp_degree` with the current.


Test with Case 2 that load from checkpoint at step 3.
```
[rank0]:2025-02-03 13:39:06,055 - root - INFO - Starting job: Llama 3 8B training
[rank0]:2025-02-03 13:39:06,866 - root - WARNING - ENV[TORCH_NCCL_ASYNC_ERROR_HANDLING] = 1 will be overridden to 3 based on job config
[rank0]:2025-02-03 13:39:06,868 - root - INFO - CUDA capacity: NVIDIA H100 with 95.00GiB memory
[rank0]:2025-02-03 13:39:06,920 - root - INFO - Peak FLOPS used for computing MFU: 9.890e+14
[rank0]:2025-02-03 13:39:06,920 - root - INFO - Building 2-D device mesh with ['dp_shard', 'tp'], [2, 4]
[rank0]:2025-02-03 13:39:08,099 - root - INFO - Building tiktoken tokenizer locally from ./torchtitan/datasets/tokenizer/original/tokenizer.model
[rank0]:2025-02-03 13:39:08,283 - root - INFO - TikTokenizer built: #words 128256, BOS ID 128000, EOS ID 128001
[rank0]:2025-02-03 13:39:08,284 - root - INFO - Preparing c4 dataset from allenai/c4
[rank0]:2025-02-03 13:39:13,047 - root - INFO - Building llama3 8B with ModelArgs(dim=4096, n_layers=32, n_heads=32, n_kv_heads=8, vocab_size=128256, multiple_of=1024, ffn_dim_multiplier=1.3, norm_eps=1e-05, rope_theta=500000, max_seq_len=8192, depth_init=True, norm_type='rmsnorm')
[rank0]:2025-02-03 13:39:13,182 - root - INFO - Model llama3 8B size: 8,030,261,248 total parameters
[rank0]:2025-02-03 13:39:13,252 - root - INFO - Applied Tensor Parallelism to the model
[rank0]:2025-02-03 13:39:13,253 - root - INFO - Applied selective activation checkpointing to the model
[rank0]:2025-02-03 13:39:13,296 - root - INFO - Compiling each TransformerBlock with torch.compile
[rank0]:2025-02-03 13:39:13,386 - root - INFO - Applied FSDP to the model
[rank0]:NCCL version 2.21.5+cuda12.0
[rank0]:2025-02-03 13:39:13,606 - root - INFO - CUDA memory usage for model: 3.77GiB(3.97%)
[rank0]:2025-02-03 13:39:13,607 - root - INFO - Checkpointing active. Checkpoints will be loaded from and saved to ./outputs/checkpoint
[rank0]:2025-02-03 13:39:13,607 - root - INFO - Loading the checkpoint at step 2.
[rank0]:[rank0]: Traceback (most recent call last):
[rank0]:[rank0]:   File "/data/users/.../torchtitan/train.py", line 433, in <module>
[rank0]:[rank0]:     main(config)
[rank0]:[rank0]:   File "/data/users/.../pytorch/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 355, in wrapper
[rank0]:[rank0]:     return f(*args, **kwargs)
[rank0]:[rank0]:   File "/data/users/.../torchtitan/train.py", line 214, in main
[rank0]:[rank0]:     checkpoint.load(step=job_config.checkpoint.load_step)
[rank0]:[rank0]:   File "/data/users/.../torchtitan/torchtitan/checkpoint.py", line 441, in load
[rank0]:[rank0]:     dcp.load(
[rank0]:[rank0]:   File "/data/users/.../pytorch/torch/distributed/checkpoint/logger.py", line 83, in wrapper
[rank0]:[rank0]:     result = func(*args, **kwargs)
[rank0]:[rank0]:   File "/data/users/.../pytorch/torch/distributed/checkpoint/utils.py", line 438, in inner_func
[rank0]:[rank0]:     return func(*args, **kwargs)
[rank0]:[rank0]:   File "/data/users/.../pytorch/torch/distributed/checkpoint/state_dict_loader.py", line 188, in load
[rank0]:[rank0]:     elem.load_state_dict(statetful_sd[key])
[rank0]:[rank0]:   File "/data/users/.../torchtitan/torchtitan/datasets/hf_datasets.py", line 178, in load_state_dict
[rank0]:[rank0]:     self._world_size == state_dict["world_size"]
[rank0]:[rank0]: AssertionError: dp_degree is inconsistent before and after checkpoint, DataLoader resharding is not supported yet.
```
@fegin
Copy link
Contributor

fegin commented Feb 5, 2025

We can close this issue after #819 is landed.

@tianyu-l tianyu-l linked a pull request Feb 5, 2025 that will close this issue
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working documentation Improvements or additions to documentation enhancement New feature or request module: fsdp
Projects
None yet
Development

Successfully merging a pull request may close this issue.

5 participants