Remove the "extra_state" for TE DPA module #14
Merged
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Remove the "extra_state" from the shared_state_dict of TEDotProductAttention module which was added from NVTE (ver >=1.9) , so that the dist checkpointing can ignore the weights for this extra_state for some earlier ckpts
Passed the manual pytest:
![image](https://private-user-images.githubusercontent.com/12014554/385923878-97cdf33b-05a4-4235-9188-b023330a3d65.png?jwt=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJnaXRodWIuY29tIiwiYXVkIjoicmF3LmdpdGh1YnVzZXJjb250ZW50LmNvbSIsImtleSI6ImtleTUiLCJleHAiOjE3MzkzOTU1MTAsIm5iZiI6MTczOTM5NTIxMCwicGF0aCI6Ii8xMjAxNDU1NC8zODU5MjM4NzgtOTdjZGYzM2ItMDVhNC00MjM1LTkxODgtYjAyMzMzMGEzZDY1LnBuZz9YLUFtei1BbGdvcml0aG09QVdTNC1ITUFDLVNIQTI1NiZYLUFtei1DcmVkZW50aWFsPUFLSUFWQ09EWUxTQTUzUFFLNFpBJTJGMjAyNTAyMTIlMkZ1cy1lYXN0LTElMkZzMyUyRmF3czRfcmVxdWVzdCZYLUFtei1EYXRlPTIwMjUwMjEyVDIxMjAxMFomWC1BbXotRXhwaXJlcz0zMDAmWC1BbXotU2lnbmF0dXJlPThkYTljMTIzYTdkNDMwMTdiYWRmMWFjYjRkYTU3ODMxMWVkYjVkYTRiODkwZjA5Mzk2ODFhNDExZGI1NDNmZGImWC1BbXotU2lnbmVkSGVhZGVycz1ob3N0In0.F6FxFBuF9jsoS6Ooqr1e1qDQ5BP-XtBNseduBOAMGgA)
torchrun --nproc_per_node=8 -m pytest --color=yes -m "not flaky and not internal and not failing_on_rocm" --csv output/test_report.csv tests/unit_tests/