Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove the "extra_state" for TE DPA module #14

Merged
merged 1 commit into from
Nov 22, 2024

Conversation

wangye805
Copy link

@wangye805 wangye805 commented Nov 6, 2024

Remove the "extra_state" from the shared_state_dict of TEDotProductAttention module which was added from NVTE (ver >=1.9) , so that the dist checkpointing can ignore the weights for this extra_state for some earlier ckpts

Passed the manual pytest:
torchrun --nproc_per_node=8 -m pytest --color=yes -m "not flaky and not internal and not failing_on_rocm" --csv output/test_report.csv tests/unit_tests/
image

Copy link
Collaborator

@wenchenvincent wenchenvincent left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.

@wangye805 wangye805 force-pushed the dist_ckpt_dpa_fix branch 2 times, most recently from d21a895 to 746e080 Compare November 13, 2024 19:50
@wenchenvincent wenchenvincent merged commit 190213a into rocm_dev Nov 22, 2024
3 checks passed
@gurpreet-dhami gurpreet-dhami deleted the dist_ckpt_dpa_fix branch November 22, 2024 18:25
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants