Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Update base for Update on "Re-enable FSDP+TP w/ strided sharding"
**Summary** 1. check if users are using new nightly-build pytorch that includes DTensor strided sharding when 2D/3D is used. Print warning if not. 2. remove temporary re-enablement added in #460 . **Test** Command: `python test_runner.py outputs --test pp_dp_tp --ngpu 8` GPUs: A100 Output: - without strided sharding: ``` [rank7]:2024-08-06 03:21:26,706 - root - INFO - step: 2 loss: 8.1652 memory: 0.51GiB(0.64%) wps: 8,250 mfu: 0.25% [rank7]:2024-08-06 03:21:27,013 - root - INFO - step: 3 loss: 8.0951 memory: 0.51GiB(0.64%) wps: 13,358 mfu: 0.41% [rank7]:2024-08-06 03:21:27,309 - root - INFO - step: 4 loss: 7.9748 memory: 0.51GiB(0.64%) wps: 13,865 mfu: 0.42% [rank7]:2024-08-06 03:21:27,582 - root - INFO - step: 5 loss: 7.8025 memory: 0.51GiB(0.64%) wps: 15,057 mfu: 0.46% [rank7]:2024-08-06 03:21:28,076 - root - INFO - step: 6 loss: 7.5612 memory: 0.51GiB(0.64%) wps: 8,300 mfu: 0.25% [rank7]:2024-08-06 03:21:28,608 - root - INFO - step: 7 loss: 7.3649 memory: 0.51GiB(0.64%) wps: 7,705 mfu: 0.23% [rank7]:2024-08-06 03:21:28,927 - root - INFO - step: 8 loss: 7.2946 memory: 0.51GiB(0.64%) wps: 12,832 mfu: 0.39% [rank7]:2024-08-06 03:21:29,251 - root - INFO - step: 9 loss: 7.1311 memory: 0.51GiB(0.64%) wps: 12,669 mfu: 0.38% [rank7]:2024-08-06 03:21:29,627 - root - INFO - step: 10 loss: 7.0540 memory: 0.51GiB(0.64%) wps: 10,918 mfu: 0.33% >>>>>>>>>>>>>>>>>Checkpoint save & load<<<<<<<<<<<<<<<<<<< [rank7]:2024-08-06 03:21:59,723 - root - INFO - step: 11 loss: 7.0822 memory: 0.51GiB(0.64%) wps: 1,139 mfu: 0.03% [rank7]:2024-08-06 03:22:00,054 - root - INFO - step: 12 loss: 7.0508 memory: 0.51GiB(0.64%) wps: 12,366 mfu: 0.38% [rank7]:2024-08-06 03:22:00,340 - root - INFO - step: 13 loss: 6.9182 memory: 0.51GiB(0.64%) wps: 14,370 mfu: 0.44% [rank7]:2024-08-06 03:22:00,624 - root - INFO - step: 14 loss: 6.8948 memory: 0.51GiB(0.64%) wps: 14,442 mfu: 0.44% [rank7]:2024-08-06 03:22:00,907 - root - INFO - step: 15 loss: 6.8358 memory: 0.51GiB(0.64%) wps: 14,514 mfu: 0.44% [rank7]:2024-08-06 03:22:01,574 - root - INFO - step: 16 loss: 6.7653 memory: 0.51GiB(0.64%) wps: 6,144 mfu: 0.19% [rank7]:2024-08-06 03:22:02,209 - root - INFO - step: 17 loss: 6.7340 memory: 0.51GiB(0.64%) wps: 6,453 mfu: 0.20% [rank7]:2024-08-06 03:22:02,532 - root - INFO - step: 18 loss: 6.6874 memory: 0.51GiB(0.64%) wps: 12,695 mfu: 0.39% [rank7]:2024-08-06 03:22:02,863 - root - INFO - step: 19 loss: 6.6566 memory: 0.51GiB(0.64%) wps: 12,406 mfu: 0.38% [rank7]:2024-08-06 03:22:03,257 - root - INFO - step: 20 loss: 6.6629 memory: 0.51GiB(0.64%) wps: 10,392 mfu: 0.32% ``` - with strided sharding ``` [rank7]:2024-08-06 03:26:18,288 - root - INFO - step: 1 loss: 8.2069 memory: 0.50GiB(0.63%) wps: 915 mfu: 0.03% [rank7]:2024-08-06 03:26:19,084 - root - INFO - step: 2 loss: 8.1913 memory: 0.51GiB(0.64%) wps: 5,144 mfu: 0.16% [rank7]:2024-08-06 03:26:19,365 - root - INFO - step: 3 loss: 8.1148 memory: 0.51GiB(0.64%) wps: 14,593 mfu: 0.44% [rank7]:2024-08-06 03:26:19,698 - root - INFO - step: 4 loss: 7.9982 memory: 0.51GiB(0.64%) wps: 12,328 mfu: 0.37% [rank7]:2024-08-06 03:26:20,011 - root - INFO - step: 5 loss: 7.8382 memory: 0.51GiB(0.64%) wps: 13,100 mfu: 0.40% [rank7]:2024-08-06 03:26:20,498 - root - INFO - step: 6 loss: 7.6293 memory: 0.51GiB(0.64%) wps: 8,423 mfu: 0.26% [rank7]:2024-08-06 03:26:21,126 - root - INFO - step: 7 loss: 7.4454 memory: 0.51GiB(0.64%) wps: 6,530 mfu: 0.20% [rank7]:2024-08-06 03:26:21,472 - root - INFO - step: 8 loss: 7.3337 memory: 0.51GiB(0.64%) wps: 11,843 mfu: 0.36% [rank7]:2024-08-06 03:26:21,849 - root - INFO - step: 9 loss: 7.1960 memory: 0.51GiB(0.64%) wps: 10,892 mfu: 0.33% [rank7]:2024-08-06 03:26:22,229 - root - INFO - step: 10 loss: 7.1208 memory: 0.51GiB(0.64%) wps: 10,798 mfu: 0.33% >>>>>>>>>>>>>>>>>Checkpoint save & load<<<<<<<<<<<<<<<<<<< [rank7]:2024-08-06 03:26:50,306 - root - INFO - step: 11 loss: 7.1222 memory: 0.51GiB(0.64%) wps: 866 mfu: 0.03% [rank7]:2024-08-06 03:26:50,632 - root - INFO - step: 12 loss: 7.1189 memory: 0.51GiB(0.64%) wps: 12,589 mfu: 0.38% [rank7]:2024-08-06 03:26:50,917 - root - INFO - step: 13 loss: 6.9646 memory: 0.51GiB(0.64%) wps: 14,417 mfu: 0.44% [rank7]:2024-08-06 03:26:51,217 - root - INFO - step: 14 loss: 6.9626 memory: 0.51GiB(0.64%) wps: 13,680 mfu: 0.42% [rank7]:2024-08-06 03:26:51,514 - root - INFO - step: 15 loss: 6.8694 memory: 0.51GiB(0.64%) wps: 13,799 mfu: 0.42% [rank7]:2024-08-06 03:26:52,207 - root - INFO - step: 16 loss: 6.7994 memory: 0.51GiB(0.64%) wps: 5,910 mfu: 0.18% [rank7]:2024-08-06 03:26:53,053 - root - INFO - step: 17 loss: 6.7634 memory: 0.51GiB(0.64%) wps: 4,847 mfu: 0.15% [rank7]:2024-08-06 03:26:53,370 - root - INFO - step: 18 loss: 6.7233 memory: 0.51GiB(0.64%) wps: 12,915 mfu: 0.39% [rank7]:2024-08-06 03:26:53,686 - root - INFO - step: 19 loss: 6.7054 memory: 0.51GiB(0.64%) wps: 12,995 mfu: 0.39% [rank7]:2024-08-06 03:26:54,059 - root - INFO - step: 20 loss: 6.7130 memory: 0.51GiB(0.64%) wps: 10,991 mfu: 0.33% ``` When to merge: when pytorch/pytorch#130760 is in nightly build. [ghstack-poisoned]
- Loading branch information