Skip to content

Commit

Permalink
[Fix] load_state_dict in nlp_model.py (NVIDIA#7086)
Browse files Browse the repository at this point in the history
* Fix load_state_dict in nlp_model.py

Signed-off-by: He Huang (Steve) <105218074+stevehuang52@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Signed-off-by: He Huang (Steve) <105218074+stevehuang52@users.noreply.github.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
stevehuang52 and pre-commit-ci[bot] authored Jul 27, 2023
1 parent ee40dce commit d8b535f
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion nemo/collections/nlp/models/nlp_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,4 +394,5 @@ def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True):
and "bert_model.embeddings.position_ids" in state_dict
):
del state_dict["bert_model.embeddings.position_ids"]
super(NLPModel, self).load_state_dict(state_dict, strict=strict)
results = super(NLPModel, self).load_state_dict(state_dict, strict=strict)
return results

0 comments on commit d8b535f

Please sign in to comment.