-
Notifications
You must be signed in to change notification settings - Fork 259
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Throw warning if users are using old pytorch version that not including DTensor strided sharding #507
Conversation
[ghstack-poisoned]
ghstack-source-id: c3e97f52a61e415fe216cc68460717e3eff04b58 Pull Request resolved: #507
**Summary** 1. re-enable FSDP+TP 2D in torchtitan. 2. remove temporary re-enablement added in #460 **Test** Command: `python test_runner.py outputs --test pp_dp_tp --ngpu 8` GPUs: A100 Output: - without strided sharding: ``` [rank7]:2024-08-06 03:21:26,706 - root - INFO - step: 2 loss: 8.1652 memory: 0.51GiB(0.64%) wps: 8,250 mfu: 0.25% [rank7]:2024-08-06 03:21:27,013 - root - INFO - step: 3 loss: 8.0951 memory: 0.51GiB(0.64%) wps: 13,358 mfu: 0.41% [rank7]:2024-08-06 03:21:27,309 - root - INFO - step: 4 loss: 7.9748 memory: 0.51GiB(0.64%) wps: 13,865 mfu: 0.42% [rank7]:2024-08-06 03:21:27,582 - root - INFO - step: 5 loss: 7.8025 memory: 0.51GiB(0.64%) wps: 15,057 mfu: 0.46% [rank7]:2024-08-06 03:21:28,076 - root - INFO - step: 6 loss: 7.5612 memory: 0.51GiB(0.64%) wps: 8,300 mfu: 0.25% [rank7]:2024-08-06 03:21:28,608 - root - INFO - step: 7 loss: 7.3649 memory: 0.51GiB(0.64%) wps: 7,705 mfu: 0.23% [rank7]:2024-08-06 03:21:28,927 - root - INFO - step: 8 loss: 7.2946 memory: 0.51GiB(0.64%) wps: 12,832 mfu: 0.39% [rank7]:2024-08-06 03:21:29,251 - root - INFO - step: 9 loss: 7.1311 memory: 0.51GiB(0.64%) wps: 12,669 mfu: 0.38% [rank7]:2024-08-06 03:21:29,627 - root - INFO - step: 10 loss: 7.0540 memory: 0.51GiB(0.64%) wps: 10,918 mfu: 0.33% >>>>>>>>>>>>>>>>>Checkpoint save & load<<<<<<<<<<<<<<<<<<< [rank7]:2024-08-06 03:21:59,723 - root - INFO - step: 11 loss: 7.0822 memory: 0.51GiB(0.64%) wps: 1,139 mfu: 0.03% [rank7]:2024-08-06 03:22:00,054 - root - INFO - step: 12 loss: 7.0508 memory: 0.51GiB(0.64%) wps: 12,366 mfu: 0.38% [rank7]:2024-08-06 03:22:00,340 - root - INFO - step: 13 loss: 6.9182 memory: 0.51GiB(0.64%) wps: 14,370 mfu: 0.44% [rank7]:2024-08-06 03:22:00,624 - root - INFO - step: 14 loss: 6.8948 memory: 0.51GiB(0.64%) wps: 14,442 mfu: 0.44% [rank7]:2024-08-06 03:22:00,907 - root - INFO - step: 15 loss: 6.8358 memory: 0.51GiB(0.64%) wps: 14,514 mfu: 0.44% [rank7]:2024-08-06 03:22:01,574 - root - INFO - step: 16 loss: 6.7653 memory: 0.51GiB(0.64%) wps: 6,144 mfu: 0.19% [rank7]:2024-08-06 03:22:02,209 - root - INFO - step: 17 loss: 6.7340 memory: 0.51GiB(0.64%) wps: 6,453 mfu: 0.20% [rank7]:2024-08-06 03:22:02,532 - root - INFO - step: 18 loss: 6.6874 memory: 0.51GiB(0.64%) wps: 12,695 mfu: 0.39% [rank7]:2024-08-06 03:22:02,863 - root - INFO - step: 19 loss: 6.6566 memory: 0.51GiB(0.64%) wps: 12,406 mfu: 0.38% [rank7]:2024-08-06 03:22:03,257 - root - INFO - step: 20 loss: 6.6629 memory: 0.51GiB(0.64%) wps: 10,392 mfu: 0.32% ``` - with strided sharding ``` [rank7]:2024-08-06 03:26:18,288 - root - INFO - step: 1 loss: 8.2069 memory: 0.50GiB(0.63%) wps: 915 mfu: 0.03% [rank7]:2024-08-06 03:26:19,084 - root - INFO - step: 2 loss: 8.1913 memory: 0.51GiB(0.64%) wps: 5,144 mfu: 0.16% [rank7]:2024-08-06 03:26:19,365 - root - INFO - step: 3 loss: 8.1148 memory: 0.51GiB(0.64%) wps: 14,593 mfu: 0.44% [rank7]:2024-08-06 03:26:19,698 - root - INFO - step: 4 loss: 7.9982 memory: 0.51GiB(0.64%) wps: 12,328 mfu: 0.37% [rank7]:2024-08-06 03:26:20,011 - root - INFO - step: 5 loss: 7.8382 memory: 0.51GiB(0.64%) wps: 13,100 mfu: 0.40% [rank7]:2024-08-06 03:26:20,498 - root - INFO - step: 6 loss: 7.6293 memory: 0.51GiB(0.64%) wps: 8,423 mfu: 0.26% [rank7]:2024-08-06 03:26:21,126 - root - INFO - step: 7 loss: 7.4454 memory: 0.51GiB(0.64%) wps: 6,530 mfu: 0.20% [rank7]:2024-08-06 03:26:21,472 - root - INFO - step: 8 loss: 7.3337 memory: 0.51GiB(0.64%) wps: 11,843 mfu: 0.36% [rank7]:2024-08-06 03:26:21,849 - root - INFO - step: 9 loss: 7.1960 memory: 0.51GiB(0.64%) wps: 10,892 mfu: 0.33% [rank7]:2024-08-06 03:26:22,229 - root - INFO - step: 10 loss: 7.1208 memory: 0.51GiB(0.64%) wps: 10,798 mfu: 0.33% >>>>>>>>>>>>>>>>>Checkpoint save & load<<<<<<<<<<<<<<<<<<< [rank7]:2024-08-06 03:26:50,306 - root - INFO - step: 11 loss: 7.1222 memory: 0.51GiB(0.64%) wps: 866 mfu: 0.03% [rank7]:2024-08-06 03:26:50,632 - root - INFO - step: 12 loss: 7.1189 memory: 0.51GiB(0.64%) wps: 12,589 mfu: 0.38% [rank7]:2024-08-06 03:26:50,917 - root - INFO - step: 13 loss: 6.9646 memory: 0.51GiB(0.64%) wps: 14,417 mfu: 0.44% [rank7]:2024-08-06 03:26:51,217 - root - INFO - step: 14 loss: 6.9626 memory: 0.51GiB(0.64%) wps: 13,680 mfu: 0.42% [rank7]:2024-08-06 03:26:51,514 - root - INFO - step: 15 loss: 6.8694 memory: 0.51GiB(0.64%) wps: 13,799 mfu: 0.42% [rank7]:2024-08-06 03:26:52,207 - root - INFO - step: 16 loss: 6.7994 memory: 0.51GiB(0.64%) wps: 5,910 mfu: 0.18% [rank7]:2024-08-06 03:26:53,053 - root - INFO - step: 17 loss: 6.7634 memory: 0.51GiB(0.64%) wps: 4,847 mfu: 0.15% [rank7]:2024-08-06 03:26:53,370 - root - INFO - step: 18 loss: 6.7233 memory: 0.51GiB(0.64%) wps: 12,915 mfu: 0.39% [rank7]:2024-08-06 03:26:53,686 - root - INFO - step: 19 loss: 6.7054 memory: 0.51GiB(0.64%) wps: 12,995 mfu: 0.39% [rank7]:2024-08-06 03:26:54,059 - root - INFO - step: 20 loss: 6.7130 memory: 0.51GiB(0.64%) wps: 10,991 mfu: 0.33% ``` When to merge: when pytorch/pytorch#130760 is in nightly build. [ghstack-poisoned]
ghstack-source-id: a746faaa43b35b59eaeef9dc4494a869459c511e Pull Request resolved: #507
@XilunWu FSDP2 registers load state dict post hooks because the loaded |
@awgu This detection is done before the first iteration. iirc, FSDP2 lazily initializes the load state dict hook? |
I think it is done at construction time :/ (see the code pointer in #507 (comment)). |
That is post hook? The hooks we remove here are pre hook for state_dict and load_state_dict. The hooks I mentioned are https://github.com/pytorch/pytorch/blob/e7b870c88bc3b854a95399a96a274d2f1f908172/torch/distributed/_composable/fsdp/_fsdp_param_group.py#L505 but I think they are installed during the first iteration? |
@fegin ah yes, those hooks are initialized on first iteration (lazy init). |
**Summary** 1. check if users are using new nightly-build pytorch that includes DTensor strided sharding when 2D/3D is used. Print warning if not. 2. remove temporary re-enablement added in #460 . **Test** Command: `python test_runner.py outputs --test pp_dp_tp --ngpu 8` GPUs: A100 Output: - without strided sharding: ``` [rank7]:2024-08-06 03:21:26,706 - root - INFO - step: 2 loss: 8.1652 memory: 0.51GiB(0.64%) wps: 8,250 mfu: 0.25% [rank7]:2024-08-06 03:21:27,013 - root - INFO - step: 3 loss: 8.0951 memory: 0.51GiB(0.64%) wps: 13,358 mfu: 0.41% [rank7]:2024-08-06 03:21:27,309 - root - INFO - step: 4 loss: 7.9748 memory: 0.51GiB(0.64%) wps: 13,865 mfu: 0.42% [rank7]:2024-08-06 03:21:27,582 - root - INFO - step: 5 loss: 7.8025 memory: 0.51GiB(0.64%) wps: 15,057 mfu: 0.46% [rank7]:2024-08-06 03:21:28,076 - root - INFO - step: 6 loss: 7.5612 memory: 0.51GiB(0.64%) wps: 8,300 mfu: 0.25% [rank7]:2024-08-06 03:21:28,608 - root - INFO - step: 7 loss: 7.3649 memory: 0.51GiB(0.64%) wps: 7,705 mfu: 0.23% [rank7]:2024-08-06 03:21:28,927 - root - INFO - step: 8 loss: 7.2946 memory: 0.51GiB(0.64%) wps: 12,832 mfu: 0.39% [rank7]:2024-08-06 03:21:29,251 - root - INFO - step: 9 loss: 7.1311 memory: 0.51GiB(0.64%) wps: 12,669 mfu: 0.38% [rank7]:2024-08-06 03:21:29,627 - root - INFO - step: 10 loss: 7.0540 memory: 0.51GiB(0.64%) wps: 10,918 mfu: 0.33% >>>>>>>>>>>>>>>>>Checkpoint save & load<<<<<<<<<<<<<<<<<<< [rank7]:2024-08-06 03:21:59,723 - root - INFO - step: 11 loss: 7.0822 memory: 0.51GiB(0.64%) wps: 1,139 mfu: 0.03% [rank7]:2024-08-06 03:22:00,054 - root - INFO - step: 12 loss: 7.0508 memory: 0.51GiB(0.64%) wps: 12,366 mfu: 0.38% [rank7]:2024-08-06 03:22:00,340 - root - INFO - step: 13 loss: 6.9182 memory: 0.51GiB(0.64%) wps: 14,370 mfu: 0.44% [rank7]:2024-08-06 03:22:00,624 - root - INFO - step: 14 loss: 6.8948 memory: 0.51GiB(0.64%) wps: 14,442 mfu: 0.44% [rank7]:2024-08-06 03:22:00,907 - root - INFO - step: 15 loss: 6.8358 memory: 0.51GiB(0.64%) wps: 14,514 mfu: 0.44% [rank7]:2024-08-06 03:22:01,574 - root - INFO - step: 16 loss: 6.7653 memory: 0.51GiB(0.64%) wps: 6,144 mfu: 0.19% [rank7]:2024-08-06 03:22:02,209 - root - INFO - step: 17 loss: 6.7340 memory: 0.51GiB(0.64%) wps: 6,453 mfu: 0.20% [rank7]:2024-08-06 03:22:02,532 - root - INFO - step: 18 loss: 6.6874 memory: 0.51GiB(0.64%) wps: 12,695 mfu: 0.39% [rank7]:2024-08-06 03:22:02,863 - root - INFO - step: 19 loss: 6.6566 memory: 0.51GiB(0.64%) wps: 12,406 mfu: 0.38% [rank7]:2024-08-06 03:22:03,257 - root - INFO - step: 20 loss: 6.6629 memory: 0.51GiB(0.64%) wps: 10,392 mfu: 0.32% ``` - with strided sharding ``` [rank7]:2024-08-06 03:26:18,288 - root - INFO - step: 1 loss: 8.2069 memory: 0.50GiB(0.63%) wps: 915 mfu: 0.03% [rank7]:2024-08-06 03:26:19,084 - root - INFO - step: 2 loss: 8.1913 memory: 0.51GiB(0.64%) wps: 5,144 mfu: 0.16% [rank7]:2024-08-06 03:26:19,365 - root - INFO - step: 3 loss: 8.1148 memory: 0.51GiB(0.64%) wps: 14,593 mfu: 0.44% [rank7]:2024-08-06 03:26:19,698 - root - INFO - step: 4 loss: 7.9982 memory: 0.51GiB(0.64%) wps: 12,328 mfu: 0.37% [rank7]:2024-08-06 03:26:20,011 - root - INFO - step: 5 loss: 7.8382 memory: 0.51GiB(0.64%) wps: 13,100 mfu: 0.40% [rank7]:2024-08-06 03:26:20,498 - root - INFO - step: 6 loss: 7.6293 memory: 0.51GiB(0.64%) wps: 8,423 mfu: 0.26% [rank7]:2024-08-06 03:26:21,126 - root - INFO - step: 7 loss: 7.4454 memory: 0.51GiB(0.64%) wps: 6,530 mfu: 0.20% [rank7]:2024-08-06 03:26:21,472 - root - INFO - step: 8 loss: 7.3337 memory: 0.51GiB(0.64%) wps: 11,843 mfu: 0.36% [rank7]:2024-08-06 03:26:21,849 - root - INFO - step: 9 loss: 7.1960 memory: 0.51GiB(0.64%) wps: 10,892 mfu: 0.33% [rank7]:2024-08-06 03:26:22,229 - root - INFO - step: 10 loss: 7.1208 memory: 0.51GiB(0.64%) wps: 10,798 mfu: 0.33% >>>>>>>>>>>>>>>>>Checkpoint save & load<<<<<<<<<<<<<<<<<<< [rank7]:2024-08-06 03:26:50,306 - root - INFO - step: 11 loss: 7.1222 memory: 0.51GiB(0.64%) wps: 866 mfu: 0.03% [rank7]:2024-08-06 03:26:50,632 - root - INFO - step: 12 loss: 7.1189 memory: 0.51GiB(0.64%) wps: 12,589 mfu: 0.38% [rank7]:2024-08-06 03:26:50,917 - root - INFO - step: 13 loss: 6.9646 memory: 0.51GiB(0.64%) wps: 14,417 mfu: 0.44% [rank7]:2024-08-06 03:26:51,217 - root - INFO - step: 14 loss: 6.9626 memory: 0.51GiB(0.64%) wps: 13,680 mfu: 0.42% [rank7]:2024-08-06 03:26:51,514 - root - INFO - step: 15 loss: 6.8694 memory: 0.51GiB(0.64%) wps: 13,799 mfu: 0.42% [rank7]:2024-08-06 03:26:52,207 - root - INFO - step: 16 loss: 6.7994 memory: 0.51GiB(0.64%) wps: 5,910 mfu: 0.18% [rank7]:2024-08-06 03:26:53,053 - root - INFO - step: 17 loss: 6.7634 memory: 0.51GiB(0.64%) wps: 4,847 mfu: 0.15% [rank7]:2024-08-06 03:26:53,370 - root - INFO - step: 18 loss: 6.7233 memory: 0.51GiB(0.64%) wps: 12,915 mfu: 0.39% [rank7]:2024-08-06 03:26:53,686 - root - INFO - step: 19 loss: 6.7054 memory: 0.51GiB(0.64%) wps: 12,995 mfu: 0.39% [rank7]:2024-08-06 03:26:54,059 - root - INFO - step: 20 loss: 6.7130 memory: 0.51GiB(0.64%) wps: 10,991 mfu: 0.33% ``` When to merge: when pytorch/pytorch#130760 is in nightly build. [ghstack-poisoned]
**Summary** 1. check if users are using new nightly-build pytorch that includes DTensor strided sharding when 2D/3D is used. Print warning if not. 2. remove temporary re-enablement added in #460 . **Test** Command: `python test_runner.py outputs --test pp_dp_tp --ngpu 8` GPUs: A100 Output: - without strided sharding: ``` [rank7]:2024-08-06 03:21:26,706 - root - INFO - step: 2 loss: 8.1652 memory: 0.51GiB(0.64%) wps: 8,250 mfu: 0.25% [rank7]:2024-08-06 03:21:27,013 - root - INFO - step: 3 loss: 8.0951 memory: 0.51GiB(0.64%) wps: 13,358 mfu: 0.41% [rank7]:2024-08-06 03:21:27,309 - root - INFO - step: 4 loss: 7.9748 memory: 0.51GiB(0.64%) wps: 13,865 mfu: 0.42% [rank7]:2024-08-06 03:21:27,582 - root - INFO - step: 5 loss: 7.8025 memory: 0.51GiB(0.64%) wps: 15,057 mfu: 0.46% [rank7]:2024-08-06 03:21:28,076 - root - INFO - step: 6 loss: 7.5612 memory: 0.51GiB(0.64%) wps: 8,300 mfu: 0.25% [rank7]:2024-08-06 03:21:28,608 - root - INFO - step: 7 loss: 7.3649 memory: 0.51GiB(0.64%) wps: 7,705 mfu: 0.23% [rank7]:2024-08-06 03:21:28,927 - root - INFO - step: 8 loss: 7.2946 memory: 0.51GiB(0.64%) wps: 12,832 mfu: 0.39% [rank7]:2024-08-06 03:21:29,251 - root - INFO - step: 9 loss: 7.1311 memory: 0.51GiB(0.64%) wps: 12,669 mfu: 0.38% [rank7]:2024-08-06 03:21:29,627 - root - INFO - step: 10 loss: 7.0540 memory: 0.51GiB(0.64%) wps: 10,918 mfu: 0.33% >>>>>>>>>>>>>>>>>Checkpoint save & load<<<<<<<<<<<<<<<<<<< [rank7]:2024-08-06 03:21:59,723 - root - INFO - step: 11 loss: 7.0822 memory: 0.51GiB(0.64%) wps: 1,139 mfu: 0.03% [rank7]:2024-08-06 03:22:00,054 - root - INFO - step: 12 loss: 7.0508 memory: 0.51GiB(0.64%) wps: 12,366 mfu: 0.38% [rank7]:2024-08-06 03:22:00,340 - root - INFO - step: 13 loss: 6.9182 memory: 0.51GiB(0.64%) wps: 14,370 mfu: 0.44% [rank7]:2024-08-06 03:22:00,624 - root - INFO - step: 14 loss: 6.8948 memory: 0.51GiB(0.64%) wps: 14,442 mfu: 0.44% [rank7]:2024-08-06 03:22:00,907 - root - INFO - step: 15 loss: 6.8358 memory: 0.51GiB(0.64%) wps: 14,514 mfu: 0.44% [rank7]:2024-08-06 03:22:01,574 - root - INFO - step: 16 loss: 6.7653 memory: 0.51GiB(0.64%) wps: 6,144 mfu: 0.19% [rank7]:2024-08-06 03:22:02,209 - root - INFO - step: 17 loss: 6.7340 memory: 0.51GiB(0.64%) wps: 6,453 mfu: 0.20% [rank7]:2024-08-06 03:22:02,532 - root - INFO - step: 18 loss: 6.6874 memory: 0.51GiB(0.64%) wps: 12,695 mfu: 0.39% [rank7]:2024-08-06 03:22:02,863 - root - INFO - step: 19 loss: 6.6566 memory: 0.51GiB(0.64%) wps: 12,406 mfu: 0.38% [rank7]:2024-08-06 03:22:03,257 - root - INFO - step: 20 loss: 6.6629 memory: 0.51GiB(0.64%) wps: 10,392 mfu: 0.32% ``` - with strided sharding ``` [rank7]:2024-08-06 03:26:18,288 - root - INFO - step: 1 loss: 8.2069 memory: 0.50GiB(0.63%) wps: 915 mfu: 0.03% [rank7]:2024-08-06 03:26:19,084 - root - INFO - step: 2 loss: 8.1913 memory: 0.51GiB(0.64%) wps: 5,144 mfu: 0.16% [rank7]:2024-08-06 03:26:19,365 - root - INFO - step: 3 loss: 8.1148 memory: 0.51GiB(0.64%) wps: 14,593 mfu: 0.44% [rank7]:2024-08-06 03:26:19,698 - root - INFO - step: 4 loss: 7.9982 memory: 0.51GiB(0.64%) wps: 12,328 mfu: 0.37% [rank7]:2024-08-06 03:26:20,011 - root - INFO - step: 5 loss: 7.8382 memory: 0.51GiB(0.64%) wps: 13,100 mfu: 0.40% [rank7]:2024-08-06 03:26:20,498 - root - INFO - step: 6 loss: 7.6293 memory: 0.51GiB(0.64%) wps: 8,423 mfu: 0.26% [rank7]:2024-08-06 03:26:21,126 - root - INFO - step: 7 loss: 7.4454 memory: 0.51GiB(0.64%) wps: 6,530 mfu: 0.20% [rank7]:2024-08-06 03:26:21,472 - root - INFO - step: 8 loss: 7.3337 memory: 0.51GiB(0.64%) wps: 11,843 mfu: 0.36% [rank7]:2024-08-06 03:26:21,849 - root - INFO - step: 9 loss: 7.1960 memory: 0.51GiB(0.64%) wps: 10,892 mfu: 0.33% [rank7]:2024-08-06 03:26:22,229 - root - INFO - step: 10 loss: 7.1208 memory: 0.51GiB(0.64%) wps: 10,798 mfu: 0.33% >>>>>>>>>>>>>>>>>Checkpoint save & load<<<<<<<<<<<<<<<<<<< [rank7]:2024-08-06 03:26:50,306 - root - INFO - step: 11 loss: 7.1222 memory: 0.51GiB(0.64%) wps: 866 mfu: 0.03% [rank7]:2024-08-06 03:26:50,632 - root - INFO - step: 12 loss: 7.1189 memory: 0.51GiB(0.64%) wps: 12,589 mfu: 0.38% [rank7]:2024-08-06 03:26:50,917 - root - INFO - step: 13 loss: 6.9646 memory: 0.51GiB(0.64%) wps: 14,417 mfu: 0.44% [rank7]:2024-08-06 03:26:51,217 - root - INFO - step: 14 loss: 6.9626 memory: 0.51GiB(0.64%) wps: 13,680 mfu: 0.42% [rank7]:2024-08-06 03:26:51,514 - root - INFO - step: 15 loss: 6.8694 memory: 0.51GiB(0.64%) wps: 13,799 mfu: 0.42% [rank7]:2024-08-06 03:26:52,207 - root - INFO - step: 16 loss: 6.7994 memory: 0.51GiB(0.64%) wps: 5,910 mfu: 0.18% [rank7]:2024-08-06 03:26:53,053 - root - INFO - step: 17 loss: 6.7634 memory: 0.51GiB(0.64%) wps: 4,847 mfu: 0.15% [rank7]:2024-08-06 03:26:53,370 - root - INFO - step: 18 loss: 6.7233 memory: 0.51GiB(0.64%) wps: 12,915 mfu: 0.39% [rank7]:2024-08-06 03:26:53,686 - root - INFO - step: 19 loss: 6.7054 memory: 0.51GiB(0.64%) wps: 12,995 mfu: 0.39% [rank7]:2024-08-06 03:26:54,059 - root - INFO - step: 20 loss: 6.7130 memory: 0.51GiB(0.64%) wps: 10,991 mfu: 0.33% ``` When to merge: when pytorch/pytorch#130760 is in nightly build. [ghstack-poisoned]
ghstack-source-id: 067f89a82b642003858afb576953447a2739fc51 Pull Request resolved: #507
**Summary** 1. check if users are using new nightly-build pytorch that includes DTensor strided sharding when 2D/3D is used. Print warning if not. 2. remove temporary re-enablement added in #460 . **Test** Command: `python test_runner.py outputs --test pp_dp_tp --ngpu 8` GPUs: A100 Output: - without strided sharding: ``` [rank7]:2024-08-06 03:21:26,706 - root - INFO - step: 2 loss: 8.1652 memory: 0.51GiB(0.64%) wps: 8,250 mfu: 0.25% [rank7]:2024-08-06 03:21:27,013 - root - INFO - step: 3 loss: 8.0951 memory: 0.51GiB(0.64%) wps: 13,358 mfu: 0.41% [rank7]:2024-08-06 03:21:27,309 - root - INFO - step: 4 loss: 7.9748 memory: 0.51GiB(0.64%) wps: 13,865 mfu: 0.42% [rank7]:2024-08-06 03:21:27,582 - root - INFO - step: 5 loss: 7.8025 memory: 0.51GiB(0.64%) wps: 15,057 mfu: 0.46% [rank7]:2024-08-06 03:21:28,076 - root - INFO - step: 6 loss: 7.5612 memory: 0.51GiB(0.64%) wps: 8,300 mfu: 0.25% [rank7]:2024-08-06 03:21:28,608 - root - INFO - step: 7 loss: 7.3649 memory: 0.51GiB(0.64%) wps: 7,705 mfu: 0.23% [rank7]:2024-08-06 03:21:28,927 - root - INFO - step: 8 loss: 7.2946 memory: 0.51GiB(0.64%) wps: 12,832 mfu: 0.39% [rank7]:2024-08-06 03:21:29,251 - root - INFO - step: 9 loss: 7.1311 memory: 0.51GiB(0.64%) wps: 12,669 mfu: 0.38% [rank7]:2024-08-06 03:21:29,627 - root - INFO - step: 10 loss: 7.0540 memory: 0.51GiB(0.64%) wps: 10,918 mfu: 0.33% >>>>>>>>>>>>>>>>>Checkpoint save & load<<<<<<<<<<<<<<<<<<< [rank7]:2024-08-06 03:21:59,723 - root - INFO - step: 11 loss: 7.0822 memory: 0.51GiB(0.64%) wps: 1,139 mfu: 0.03% [rank7]:2024-08-06 03:22:00,054 - root - INFO - step: 12 loss: 7.0508 memory: 0.51GiB(0.64%) wps: 12,366 mfu: 0.38% [rank7]:2024-08-06 03:22:00,340 - root - INFO - step: 13 loss: 6.9182 memory: 0.51GiB(0.64%) wps: 14,370 mfu: 0.44% [rank7]:2024-08-06 03:22:00,624 - root - INFO - step: 14 loss: 6.8948 memory: 0.51GiB(0.64%) wps: 14,442 mfu: 0.44% [rank7]:2024-08-06 03:22:00,907 - root - INFO - step: 15 loss: 6.8358 memory: 0.51GiB(0.64%) wps: 14,514 mfu: 0.44% [rank7]:2024-08-06 03:22:01,574 - root - INFO - step: 16 loss: 6.7653 memory: 0.51GiB(0.64%) wps: 6,144 mfu: 0.19% [rank7]:2024-08-06 03:22:02,209 - root - INFO - step: 17 loss: 6.7340 memory: 0.51GiB(0.64%) wps: 6,453 mfu: 0.20% [rank7]:2024-08-06 03:22:02,532 - root - INFO - step: 18 loss: 6.6874 memory: 0.51GiB(0.64%) wps: 12,695 mfu: 0.39% [rank7]:2024-08-06 03:22:02,863 - root - INFO - step: 19 loss: 6.6566 memory: 0.51GiB(0.64%) wps: 12,406 mfu: 0.38% [rank7]:2024-08-06 03:22:03,257 - root - INFO - step: 20 loss: 6.6629 memory: 0.51GiB(0.64%) wps: 10,392 mfu: 0.32% ``` - with strided sharding ``` [rank7]:2024-08-06 03:26:18,288 - root - INFO - step: 1 loss: 8.2069 memory: 0.50GiB(0.63%) wps: 915 mfu: 0.03% [rank7]:2024-08-06 03:26:19,084 - root - INFO - step: 2 loss: 8.1913 memory: 0.51GiB(0.64%) wps: 5,144 mfu: 0.16% [rank7]:2024-08-06 03:26:19,365 - root - INFO - step: 3 loss: 8.1148 memory: 0.51GiB(0.64%) wps: 14,593 mfu: 0.44% [rank7]:2024-08-06 03:26:19,698 - root - INFO - step: 4 loss: 7.9982 memory: 0.51GiB(0.64%) wps: 12,328 mfu: 0.37% [rank7]:2024-08-06 03:26:20,011 - root - INFO - step: 5 loss: 7.8382 memory: 0.51GiB(0.64%) wps: 13,100 mfu: 0.40% [rank7]:2024-08-06 03:26:20,498 - root - INFO - step: 6 loss: 7.6293 memory: 0.51GiB(0.64%) wps: 8,423 mfu: 0.26% [rank7]:2024-08-06 03:26:21,126 - root - INFO - step: 7 loss: 7.4454 memory: 0.51GiB(0.64%) wps: 6,530 mfu: 0.20% [rank7]:2024-08-06 03:26:21,472 - root - INFO - step: 8 loss: 7.3337 memory: 0.51GiB(0.64%) wps: 11,843 mfu: 0.36% [rank7]:2024-08-06 03:26:21,849 - root - INFO - step: 9 loss: 7.1960 memory: 0.51GiB(0.64%) wps: 10,892 mfu: 0.33% [rank7]:2024-08-06 03:26:22,229 - root - INFO - step: 10 loss: 7.1208 memory: 0.51GiB(0.64%) wps: 10,798 mfu: 0.33% >>>>>>>>>>>>>>>>>Checkpoint save & load<<<<<<<<<<<<<<<<<<< [rank7]:2024-08-06 03:26:50,306 - root - INFO - step: 11 loss: 7.1222 memory: 0.51GiB(0.64%) wps: 866 mfu: 0.03% [rank7]:2024-08-06 03:26:50,632 - root - INFO - step: 12 loss: 7.1189 memory: 0.51GiB(0.64%) wps: 12,589 mfu: 0.38% [rank7]:2024-08-06 03:26:50,917 - root - INFO - step: 13 loss: 6.9646 memory: 0.51GiB(0.64%) wps: 14,417 mfu: 0.44% [rank7]:2024-08-06 03:26:51,217 - root - INFO - step: 14 loss: 6.9626 memory: 0.51GiB(0.64%) wps: 13,680 mfu: 0.42% [rank7]:2024-08-06 03:26:51,514 - root - INFO - step: 15 loss: 6.8694 memory: 0.51GiB(0.64%) wps: 13,799 mfu: 0.42% [rank7]:2024-08-06 03:26:52,207 - root - INFO - step: 16 loss: 6.7994 memory: 0.51GiB(0.64%) wps: 5,910 mfu: 0.18% [rank7]:2024-08-06 03:26:53,053 - root - INFO - step: 17 loss: 6.7634 memory: 0.51GiB(0.64%) wps: 4,847 mfu: 0.15% [rank7]:2024-08-06 03:26:53,370 - root - INFO - step: 18 loss: 6.7233 memory: 0.51GiB(0.64%) wps: 12,915 mfu: 0.39% [rank7]:2024-08-06 03:26:53,686 - root - INFO - step: 19 loss: 6.7054 memory: 0.51GiB(0.64%) wps: 12,995 mfu: 0.39% [rank7]:2024-08-06 03:26:54,059 - root - INFO - step: 20 loss: 6.7130 memory: 0.51GiB(0.64%) wps: 10,991 mfu: 0.33% ``` When to merge: when pytorch/pytorch#130760 is in nightly build. [ghstack-poisoned]
**Summary** 1. check if users are using new nightly-build pytorch that includes DTensor strided sharding when 2D/3D is used. Print warning if not. 2. remove temporary re-enablement added in #460 . **Test** Command: `python test_runner.py outputs --test pp_dp_tp --ngpu 8` GPUs: A100 Output: - without strided sharding: ``` [rank7]:2024-08-06 03:21:26,706 - root - INFO - step: 2 loss: 8.1652 memory: 0.51GiB(0.64%) wps: 8,250 mfu: 0.25% [rank7]:2024-08-06 03:21:27,013 - root - INFO - step: 3 loss: 8.0951 memory: 0.51GiB(0.64%) wps: 13,358 mfu: 0.41% [rank7]:2024-08-06 03:21:27,309 - root - INFO - step: 4 loss: 7.9748 memory: 0.51GiB(0.64%) wps: 13,865 mfu: 0.42% [rank7]:2024-08-06 03:21:27,582 - root - INFO - step: 5 loss: 7.8025 memory: 0.51GiB(0.64%) wps: 15,057 mfu: 0.46% [rank7]:2024-08-06 03:21:28,076 - root - INFO - step: 6 loss: 7.5612 memory: 0.51GiB(0.64%) wps: 8,300 mfu: 0.25% [rank7]:2024-08-06 03:21:28,608 - root - INFO - step: 7 loss: 7.3649 memory: 0.51GiB(0.64%) wps: 7,705 mfu: 0.23% [rank7]:2024-08-06 03:21:28,927 - root - INFO - step: 8 loss: 7.2946 memory: 0.51GiB(0.64%) wps: 12,832 mfu: 0.39% [rank7]:2024-08-06 03:21:29,251 - root - INFO - step: 9 loss: 7.1311 memory: 0.51GiB(0.64%) wps: 12,669 mfu: 0.38% [rank7]:2024-08-06 03:21:29,627 - root - INFO - step: 10 loss: 7.0540 memory: 0.51GiB(0.64%) wps: 10,918 mfu: 0.33% >>>>>>>>>>>>>>>>>Checkpoint save & load<<<<<<<<<<<<<<<<<<< [rank7]:2024-08-06 03:21:59,723 - root - INFO - step: 11 loss: 7.0822 memory: 0.51GiB(0.64%) wps: 1,139 mfu: 0.03% [rank7]:2024-08-06 03:22:00,054 - root - INFO - step: 12 loss: 7.0508 memory: 0.51GiB(0.64%) wps: 12,366 mfu: 0.38% [rank7]:2024-08-06 03:22:00,340 - root - INFO - step: 13 loss: 6.9182 memory: 0.51GiB(0.64%) wps: 14,370 mfu: 0.44% [rank7]:2024-08-06 03:22:00,624 - root - INFO - step: 14 loss: 6.8948 memory: 0.51GiB(0.64%) wps: 14,442 mfu: 0.44% [rank7]:2024-08-06 03:22:00,907 - root - INFO - step: 15 loss: 6.8358 memory: 0.51GiB(0.64%) wps: 14,514 mfu: 0.44% [rank7]:2024-08-06 03:22:01,574 - root - INFO - step: 16 loss: 6.7653 memory: 0.51GiB(0.64%) wps: 6,144 mfu: 0.19% [rank7]:2024-08-06 03:22:02,209 - root - INFO - step: 17 loss: 6.7340 memory: 0.51GiB(0.64%) wps: 6,453 mfu: 0.20% [rank7]:2024-08-06 03:22:02,532 - root - INFO - step: 18 loss: 6.6874 memory: 0.51GiB(0.64%) wps: 12,695 mfu: 0.39% [rank7]:2024-08-06 03:22:02,863 - root - INFO - step: 19 loss: 6.6566 memory: 0.51GiB(0.64%) wps: 12,406 mfu: 0.38% [rank7]:2024-08-06 03:22:03,257 - root - INFO - step: 20 loss: 6.6629 memory: 0.51GiB(0.64%) wps: 10,392 mfu: 0.32% ``` - with strided sharding ``` [rank7]:2024-08-06 03:26:18,288 - root - INFO - step: 1 loss: 8.2069 memory: 0.50GiB(0.63%) wps: 915 mfu: 0.03% [rank7]:2024-08-06 03:26:19,084 - root - INFO - step: 2 loss: 8.1913 memory: 0.51GiB(0.64%) wps: 5,144 mfu: 0.16% [rank7]:2024-08-06 03:26:19,365 - root - INFO - step: 3 loss: 8.1148 memory: 0.51GiB(0.64%) wps: 14,593 mfu: 0.44% [rank7]:2024-08-06 03:26:19,698 - root - INFO - step: 4 loss: 7.9982 memory: 0.51GiB(0.64%) wps: 12,328 mfu: 0.37% [rank7]:2024-08-06 03:26:20,011 - root - INFO - step: 5 loss: 7.8382 memory: 0.51GiB(0.64%) wps: 13,100 mfu: 0.40% [rank7]:2024-08-06 03:26:20,498 - root - INFO - step: 6 loss: 7.6293 memory: 0.51GiB(0.64%) wps: 8,423 mfu: 0.26% [rank7]:2024-08-06 03:26:21,126 - root - INFO - step: 7 loss: 7.4454 memory: 0.51GiB(0.64%) wps: 6,530 mfu: 0.20% [rank7]:2024-08-06 03:26:21,472 - root - INFO - step: 8 loss: 7.3337 memory: 0.51GiB(0.64%) wps: 11,843 mfu: 0.36% [rank7]:2024-08-06 03:26:21,849 - root - INFO - step: 9 loss: 7.1960 memory: 0.51GiB(0.64%) wps: 10,892 mfu: 0.33% [rank7]:2024-08-06 03:26:22,229 - root - INFO - step: 10 loss: 7.1208 memory: 0.51GiB(0.64%) wps: 10,798 mfu: 0.33% >>>>>>>>>>>>>>>>>Checkpoint save & load<<<<<<<<<<<<<<<<<<< [rank7]:2024-08-06 03:26:50,306 - root - INFO - step: 11 loss: 7.1222 memory: 0.51GiB(0.64%) wps: 866 mfu: 0.03% [rank7]:2024-08-06 03:26:50,632 - root - INFO - step: 12 loss: 7.1189 memory: 0.51GiB(0.64%) wps: 12,589 mfu: 0.38% [rank7]:2024-08-06 03:26:50,917 - root - INFO - step: 13 loss: 6.9646 memory: 0.51GiB(0.64%) wps: 14,417 mfu: 0.44% [rank7]:2024-08-06 03:26:51,217 - root - INFO - step: 14 loss: 6.9626 memory: 0.51GiB(0.64%) wps: 13,680 mfu: 0.42% [rank7]:2024-08-06 03:26:51,514 - root - INFO - step: 15 loss: 6.8694 memory: 0.51GiB(0.64%) wps: 13,799 mfu: 0.42% [rank7]:2024-08-06 03:26:52,207 - root - INFO - step: 16 loss: 6.7994 memory: 0.51GiB(0.64%) wps: 5,910 mfu: 0.18% [rank7]:2024-08-06 03:26:53,053 - root - INFO - step: 17 loss: 6.7634 memory: 0.51GiB(0.64%) wps: 4,847 mfu: 0.15% [rank7]:2024-08-06 03:26:53,370 - root - INFO - step: 18 loss: 6.7233 memory: 0.51GiB(0.64%) wps: 12,915 mfu: 0.39% [rank7]:2024-08-06 03:26:53,686 - root - INFO - step: 19 loss: 6.7054 memory: 0.51GiB(0.64%) wps: 12,995 mfu: 0.39% [rank7]:2024-08-06 03:26:54,059 - root - INFO - step: 20 loss: 6.7130 memory: 0.51GiB(0.64%) wps: 10,991 mfu: 0.33% ``` When to merge: when pytorch/pytorch#130760 is in nightly build. [ghstack-poisoned]
ghstack-source-id: 13d0d4c3f11e5e130cb780557243da07e2564a3c Pull Request resolved: #507
…ch version that not including DTensor strided sharding" **Summary** 1. check if users are using new nightly-build pytorch that includes DTensor strided sharding (pytorch/pytorch#130760) when 2D/3D is used. Print warning if not. 2. remove temporary re-enablement added in #460 . **Test** Command: `python test_runner.py outputs --test pp_dp_tp --ngpu 8` GPUs: A100 Output: - without strided sharding: ``` [rank7]:2024-08-06 03:21:26,706 - root - INFO - step: 2 loss: 8.1652 memory: 0.51GiB(0.64%) wps: 8,250 mfu: 0.25% [rank7]:2024-08-06 03:21:27,013 - root - INFO - step: 3 loss: 8.0951 memory: 0.51GiB(0.64%) wps: 13,358 mfu: 0.41% [rank7]:2024-08-06 03:21:27,309 - root - INFO - step: 4 loss: 7.9748 memory: 0.51GiB(0.64%) wps: 13,865 mfu: 0.42% [rank7]:2024-08-06 03:21:27,582 - root - INFO - step: 5 loss: 7.8025 memory: 0.51GiB(0.64%) wps: 15,057 mfu: 0.46% [rank7]:2024-08-06 03:21:28,076 - root - INFO - step: 6 loss: 7.5612 memory: 0.51GiB(0.64%) wps: 8,300 mfu: 0.25% [rank7]:2024-08-06 03:21:28,608 - root - INFO - step: 7 loss: 7.3649 memory: 0.51GiB(0.64%) wps: 7,705 mfu: 0.23% [rank7]:2024-08-06 03:21:28,927 - root - INFO - step: 8 loss: 7.2946 memory: 0.51GiB(0.64%) wps: 12,832 mfu: 0.39% [rank7]:2024-08-06 03:21:29,251 - root - INFO - step: 9 loss: 7.1311 memory: 0.51GiB(0.64%) wps: 12,669 mfu: 0.38% [rank7]:2024-08-06 03:21:29,627 - root - INFO - step: 10 loss: 7.0540 memory: 0.51GiB(0.64%) wps: 10,918 mfu: 0.33% >>>>>>>>>>>>>>>>>Checkpoint save & load<<<<<<<<<<<<<<<<<<< [rank7]:2024-08-06 03:21:59,723 - root - INFO - step: 11 loss: 7.0822 memory: 0.51GiB(0.64%) wps: 1,139 mfu: 0.03% [rank7]:2024-08-06 03:22:00,054 - root - INFO - step: 12 loss: 7.0508 memory: 0.51GiB(0.64%) wps: 12,366 mfu: 0.38% [rank7]:2024-08-06 03:22:00,340 - root - INFO - step: 13 loss: 6.9182 memory: 0.51GiB(0.64%) wps: 14,370 mfu: 0.44% [rank7]:2024-08-06 03:22:00,624 - root - INFO - step: 14 loss: 6.8948 memory: 0.51GiB(0.64%) wps: 14,442 mfu: 0.44% [rank7]:2024-08-06 03:22:00,907 - root - INFO - step: 15 loss: 6.8358 memory: 0.51GiB(0.64%) wps: 14,514 mfu: 0.44% [rank7]:2024-08-06 03:22:01,574 - root - INFO - step: 16 loss: 6.7653 memory: 0.51GiB(0.64%) wps: 6,144 mfu: 0.19% [rank7]:2024-08-06 03:22:02,209 - root - INFO - step: 17 loss: 6.7340 memory: 0.51GiB(0.64%) wps: 6,453 mfu: 0.20% [rank7]:2024-08-06 03:22:02,532 - root - INFO - step: 18 loss: 6.6874 memory: 0.51GiB(0.64%) wps: 12,695 mfu: 0.39% [rank7]:2024-08-06 03:22:02,863 - root - INFO - step: 19 loss: 6.6566 memory: 0.51GiB(0.64%) wps: 12,406 mfu: 0.38% [rank7]:2024-08-06 03:22:03,257 - root - INFO - step: 20 loss: 6.6629 memory: 0.51GiB(0.64%) wps: 10,392 mfu: 0.32% ``` - with strided sharding ``` [rank7]:2024-08-06 03:26:18,288 - root - INFO - step: 1 loss: 8.2069 memory: 0.50GiB(0.63%) wps: 915 mfu: 0.03% [rank7]:2024-08-06 03:26:19,084 - root - INFO - step: 2 loss: 8.1913 memory: 0.51GiB(0.64%) wps: 5,144 mfu: 0.16% [rank7]:2024-08-06 03:26:19,365 - root - INFO - step: 3 loss: 8.1148 memory: 0.51GiB(0.64%) wps: 14,593 mfu: 0.44% [rank7]:2024-08-06 03:26:19,698 - root - INFO - step: 4 loss: 7.9982 memory: 0.51GiB(0.64%) wps: 12,328 mfu: 0.37% [rank7]:2024-08-06 03:26:20,011 - root - INFO - step: 5 loss: 7.8382 memory: 0.51GiB(0.64%) wps: 13,100 mfu: 0.40% [rank7]:2024-08-06 03:26:20,498 - root - INFO - step: 6 loss: 7.6293 memory: 0.51GiB(0.64%) wps: 8,423 mfu: 0.26% [rank7]:2024-08-06 03:26:21,126 - root - INFO - step: 7 loss: 7.4454 memory: 0.51GiB(0.64%) wps: 6,530 mfu: 0.20% [rank7]:2024-08-06 03:26:21,472 - root - INFO - step: 8 loss: 7.3337 memory: 0.51GiB(0.64%) wps: 11,843 mfu: 0.36% [rank7]:2024-08-06 03:26:21,849 - root - INFO - step: 9 loss: 7.1960 memory: 0.51GiB(0.64%) wps: 10,892 mfu: 0.33% [rank7]:2024-08-06 03:26:22,229 - root - INFO - step: 10 loss: 7.1208 memory: 0.51GiB(0.64%) wps: 10,798 mfu: 0.33% >>>>>>>>>>>>>>>>>Checkpoint save & load<<<<<<<<<<<<<<<<<<< [rank7]:2024-08-06 03:26:50,306 - root - INFO - step: 11 loss: 7.1222 memory: 0.51GiB(0.64%) wps: 866 mfu: 0.03% [rank7]:2024-08-06 03:26:50,632 - root - INFO - step: 12 loss: 7.1189 memory: 0.51GiB(0.64%) wps: 12,589 mfu: 0.38% [rank7]:2024-08-06 03:26:50,917 - root - INFO - step: 13 loss: 6.9646 memory: 0.51GiB(0.64%) wps: 14,417 mfu: 0.44% [rank7]:2024-08-06 03:26:51,217 - root - INFO - step: 14 loss: 6.9626 memory: 0.51GiB(0.64%) wps: 13,680 mfu: 0.42% [rank7]:2024-08-06 03:26:51,514 - root - INFO - step: 15 loss: 6.8694 memory: 0.51GiB(0.64%) wps: 13,799 mfu: 0.42% [rank7]:2024-08-06 03:26:52,207 - root - INFO - step: 16 loss: 6.7994 memory: 0.51GiB(0.64%) wps: 5,910 mfu: 0.18% [rank7]:2024-08-06 03:26:53,053 - root - INFO - step: 17 loss: 6.7634 memory: 0.51GiB(0.64%) wps: 4,847 mfu: 0.15% [rank7]:2024-08-06 03:26:53,370 - root - INFO - step: 18 loss: 6.7233 memory: 0.51GiB(0.64%) wps: 12,915 mfu: 0.39% [rank7]:2024-08-06 03:26:53,686 - root - INFO - step: 19 loss: 6.7054 memory: 0.51GiB(0.64%) wps: 12,995 mfu: 0.39% [rank7]:2024-08-06 03:26:54,059 - root - INFO - step: 20 loss: 6.7130 memory: 0.51GiB(0.64%) wps: 10,991 mfu: 0.33% ``` When to merge: when pytorch/pytorch#130760 is in nightly build. [ghstack-poisoned]
…not including DTensor strided sharding" **Summary** 1. check if users are using new nightly-build pytorch that includes DTensor strided sharding (pytorch/pytorch#130760) when 2D/3D is used. Print warning if not. 2. remove temporary re-enablement added in #460 . **Test** Command: `python test_runner.py outputs --test pp_dp_tp --ngpu 8` GPUs: A100 Output: - without strided sharding: ``` [rank7]:2024-08-06 03:21:26,706 - root - INFO - step: 2 loss: 8.1652 memory: 0.51GiB(0.64%) wps: 8,250 mfu: 0.25% [rank7]:2024-08-06 03:21:27,013 - root - INFO - step: 3 loss: 8.0951 memory: 0.51GiB(0.64%) wps: 13,358 mfu: 0.41% [rank7]:2024-08-06 03:21:27,309 - root - INFO - step: 4 loss: 7.9748 memory: 0.51GiB(0.64%) wps: 13,865 mfu: 0.42% [rank7]:2024-08-06 03:21:27,582 - root - INFO - step: 5 loss: 7.8025 memory: 0.51GiB(0.64%) wps: 15,057 mfu: 0.46% [rank7]:2024-08-06 03:21:28,076 - root - INFO - step: 6 loss: 7.5612 memory: 0.51GiB(0.64%) wps: 8,300 mfu: 0.25% [rank7]:2024-08-06 03:21:28,608 - root - INFO - step: 7 loss: 7.3649 memory: 0.51GiB(0.64%) wps: 7,705 mfu: 0.23% [rank7]:2024-08-06 03:21:28,927 - root - INFO - step: 8 loss: 7.2946 memory: 0.51GiB(0.64%) wps: 12,832 mfu: 0.39% [rank7]:2024-08-06 03:21:29,251 - root - INFO - step: 9 loss: 7.1311 memory: 0.51GiB(0.64%) wps: 12,669 mfu: 0.38% [rank7]:2024-08-06 03:21:29,627 - root - INFO - step: 10 loss: 7.0540 memory: 0.51GiB(0.64%) wps: 10,918 mfu: 0.33% >>>>>>>>>>>>>>>>>Checkpoint save & load<<<<<<<<<<<<<<<<<<< [rank7]:2024-08-06 03:21:59,723 - root - INFO - step: 11 loss: 7.0822 memory: 0.51GiB(0.64%) wps: 1,139 mfu: 0.03% [rank7]:2024-08-06 03:22:00,054 - root - INFO - step: 12 loss: 7.0508 memory: 0.51GiB(0.64%) wps: 12,366 mfu: 0.38% [rank7]:2024-08-06 03:22:00,340 - root - INFO - step: 13 loss: 6.9182 memory: 0.51GiB(0.64%) wps: 14,370 mfu: 0.44% [rank7]:2024-08-06 03:22:00,624 - root - INFO - step: 14 loss: 6.8948 memory: 0.51GiB(0.64%) wps: 14,442 mfu: 0.44% [rank7]:2024-08-06 03:22:00,907 - root - INFO - step: 15 loss: 6.8358 memory: 0.51GiB(0.64%) wps: 14,514 mfu: 0.44% [rank7]:2024-08-06 03:22:01,574 - root - INFO - step: 16 loss: 6.7653 memory: 0.51GiB(0.64%) wps: 6,144 mfu: 0.19% [rank7]:2024-08-06 03:22:02,209 - root - INFO - step: 17 loss: 6.7340 memory: 0.51GiB(0.64%) wps: 6,453 mfu: 0.20% [rank7]:2024-08-06 03:22:02,532 - root - INFO - step: 18 loss: 6.6874 memory: 0.51GiB(0.64%) wps: 12,695 mfu: 0.39% [rank7]:2024-08-06 03:22:02,863 - root - INFO - step: 19 loss: 6.6566 memory: 0.51GiB(0.64%) wps: 12,406 mfu: 0.38% [rank7]:2024-08-06 03:22:03,257 - root - INFO - step: 20 loss: 6.6629 memory: 0.51GiB(0.64%) wps: 10,392 mfu: 0.32% ``` - with strided sharding ``` [rank7]:2024-08-06 03:26:18,288 - root - INFO - step: 1 loss: 8.2069 memory: 0.50GiB(0.63%) wps: 915 mfu: 0.03% [rank7]:2024-08-06 03:26:19,084 - root - INFO - step: 2 loss: 8.1913 memory: 0.51GiB(0.64%) wps: 5,144 mfu: 0.16% [rank7]:2024-08-06 03:26:19,365 - root - INFO - step: 3 loss: 8.1148 memory: 0.51GiB(0.64%) wps: 14,593 mfu: 0.44% [rank7]:2024-08-06 03:26:19,698 - root - INFO - step: 4 loss: 7.9982 memory: 0.51GiB(0.64%) wps: 12,328 mfu: 0.37% [rank7]:2024-08-06 03:26:20,011 - root - INFO - step: 5 loss: 7.8382 memory: 0.51GiB(0.64%) wps: 13,100 mfu: 0.40% [rank7]:2024-08-06 03:26:20,498 - root - INFO - step: 6 loss: 7.6293 memory: 0.51GiB(0.64%) wps: 8,423 mfu: 0.26% [rank7]:2024-08-06 03:26:21,126 - root - INFO - step: 7 loss: 7.4454 memory: 0.51GiB(0.64%) wps: 6,530 mfu: 0.20% [rank7]:2024-08-06 03:26:21,472 - root - INFO - step: 8 loss: 7.3337 memory: 0.51GiB(0.64%) wps: 11,843 mfu: 0.36% [rank7]:2024-08-06 03:26:21,849 - root - INFO - step: 9 loss: 7.1960 memory: 0.51GiB(0.64%) wps: 10,892 mfu: 0.33% [rank7]:2024-08-06 03:26:22,229 - root - INFO - step: 10 loss: 7.1208 memory: 0.51GiB(0.64%) wps: 10,798 mfu: 0.33% >>>>>>>>>>>>>>>>>Checkpoint save & load<<<<<<<<<<<<<<<<<<< [rank7]:2024-08-06 03:26:50,306 - root - INFO - step: 11 loss: 7.1222 memory: 0.51GiB(0.64%) wps: 866 mfu: 0.03% [rank7]:2024-08-06 03:26:50,632 - root - INFO - step: 12 loss: 7.1189 memory: 0.51GiB(0.64%) wps: 12,589 mfu: 0.38% [rank7]:2024-08-06 03:26:50,917 - root - INFO - step: 13 loss: 6.9646 memory: 0.51GiB(0.64%) wps: 14,417 mfu: 0.44% [rank7]:2024-08-06 03:26:51,217 - root - INFO - step: 14 loss: 6.9626 memory: 0.51GiB(0.64%) wps: 13,680 mfu: 0.42% [rank7]:2024-08-06 03:26:51,514 - root - INFO - step: 15 loss: 6.8694 memory: 0.51GiB(0.64%) wps: 13,799 mfu: 0.42% [rank7]:2024-08-06 03:26:52,207 - root - INFO - step: 16 loss: 6.7994 memory: 0.51GiB(0.64%) wps: 5,910 mfu: 0.18% [rank7]:2024-08-06 03:26:53,053 - root - INFO - step: 17 loss: 6.7634 memory: 0.51GiB(0.64%) wps: 4,847 mfu: 0.15% [rank7]:2024-08-06 03:26:53,370 - root - INFO - step: 18 loss: 6.7233 memory: 0.51GiB(0.64%) wps: 12,915 mfu: 0.39% [rank7]:2024-08-06 03:26:53,686 - root - INFO - step: 19 loss: 6.7054 memory: 0.51GiB(0.64%) wps: 12,995 mfu: 0.39% [rank7]:2024-08-06 03:26:54,059 - root - INFO - step: 20 loss: 6.7130 memory: 0.51GiB(0.64%) wps: 10,991 mfu: 0.33% ``` When to merge: when pytorch/pytorch#130760 is in nightly build. [ghstack-poisoned]
ghstack-source-id: 259289fa3a81a781287832e953d4d3ca494a8075 Pull Request resolved: #507
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm
…ch version that not including DTensor strided sharding" **Summary** 1. check if users are using new nightly-build pytorch that includes DTensor strided sharding (pytorch/pytorch#130760) when 2D/3D is used. Print warning if not. 2. remove temporary re-enablement added in #460 . **Test** Command: `python test_runner.py outputs --test pp_dp_tp --ngpu 8` GPUs: A100 Output: - without strided sharding: ``` [rank7]:2024-08-06 03:21:26,706 - root - INFO - step: 2 loss: 8.1652 memory: 0.51GiB(0.64%) wps: 8,250 mfu: 0.25% [rank7]:2024-08-06 03:21:27,013 - root - INFO - step: 3 loss: 8.0951 memory: 0.51GiB(0.64%) wps: 13,358 mfu: 0.41% [rank7]:2024-08-06 03:21:27,309 - root - INFO - step: 4 loss: 7.9748 memory: 0.51GiB(0.64%) wps: 13,865 mfu: 0.42% [rank7]:2024-08-06 03:21:27,582 - root - INFO - step: 5 loss: 7.8025 memory: 0.51GiB(0.64%) wps: 15,057 mfu: 0.46% [rank7]:2024-08-06 03:21:28,076 - root - INFO - step: 6 loss: 7.5612 memory: 0.51GiB(0.64%) wps: 8,300 mfu: 0.25% [rank7]:2024-08-06 03:21:28,608 - root - INFO - step: 7 loss: 7.3649 memory: 0.51GiB(0.64%) wps: 7,705 mfu: 0.23% [rank7]:2024-08-06 03:21:28,927 - root - INFO - step: 8 loss: 7.2946 memory: 0.51GiB(0.64%) wps: 12,832 mfu: 0.39% [rank7]:2024-08-06 03:21:29,251 - root - INFO - step: 9 loss: 7.1311 memory: 0.51GiB(0.64%) wps: 12,669 mfu: 0.38% [rank7]:2024-08-06 03:21:29,627 - root - INFO - step: 10 loss: 7.0540 memory: 0.51GiB(0.64%) wps: 10,918 mfu: 0.33% >>>>>>>>>>>>>>>>>Checkpoint save & load<<<<<<<<<<<<<<<<<<< [rank7]:2024-08-06 03:21:59,723 - root - INFO - step: 11 loss: 7.0822 memory: 0.51GiB(0.64%) wps: 1,139 mfu: 0.03% [rank7]:2024-08-06 03:22:00,054 - root - INFO - step: 12 loss: 7.0508 memory: 0.51GiB(0.64%) wps: 12,366 mfu: 0.38% [rank7]:2024-08-06 03:22:00,340 - root - INFO - step: 13 loss: 6.9182 memory: 0.51GiB(0.64%) wps: 14,370 mfu: 0.44% [rank7]:2024-08-06 03:22:00,624 - root - INFO - step: 14 loss: 6.8948 memory: 0.51GiB(0.64%) wps: 14,442 mfu: 0.44% [rank7]:2024-08-06 03:22:00,907 - root - INFO - step: 15 loss: 6.8358 memory: 0.51GiB(0.64%) wps: 14,514 mfu: 0.44% [rank7]:2024-08-06 03:22:01,574 - root - INFO - step: 16 loss: 6.7653 memory: 0.51GiB(0.64%) wps: 6,144 mfu: 0.19% [rank7]:2024-08-06 03:22:02,209 - root - INFO - step: 17 loss: 6.7340 memory: 0.51GiB(0.64%) wps: 6,453 mfu: 0.20% [rank7]:2024-08-06 03:22:02,532 - root - INFO - step: 18 loss: 6.6874 memory: 0.51GiB(0.64%) wps: 12,695 mfu: 0.39% [rank7]:2024-08-06 03:22:02,863 - root - INFO - step: 19 loss: 6.6566 memory: 0.51GiB(0.64%) wps: 12,406 mfu: 0.38% [rank7]:2024-08-06 03:22:03,257 - root - INFO - step: 20 loss: 6.6629 memory: 0.51GiB(0.64%) wps: 10,392 mfu: 0.32% ``` - with strided sharding ``` [rank7]:2024-08-06 03:26:18,288 - root - INFO - step: 1 loss: 8.2069 memory: 0.50GiB(0.63%) wps: 915 mfu: 0.03% [rank7]:2024-08-06 03:26:19,084 - root - INFO - step: 2 loss: 8.1913 memory: 0.51GiB(0.64%) wps: 5,144 mfu: 0.16% [rank7]:2024-08-06 03:26:19,365 - root - INFO - step: 3 loss: 8.1148 memory: 0.51GiB(0.64%) wps: 14,593 mfu: 0.44% [rank7]:2024-08-06 03:26:19,698 - root - INFO - step: 4 loss: 7.9982 memory: 0.51GiB(0.64%) wps: 12,328 mfu: 0.37% [rank7]:2024-08-06 03:26:20,011 - root - INFO - step: 5 loss: 7.8382 memory: 0.51GiB(0.64%) wps: 13,100 mfu: 0.40% [rank7]:2024-08-06 03:26:20,498 - root - INFO - step: 6 loss: 7.6293 memory: 0.51GiB(0.64%) wps: 8,423 mfu: 0.26% [rank7]:2024-08-06 03:26:21,126 - root - INFO - step: 7 loss: 7.4454 memory: 0.51GiB(0.64%) wps: 6,530 mfu: 0.20% [rank7]:2024-08-06 03:26:21,472 - root - INFO - step: 8 loss: 7.3337 memory: 0.51GiB(0.64%) wps: 11,843 mfu: 0.36% [rank7]:2024-08-06 03:26:21,849 - root - INFO - step: 9 loss: 7.1960 memory: 0.51GiB(0.64%) wps: 10,892 mfu: 0.33% [rank7]:2024-08-06 03:26:22,229 - root - INFO - step: 10 loss: 7.1208 memory: 0.51GiB(0.64%) wps: 10,798 mfu: 0.33% >>>>>>>>>>>>>>>>>Checkpoint save & load<<<<<<<<<<<<<<<<<<< [rank7]:2024-08-06 03:26:50,306 - root - INFO - step: 11 loss: 7.1222 memory: 0.51GiB(0.64%) wps: 866 mfu: 0.03% [rank7]:2024-08-06 03:26:50,632 - root - INFO - step: 12 loss: 7.1189 memory: 0.51GiB(0.64%) wps: 12,589 mfu: 0.38% [rank7]:2024-08-06 03:26:50,917 - root - INFO - step: 13 loss: 6.9646 memory: 0.51GiB(0.64%) wps: 14,417 mfu: 0.44% [rank7]:2024-08-06 03:26:51,217 - root - INFO - step: 14 loss: 6.9626 memory: 0.51GiB(0.64%) wps: 13,680 mfu: 0.42% [rank7]:2024-08-06 03:26:51,514 - root - INFO - step: 15 loss: 6.8694 memory: 0.51GiB(0.64%) wps: 13,799 mfu: 0.42% [rank7]:2024-08-06 03:26:52,207 - root - INFO - step: 16 loss: 6.7994 memory: 0.51GiB(0.64%) wps: 5,910 mfu: 0.18% [rank7]:2024-08-06 03:26:53,053 - root - INFO - step: 17 loss: 6.7634 memory: 0.51GiB(0.64%) wps: 4,847 mfu: 0.15% [rank7]:2024-08-06 03:26:53,370 - root - INFO - step: 18 loss: 6.7233 memory: 0.51GiB(0.64%) wps: 12,915 mfu: 0.39% [rank7]:2024-08-06 03:26:53,686 - root - INFO - step: 19 loss: 6.7054 memory: 0.51GiB(0.64%) wps: 12,995 mfu: 0.39% [rank7]:2024-08-06 03:26:54,059 - root - INFO - step: 20 loss: 6.7130 memory: 0.51GiB(0.64%) wps: 10,991 mfu: 0.33% ``` When to merge: when pytorch/pytorch#130760 is in nightly build. [ghstack-poisoned]
…not including DTensor strided sharding" **Summary** 1. check if users are using new nightly-build pytorch that includes DTensor strided sharding (pytorch/pytorch#130760) when 2D/3D is used. Print warning if not. 2. remove temporary re-enablement added in #460 . **Test** Command: `python test_runner.py outputs --test pp_dp_tp --ngpu 8` GPUs: A100 Output: - without strided sharding: ``` [rank7]:2024-08-06 03:21:26,706 - root - INFO - step: 2 loss: 8.1652 memory: 0.51GiB(0.64%) wps: 8,250 mfu: 0.25% [rank7]:2024-08-06 03:21:27,013 - root - INFO - step: 3 loss: 8.0951 memory: 0.51GiB(0.64%) wps: 13,358 mfu: 0.41% [rank7]:2024-08-06 03:21:27,309 - root - INFO - step: 4 loss: 7.9748 memory: 0.51GiB(0.64%) wps: 13,865 mfu: 0.42% [rank7]:2024-08-06 03:21:27,582 - root - INFO - step: 5 loss: 7.8025 memory: 0.51GiB(0.64%) wps: 15,057 mfu: 0.46% [rank7]:2024-08-06 03:21:28,076 - root - INFO - step: 6 loss: 7.5612 memory: 0.51GiB(0.64%) wps: 8,300 mfu: 0.25% [rank7]:2024-08-06 03:21:28,608 - root - INFO - step: 7 loss: 7.3649 memory: 0.51GiB(0.64%) wps: 7,705 mfu: 0.23% [rank7]:2024-08-06 03:21:28,927 - root - INFO - step: 8 loss: 7.2946 memory: 0.51GiB(0.64%) wps: 12,832 mfu: 0.39% [rank7]:2024-08-06 03:21:29,251 - root - INFO - step: 9 loss: 7.1311 memory: 0.51GiB(0.64%) wps: 12,669 mfu: 0.38% [rank7]:2024-08-06 03:21:29,627 - root - INFO - step: 10 loss: 7.0540 memory: 0.51GiB(0.64%) wps: 10,918 mfu: 0.33% >>>>>>>>>>>>>>>>>Checkpoint save & load<<<<<<<<<<<<<<<<<<< [rank7]:2024-08-06 03:21:59,723 - root - INFO - step: 11 loss: 7.0822 memory: 0.51GiB(0.64%) wps: 1,139 mfu: 0.03% [rank7]:2024-08-06 03:22:00,054 - root - INFO - step: 12 loss: 7.0508 memory: 0.51GiB(0.64%) wps: 12,366 mfu: 0.38% [rank7]:2024-08-06 03:22:00,340 - root - INFO - step: 13 loss: 6.9182 memory: 0.51GiB(0.64%) wps: 14,370 mfu: 0.44% [rank7]:2024-08-06 03:22:00,624 - root - INFO - step: 14 loss: 6.8948 memory: 0.51GiB(0.64%) wps: 14,442 mfu: 0.44% [rank7]:2024-08-06 03:22:00,907 - root - INFO - step: 15 loss: 6.8358 memory: 0.51GiB(0.64%) wps: 14,514 mfu: 0.44% [rank7]:2024-08-06 03:22:01,574 - root - INFO - step: 16 loss: 6.7653 memory: 0.51GiB(0.64%) wps: 6,144 mfu: 0.19% [rank7]:2024-08-06 03:22:02,209 - root - INFO - step: 17 loss: 6.7340 memory: 0.51GiB(0.64%) wps: 6,453 mfu: 0.20% [rank7]:2024-08-06 03:22:02,532 - root - INFO - step: 18 loss: 6.6874 memory: 0.51GiB(0.64%) wps: 12,695 mfu: 0.39% [rank7]:2024-08-06 03:22:02,863 - root - INFO - step: 19 loss: 6.6566 memory: 0.51GiB(0.64%) wps: 12,406 mfu: 0.38% [rank7]:2024-08-06 03:22:03,257 - root - INFO - step: 20 loss: 6.6629 memory: 0.51GiB(0.64%) wps: 10,392 mfu: 0.32% ``` - with strided sharding ``` [rank7]:2024-08-06 03:26:18,288 - root - INFO - step: 1 loss: 8.2069 memory: 0.50GiB(0.63%) wps: 915 mfu: 0.03% [rank7]:2024-08-06 03:26:19,084 - root - INFO - step: 2 loss: 8.1913 memory: 0.51GiB(0.64%) wps: 5,144 mfu: 0.16% [rank7]:2024-08-06 03:26:19,365 - root - INFO - step: 3 loss: 8.1148 memory: 0.51GiB(0.64%) wps: 14,593 mfu: 0.44% [rank7]:2024-08-06 03:26:19,698 - root - INFO - step: 4 loss: 7.9982 memory: 0.51GiB(0.64%) wps: 12,328 mfu: 0.37% [rank7]:2024-08-06 03:26:20,011 - root - INFO - step: 5 loss: 7.8382 memory: 0.51GiB(0.64%) wps: 13,100 mfu: 0.40% [rank7]:2024-08-06 03:26:20,498 - root - INFO - step: 6 loss: 7.6293 memory: 0.51GiB(0.64%) wps: 8,423 mfu: 0.26% [rank7]:2024-08-06 03:26:21,126 - root - INFO - step: 7 loss: 7.4454 memory: 0.51GiB(0.64%) wps: 6,530 mfu: 0.20% [rank7]:2024-08-06 03:26:21,472 - root - INFO - step: 8 loss: 7.3337 memory: 0.51GiB(0.64%) wps: 11,843 mfu: 0.36% [rank7]:2024-08-06 03:26:21,849 - root - INFO - step: 9 loss: 7.1960 memory: 0.51GiB(0.64%) wps: 10,892 mfu: 0.33% [rank7]:2024-08-06 03:26:22,229 - root - INFO - step: 10 loss: 7.1208 memory: 0.51GiB(0.64%) wps: 10,798 mfu: 0.33% >>>>>>>>>>>>>>>>>Checkpoint save & load<<<<<<<<<<<<<<<<<<< [rank7]:2024-08-06 03:26:50,306 - root - INFO - step: 11 loss: 7.1222 memory: 0.51GiB(0.64%) wps: 866 mfu: 0.03% [rank7]:2024-08-06 03:26:50,632 - root - INFO - step: 12 loss: 7.1189 memory: 0.51GiB(0.64%) wps: 12,589 mfu: 0.38% [rank7]:2024-08-06 03:26:50,917 - root - INFO - step: 13 loss: 6.9646 memory: 0.51GiB(0.64%) wps: 14,417 mfu: 0.44% [rank7]:2024-08-06 03:26:51,217 - root - INFO - step: 14 loss: 6.9626 memory: 0.51GiB(0.64%) wps: 13,680 mfu: 0.42% [rank7]:2024-08-06 03:26:51,514 - root - INFO - step: 15 loss: 6.8694 memory: 0.51GiB(0.64%) wps: 13,799 mfu: 0.42% [rank7]:2024-08-06 03:26:52,207 - root - INFO - step: 16 loss: 6.7994 memory: 0.51GiB(0.64%) wps: 5,910 mfu: 0.18% [rank7]:2024-08-06 03:26:53,053 - root - INFO - step: 17 loss: 6.7634 memory: 0.51GiB(0.64%) wps: 4,847 mfu: 0.15% [rank7]:2024-08-06 03:26:53,370 - root - INFO - step: 18 loss: 6.7233 memory: 0.51GiB(0.64%) wps: 12,915 mfu: 0.39% [rank7]:2024-08-06 03:26:53,686 - root - INFO - step: 19 loss: 6.7054 memory: 0.51GiB(0.64%) wps: 12,995 mfu: 0.39% [rank7]:2024-08-06 03:26:54,059 - root - INFO - step: 20 loss: 6.7130 memory: 0.51GiB(0.64%) wps: 10,991 mfu: 0.33% ``` When to merge: when pytorch/pytorch#130760 is in nightly build. [ghstack-poisoned]
ghstack-source-id: d5c663edc64ff764bf044c92e9a1d99ea640da74 Pull Request resolved: #507
ghstack-source-id: d5c663edc64ff764bf044c92e9a1d99ea640da74 Pull Request resolved: #507
…ng DTensor strided sharding (#507) Stack from [ghstack](https://github.com/ezyang/ghstack) (oldest at bottom): * __->__ #507 **Summary** 1. check if users are using new nightly-build pytorch that includes DTensor strided sharding (pytorch/pytorch#130760) when 2D/3D is used. Print warning if not. 2. remove temporary re-enablement added in #460 . **Test** Command: `python test_runner.py outputs --test pp_dp_tp --ngpu 8` GPUs: A100 Output: - without strided sharding: ``` [rank7]:2024-08-06 03:21:26,706 - root - INFO - step: 2 loss: 8.1652 memory: 0.51GiB(0.64%) wps: 8,250 mfu: 0.25% [rank7]:2024-08-06 03:21:27,013 - root - INFO - step: 3 loss: 8.0951 memory: 0.51GiB(0.64%) wps: 13,358 mfu: 0.41% [rank7]:2024-08-06 03:21:27,309 - root - INFO - step: 4 loss: 7.9748 memory: 0.51GiB(0.64%) wps: 13,865 mfu: 0.42% [rank7]:2024-08-06 03:21:27,582 - root - INFO - step: 5 loss: 7.8025 memory: 0.51GiB(0.64%) wps: 15,057 mfu: 0.46% [rank7]:2024-08-06 03:21:28,076 - root - INFO - step: 6 loss: 7.5612 memory: 0.51GiB(0.64%) wps: 8,300 mfu: 0.25% [rank7]:2024-08-06 03:21:28,608 - root - INFO - step: 7 loss: 7.3649 memory: 0.51GiB(0.64%) wps: 7,705 mfu: 0.23% [rank7]:2024-08-06 03:21:28,927 - root - INFO - step: 8 loss: 7.2946 memory: 0.51GiB(0.64%) wps: 12,832 mfu: 0.39% [rank7]:2024-08-06 03:21:29,251 - root - INFO - step: 9 loss: 7.1311 memory: 0.51GiB(0.64%) wps: 12,669 mfu: 0.38% [rank7]:2024-08-06 03:21:29,627 - root - INFO - step: 10 loss: 7.0540 memory: 0.51GiB(0.64%) wps: 10,918 mfu: 0.33% >>>>>>>>>>>>>>>>>Checkpoint save & load<<<<<<<<<<<<<<<<<<< [rank7]:2024-08-06 03:21:59,723 - root - INFO - step: 11 loss: 7.0822 memory: 0.51GiB(0.64%) wps: 1,139 mfu: 0.03% [rank7]:2024-08-06 03:22:00,054 - root - INFO - step: 12 loss: 7.0508 memory: 0.51GiB(0.64%) wps: 12,366 mfu: 0.38% [rank7]:2024-08-06 03:22:00,340 - root - INFO - step: 13 loss: 6.9182 memory: 0.51GiB(0.64%) wps: 14,370 mfu: 0.44% [rank7]:2024-08-06 03:22:00,624 - root - INFO - step: 14 loss: 6.8948 memory: 0.51GiB(0.64%) wps: 14,442 mfu: 0.44% [rank7]:2024-08-06 03:22:00,907 - root - INFO - step: 15 loss: 6.8358 memory: 0.51GiB(0.64%) wps: 14,514 mfu: 0.44% [rank7]:2024-08-06 03:22:01,574 - root - INFO - step: 16 loss: 6.7653 memory: 0.51GiB(0.64%) wps: 6,144 mfu: 0.19% [rank7]:2024-08-06 03:22:02,209 - root - INFO - step: 17 loss: 6.7340 memory: 0.51GiB(0.64%) wps: 6,453 mfu: 0.20% [rank7]:2024-08-06 03:22:02,532 - root - INFO - step: 18 loss: 6.6874 memory: 0.51GiB(0.64%) wps: 12,695 mfu: 0.39% [rank7]:2024-08-06 03:22:02,863 - root - INFO - step: 19 loss: 6.6566 memory: 0.51GiB(0.64%) wps: 12,406 mfu: 0.38% [rank7]:2024-08-06 03:22:03,257 - root - INFO - step: 20 loss: 6.6629 memory: 0.51GiB(0.64%) wps: 10,392 mfu: 0.32% ``` - with strided sharding ``` [rank7]:2024-08-06 03:26:18,288 - root - INFO - step: 1 loss: 8.2069 memory: 0.50GiB(0.63%) wps: 915 mfu: 0.03% [rank7]:2024-08-06 03:26:19,084 - root - INFO - step: 2 loss: 8.1913 memory: 0.51GiB(0.64%) wps: 5,144 mfu: 0.16% [rank7]:2024-08-06 03:26:19,365 - root - INFO - step: 3 loss: 8.1148 memory: 0.51GiB(0.64%) wps: 14,593 mfu: 0.44% [rank7]:2024-08-06 03:26:19,698 - root - INFO - step: 4 loss: 7.9982 memory: 0.51GiB(0.64%) wps: 12,328 mfu: 0.37% [rank7]:2024-08-06 03:26:20,011 - root - INFO - step: 5 loss: 7.8382 memory: 0.51GiB(0.64%) wps: 13,100 mfu: 0.40% [rank7]:2024-08-06 03:26:20,498 - root - INFO - step: 6 loss: 7.6293 memory: 0.51GiB(0.64%) wps: 8,423 mfu: 0.26% [rank7]:2024-08-06 03:26:21,126 - root - INFO - step: 7 loss: 7.4454 memory: 0.51GiB(0.64%) wps: 6,530 mfu: 0.20% [rank7]:2024-08-06 03:26:21,472 - root - INFO - step: 8 loss: 7.3337 memory: 0.51GiB(0.64%) wps: 11,843 mfu: 0.36% [rank7]:2024-08-06 03:26:21,849 - root - INFO - step: 9 loss: 7.1960 memory: 0.51GiB(0.64%) wps: 10,892 mfu: 0.33% [rank7]:2024-08-06 03:26:22,229 - root - INFO - step: 10 loss: 7.1208 memory: 0.51GiB(0.64%) wps: 10,798 mfu: 0.33% >>>>>>>>>>>>>>>>>Checkpoint save & load<<<<<<<<<<<<<<<<<<< [rank7]:2024-08-06 03:26:50,306 - root - INFO - step: 11 loss: 7.1222 memory: 0.51GiB(0.64%) wps: 866 mfu: 0.03% [rank7]:2024-08-06 03:26:50,632 - root - INFO - step: 12 loss: 7.1189 memory: 0.51GiB(0.64%) wps: 12,589 mfu: 0.38% [rank7]:2024-08-06 03:26:50,917 - root - INFO - step: 13 loss: 6.9646 memory: 0.51GiB(0.64%) wps: 14,417 mfu: 0.44% [rank7]:2024-08-06 03:26:51,217 - root - INFO - step: 14 loss: 6.9626 memory: 0.51GiB(0.64%) wps: 13,680 mfu: 0.42% [rank7]:2024-08-06 03:26:51,514 - root - INFO - step: 15 loss: 6.8694 memory: 0.51GiB(0.64%) wps: 13,799 mfu: 0.42% [rank7]:2024-08-06 03:26:52,207 - root - INFO - step: 16 loss: 6.7994 memory: 0.51GiB(0.64%) wps: 5,910 mfu: 0.18% [rank7]:2024-08-06 03:26:53,053 - root - INFO - step: 17 loss: 6.7634 memory: 0.51GiB(0.64%) wps: 4,847 mfu: 0.15% [rank7]:2024-08-06 03:26:53,370 - root - INFO - step: 18 loss: 6.7233 memory: 0.51GiB(0.64%) wps: 12,915 mfu: 0.39% [rank7]:2024-08-06 03:26:53,686 - root - INFO - step: 19 loss: 6.7054 memory: 0.51GiB(0.64%) wps: 12,995 mfu: 0.39% [rank7]:2024-08-06 03:26:54,059 - root - INFO - step: 20 loss: 6.7130 memory: 0.51GiB(0.64%) wps: 10,991 mfu: 0.33% ```
…g since 2.5 is official released" #507 added a PyTorch version check when users try to use FSDP+TP, to make sure the right PT version includes DTensor strided sharding which assures correct DTensor checkpoint. Since PyTorch 2.5 is official released and strided sharding is included in 2.5, we can safely remove this warning. [ghstack-poisoned]
… is official released (#665) Stack from [ghstack](https://github.com/ezyang/ghstack) (oldest at bottom): * __->__ #665 #507 added a PyTorch version check when users try to use FSDP+TP, to make sure the right PT version includes DTensor strided sharding which assures correct DTensor checkpoint. Since PyTorch 2.5 is official released and strided sharding is included in 2.5, we can safely remove this warning.
… is official released (pytorch#665) Stack from [ghstack](https://github.com/ezyang/ghstack) (oldest at bottom): * __->__ pytorch#665 pytorch#507 added a PyTorch version check when users try to use FSDP+TP, to make sure the right PT version includes DTensor strided sharding which assures correct DTensor checkpoint. Since PyTorch 2.5 is official released and strided sharding is included in 2.5, we can safely remove this warning.
Stack from ghstack (oldest at bottom):
Summary
Test
Command:
python test_runner.py outputs --test pp_dp_tp --ngpu 8
GPUs: A100
Output:
When to merge:
when pytorch/pytorch#130760 is in nightly build.