Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixed base model class name extraction from PeftModels #27162

Merged
merged 3 commits into from
Nov 2, 2023
Merged

Fixed base model class name extraction from PeftModels #27162

merged 3 commits into from
Nov 2, 2023

Conversation

kkteru
Copy link
Contributor

@kkteru kkteru commented Oct 30, 2023

What does this PR do?

Fixes #27161

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

@pacman100, @muellerzr

@amyeroberts
Copy link
Collaborator

cc @younesbelkada

Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot!
I tried to reproduce and it is indeed a bug - however your fix does not properly take care of the case where users do DDP + PEFT. I propose to create first a dummy variable unwrapped_model that simply unwraps the model in case it is DDP or FSDP then perform the checks you suggested. What do you think?

@@ -2687,7 +2687,7 @@ def compute_loss(self, model, inputs, return_outputs=False):

if labels is not None:
if is_peft_available() and isinstance(model, PeftModel):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if is_peft_available() and isinstance(model, PeftModel):
unwrapped_model = unwrap_model(model)
if is_peft_available() and isinstance(unwrapped_model, PeftModel):
model_name = unwrapped_model.base_model.model._get_name()
else:
model_name = unwrapped_model._get_name()
if model_name in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.values():
loss = self.label_smoother(outputs, labels, shift_labels=True)
else:
loss = self.label_smoother(outputs, labels)
del unwrapped_model

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't accept this suggestion otherwise it will create a weird diff but I did this so that you can see what I meant

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint.

Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your great work on this!

Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for fixing this!

Just a small question - once clarified I think we're good to merge :)

model_name = unwrap_model(model.base_model)._get_name()
unwrapped_model = unwrap_model(model)
if is_peft_available() and isinstance(unwrapped_model, PeftModel):
model_name = unwrapped_model.base_model.model._get_name()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For my own understanding, was unwrap_model(model.base_model)._get_name() ever working?

I'm asking to understand whether we would still need the case
model_name = unwrapped_model.base_model._get_name()

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

was unwrap_model(model.base_model)._get_name() ever working?

If model is a DistrbutedDataParallel or a FSDP wrapped module model.base_model would fail (you need to unwrap model first) + unwrap_model(model.base_model)._get_name() would also fail because the model is stored in model.base_model.model. Let me know if you want more clarifications

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, but what if the model isn't distributed? Is it possible for those models to be hitting this logic branch?

Copy link
Contributor Author

@kkteru kkteru Nov 2, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the model isn't distributed and it is a PeftModel and w/ non-zero label-smoothing, unwrap_model(model.base_model)._get_name() would not fail but return the wrong model_name due to the reason mentioned in #27161.

IMO, the change @younesbelkada suggested fixes a different bug that I did not explicitly mention in the original issue, i.e., PeftModel with DDP/FSDP + label smoothing would throw an error. In other words, without the DDP/FSDP the fix of just replacing model.base_model with model.base_model.model would have worked.

Copy link
Contributor Author

@kkteru kkteru Nov 2, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm asking to understand whether we would still need the case model_name = unwrapped_model.base_model._get_name()

I think in peft, the abstraction around the true base model is always PeftModel.base_model.model in all configurations from what I can tell (refer here and here). So this case may not be needed anymore.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That sounds like a good idea. Let me try to do that, if that is okay. I see two potential changes in the neft activate and deactivate functions. I will try to do a more thorough scan and push another commit for review.

@younesbelkada Great work on integrating NEFT so quick, thank you so much!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sounds great, thank you @kkteru !

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just pushed the changes. I don't think there is any other place where this needs to be changed. One place that came close was this prefix definition when loading the adapter state_dict, but I think that was correctly declared.

Copy link
Contributor Author

@kkteru kkteru Nov 2, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If model is a DistrbutedDataParallel or a FSDP wrapped module model.base_model would fail (you need to unwrap model first) + unwrap_model(model.base_model)._get_name() would also fail because the model is stored in model.base_model.model. Let me know if you want more clarifications

I actually noticed similar potential issue in the sft_trainer/neft support of trl package here and here. Happy to push a PR there or leave it to you to clean up after.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah yes, if you could submit PRs on TRL that would be great as well ! Thanks @kkteru !

Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks again for your great work!

Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for fixing and iterating!

@amyeroberts amyeroberts merged commit 552ff24 into huggingface:main Nov 2, 2023
3 checks passed
EduardoPach pushed a commit to EduardoPach/transformers that referenced this pull request Nov 19, 2023
…7162)

* Fixed base model class name extraction from PeftModels

* Changes to first unwrap the model then extract the base model name

* Changed base_model to base_model.model to stay consistent with peft model abstractions
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Trainer doesn't shift labels for CAUSAL_LM PEFT models with label smoothing enabled
4 participants