-
Notifications
You must be signed in to change notification settings - Fork 27.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fixed base model class name extraction from PeftModels #27162
Fixed base model class name extraction from PeftModels #27162
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot!
I tried to reproduce and it is indeed a bug - however your fix does not properly take care of the case where users do DDP + PEFT. I propose to create first a dummy variable unwrapped_model
that simply unwraps the model in case it is DDP or FSDP then perform the checks you suggested. What do you think?
src/transformers/trainer.py
Outdated
@@ -2687,7 +2687,7 @@ def compute_loss(self, model, inputs, return_outputs=False): | |||
|
|||
if labels is not None: | |||
if is_peft_available() and isinstance(model, PeftModel): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if is_peft_available() and isinstance(model, PeftModel): | |
unwrapped_model = unwrap_model(model) | |
if is_peft_available() and isinstance(unwrapped_model, PeftModel): | |
model_name = unwrapped_model.base_model.model._get_name() | |
else: | |
model_name = unwrapped_model._get_name() | |
if model_name in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.values(): | |
loss = self.label_smoother(outputs, labels, shift_labels=True) | |
else: | |
loss = self.label_smoother(outputs, labels) | |
del unwrapped_model |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Don't accept this suggestion otherwise it will create a weird diff but I did this so that you can see what I meant
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for your great work on this!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for fixing this!
Just a small question - once clarified I think we're good to merge :)
model_name = unwrap_model(model.base_model)._get_name() | ||
unwrapped_model = unwrap_model(model) | ||
if is_peft_available() and isinstance(unwrapped_model, PeftModel): | ||
model_name = unwrapped_model.base_model.model._get_name() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For my own understanding, was unwrap_model(model.base_model)._get_name()
ever working?
I'm asking to understand whether we would still need the case
model_name = unwrapped_model.base_model._get_name()
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
was unwrap_model(model.base_model)._get_name() ever working?
If model
is a DistrbutedDataParallel or a FSDP wrapped module model.base_model
would fail (you need to unwrap model
first) + unwrap_model(model.base_model)._get_name()
would also fail because the model is stored in model.base_model.model. Let me know if you want more clarifications
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK, but what if the model isn't distributed? Is it possible for those models to be hitting this logic branch?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If the model isn't distributed and it is a PeftModel and w/ non-zero label-smoothing, unwrap_model(model.base_model)._get_name()
would not fail but return the wrong model_name
due to the reason mentioned in #27161.
IMO, the change @younesbelkada suggested fixes a different bug that I did not explicitly mention in the original issue, i.e., PeftModel with DDP/FSDP + label smoothing would throw an error. In other words, without the DDP/FSDP the fix of just replacing model.base_model
with model.base_model.model
would have worked.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm asking to understand whether we would still need the case
model_name = unwrapped_model.base_model._get_name()
I think in peft
, the abstraction around the true base model is always PeftModel.base_model.model
in all configurations from what I can tell (refer here and here). So this case may not be needed anymore.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That sounds like a good idea. Let me try to do that, if that is okay. I see two potential changes in the neft activate and deactivate functions. I will try to do a more thorough scan and push another commit for review.
@younesbelkada Great work on integrating NEFT so quick, thank you so much!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sounds great, thank you @kkteru !
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just pushed the changes. I don't think there is any other place where this needs to be changed. One place that came close was this prefix definition when loading the adapter state_dict, but I think that was correctly declared.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If
model
is a DistrbutedDataParallel or a FSDP wrapped modulemodel.base_model
would fail (you need to unwrapmodel
first) +unwrap_model(model.base_model)._get_name()
would also fail because the model is stored in model.base_model.model. Let me know if you want more clarifications
I actually noticed similar potential issue in the sft_trainer/neft support of trl package here and here. Happy to push a PR there or leave it to you to clean up after.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ah yes, if you could submit PRs on TRL that would be great as well ! Thanks @kkteru !
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks again for your great work!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for fixing and iterating!
…7162) * Fixed base model class name extraction from PeftModels * Changes to first unwrap the model then extract the base model name * Changed base_model to base_model.model to stay consistent with peft model abstractions
What does this PR do?
Fixes #27161
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
@pacman100, @muellerzr