Skip to content

Commit

Permalink
Update on "Re-enable FSDP+TP w/ strided sharding"
Browse files Browse the repository at this point in the history
**Summary**
1. re-enable FSDP+TP 2D in torchtitan.
2. remove temporary re-enablement added in #460 

**Test**
Command: `python test_runner.py outputs --test pp_dp_tp --ngpu 8`
GPUs: A100
Output:
- without strided sharding:
```
[rank7]:2024-08-06 03:21:26,706 - root - INFO - step:  2  loss:  8.1652  memory:  0.51GiB(0.64%)  wps: 8,250  mfu: 0.25%
[rank7]:2024-08-06 03:21:27,013 - root - INFO - step:  3  loss:  8.0951  memory:  0.51GiB(0.64%)  wps: 13,358  mfu: 0.41%
[rank7]:2024-08-06 03:21:27,309 - root - INFO - step:  4  loss:  7.9748  memory:  0.51GiB(0.64%)  wps: 13,865  mfu: 0.42%
[rank7]:2024-08-06 03:21:27,582 - root - INFO - step:  5  loss:  7.8025  memory:  0.51GiB(0.64%)  wps: 15,057  mfu: 0.46%
[rank7]:2024-08-06 03:21:28,076 - root - INFO - step:  6  loss:  7.5612  memory:  0.51GiB(0.64%)  wps: 8,300  mfu: 0.25%
[rank7]:2024-08-06 03:21:28,608 - root - INFO - step:  7  loss:  7.3649  memory:  0.51GiB(0.64%)  wps: 7,705  mfu: 0.23%
[rank7]:2024-08-06 03:21:28,927 - root - INFO - step:  8  loss:  7.2946  memory:  0.51GiB(0.64%)  wps: 12,832  mfu: 0.39%
[rank7]:2024-08-06 03:21:29,251 - root - INFO - step:  9  loss:  7.1311  memory:  0.51GiB(0.64%)  wps: 12,669  mfu: 0.38%
[rank7]:2024-08-06 03:21:29,627 - root - INFO - step: 10  loss:  7.0540  memory:  0.51GiB(0.64%)  wps: 10,918  mfu: 0.33%
>>>>>>>>>>>>>>>>>Checkpoint save & load<<<<<<<<<<<<<<<<<<<
[rank7]:2024-08-06 03:21:59,723 - root - INFO - step: 11  loss:  7.0822  memory:  0.51GiB(0.64%)  wps: 1,139  mfu: 0.03%
[rank7]:2024-08-06 03:22:00,054 - root - INFO - step: 12  loss:  7.0508  memory:  0.51GiB(0.64%)  wps: 12,366  mfu: 0.38%
[rank7]:2024-08-06 03:22:00,340 - root - INFO - step: 13  loss:  6.9182  memory:  0.51GiB(0.64%)  wps: 14,370  mfu: 0.44%
[rank7]:2024-08-06 03:22:00,624 - root - INFO - step: 14  loss:  6.8948  memory:  0.51GiB(0.64%)  wps: 14,442  mfu: 0.44%
[rank7]:2024-08-06 03:22:00,907 - root - INFO - step: 15  loss:  6.8358  memory:  0.51GiB(0.64%)  wps: 14,514  mfu: 0.44%
[rank7]:2024-08-06 03:22:01,574 - root - INFO - step: 16  loss:  6.7653  memory:  0.51GiB(0.64%)  wps: 6,144  mfu: 0.19%
[rank7]:2024-08-06 03:22:02,209 - root - INFO - step: 17  loss:  6.7340  memory:  0.51GiB(0.64%)  wps: 6,453  mfu: 0.20%
[rank7]:2024-08-06 03:22:02,532 - root - INFO - step: 18  loss:  6.6874  memory:  0.51GiB(0.64%)  wps: 12,695  mfu: 0.39%
[rank7]:2024-08-06 03:22:02,863 - root - INFO - step: 19  loss:  6.6566  memory:  0.51GiB(0.64%)  wps: 12,406  mfu: 0.38%
[rank7]:2024-08-06 03:22:03,257 - root - INFO - step: 20  loss:  6.6629  memory:  0.51GiB(0.64%)  wps: 10,392  mfu: 0.32%
```
- with strided sharding
```
[rank7]:2024-08-06 03:26:18,288 - root - INFO - step:  1  loss:  8.2069  memory:  0.50GiB(0.63%)  wps: 915  mfu: 0.03%
[rank7]:2024-08-06 03:26:19,084 - root - INFO - step:  2  loss:  8.1913  memory:  0.51GiB(0.64%)  wps: 5,144  mfu: 0.16%
[rank7]:2024-08-06 03:26:19,365 - root - INFO - step:  3  loss:  8.1148  memory:  0.51GiB(0.64%)  wps: 14,593  mfu: 0.44%
[rank7]:2024-08-06 03:26:19,698 - root - INFO - step:  4  loss:  7.9982  memory:  0.51GiB(0.64%)  wps: 12,328  mfu: 0.37%
[rank7]:2024-08-06 03:26:20,011 - root - INFO - step:  5  loss:  7.8382  memory:  0.51GiB(0.64%)  wps: 13,100  mfu: 0.40%
[rank7]:2024-08-06 03:26:20,498 - root - INFO - step:  6  loss:  7.6293  memory:  0.51GiB(0.64%)  wps: 8,423  mfu: 0.26%
[rank7]:2024-08-06 03:26:21,126 - root - INFO - step:  7  loss:  7.4454  memory:  0.51GiB(0.64%)  wps: 6,530  mfu: 0.20%
[rank7]:2024-08-06 03:26:21,472 - root - INFO - step:  8  loss:  7.3337  memory:  0.51GiB(0.64%)  wps: 11,843  mfu: 0.36%
[rank7]:2024-08-06 03:26:21,849 - root - INFO - step:  9  loss:  7.1960  memory:  0.51GiB(0.64%)  wps: 10,892  mfu: 0.33%
[rank7]:2024-08-06 03:26:22,229 - root - INFO - step: 10  loss:  7.1208  memory:  0.51GiB(0.64%)  wps: 10,798  mfu: 0.33%
>>>>>>>>>>>>>>>>>Checkpoint save & load<<<<<<<<<<<<<<<<<<<
[rank7]:2024-08-06 03:26:50,306 - root - INFO - step: 11  loss:  7.1222  memory:  0.51GiB(0.64%)  wps: 866  mfu: 0.03%
[rank7]:2024-08-06 03:26:50,632 - root - INFO - step: 12  loss:  7.1189  memory:  0.51GiB(0.64%)  wps: 12,589  mfu: 0.38%
[rank7]:2024-08-06 03:26:50,917 - root - INFO - step: 13  loss:  6.9646  memory:  0.51GiB(0.64%)  wps: 14,417  mfu: 0.44%
[rank7]:2024-08-06 03:26:51,217 - root - INFO - step: 14  loss:  6.9626  memory:  0.51GiB(0.64%)  wps: 13,680  mfu: 0.42%
[rank7]:2024-08-06 03:26:51,514 - root - INFO - step: 15  loss:  6.8694  memory:  0.51GiB(0.64%)  wps: 13,799  mfu: 0.42%
[rank7]:2024-08-06 03:26:52,207 - root - INFO - step: 16  loss:  6.7994  memory:  0.51GiB(0.64%)  wps: 5,910  mfu: 0.18%
[rank7]:2024-08-06 03:26:53,053 - root - INFO - step: 17  loss:  6.7634  memory:  0.51GiB(0.64%)  wps: 4,847  mfu: 0.15%
[rank7]:2024-08-06 03:26:53,370 - root - INFO - step: 18  loss:  6.7233  memory:  0.51GiB(0.64%)  wps: 12,915  mfu: 0.39%
[rank7]:2024-08-06 03:26:53,686 - root - INFO - step: 19  loss:  6.7054  memory:  0.51GiB(0.64%)  wps: 12,995  mfu: 0.39%
[rank7]:2024-08-06 03:26:54,059 - root - INFO - step: 20  loss:  6.7130  memory:  0.51GiB(0.64%)  wps: 10,991  mfu: 0.33%
```

