Skip to content

Commit

Permalink
[ROCm] return an empty state dict for TE DPA module
Browse files Browse the repository at this point in the history
  • Loading branch information
wangye805 committed Nov 6, 2024
1 parent a2d0bdf commit 871c161
Showing 1 changed file with 9 additions and 0 deletions.
9 changes: 9 additions & 0 deletions megatron/core/extensions/transformer_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -706,6 +706,15 @@ def forward(
else:
return core_attn_out

def sharded_state_dict(self, prefix='', sharded_offsets=(), metadata=None):
state_dict = self.state_dict(prefix='', keep_vars=True)
# TE with version>=1.9 introduces an extra state in DotProductAttention Module
if is_te_min_version("1.9.0.dev0") and ('_extra_state' in state_dict):
state_dict.pop('_extra_state')
return make_sharded_tensors_for_checkpoint(
state_dict, prefix, {}, sharded_offsets
)


if is_te_min_version("1.9.0.dev0"):

Expand Down

0 comments on commit 871c161

Please sign in to comment.