When to merge:
when pytorch/pytorch#130760 is in nightly build.


[ghstack-poisoned]
  • Loading branch information
XilunWu committed Aug 9, 2024
1 parent b69c617 commit 695b207
Showing 1 changed file with 17 additions and 0 deletions.
17 changes: 17 additions & 0 deletions torchtitan/parallelisms/parallelize_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,23 @@ def apply_fsdp(
)
fully_shard(model, **fsdp_config, reshard_after_forward=not pp_enabled)

# TODO: This PR https://github.com/pytorch/pytorch/pull/129519 added a safety check
# to avoid using 2D/3D DCP since without strided sharding, DCP can not safely support
# resharding for 2D/3D. We keep this safety check disablement here for now and will
# remove it when PyTorch gets the next minor release (PyTorch 2.5).
for module in model.modules():
assert len(module._load_state_dict_pre_hooks) <= 1
if len(module._load_state_dict_pre_hooks) == 1:
logger.warning(
"a safety check on 2D/3D DCP is detected. Please upgrade PyTorch to a "
"version newer than 2024-08-09 nightly release to include the change in "
"https://github.com/pytorch/pytorch/pull/130760"
)
module._load_state_dict_pre_hooks.clear()

assert len(module._state_dict_pre_hooks) <= 1
module._state_dict_pre_hooks.clear()

logger.info("Applied FSDP to the model")
return model

Expand Down

0 comments on commit 695b207

Please sign in to comment.