-
Notifications
You must be signed in to change notification settings - Fork 2.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix ptuning residuals bug #6866
Conversation
…T inference Signed-off-by: arendu <adithya.r@gmail.com>
for more information, see https://pre-commit.ci
…nto adithyare/lora_fix
Signed-off-by: arendu <adithya.r@gmail.com>
for more information, see https://pre-commit.ci
nemo/collections/nlp/models/language_modeling/megatron_gpt_peft_models.py
Show resolved
Hide resolved
nemo/collections/nlp/modules/common/megatron/adapters/parallel_adapters.py
Show resolved
Hide resolved
|
||
def forward(self, batch_size): | ||
def _forward(self,): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Better to not make this private for subclasses, rename to forward_inner()
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good point!
virtual_embeddings = self.forward_single_enabled_adapter_( | ||
_bs, ptuning_adapter, adapter_name=AdapterName.PTUNING_ADAPTER, adapter_strategy=strategy, | ||
) | ||
virtual_embeddings = ptuning_adapter(_bs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Each adapter type has its own strategy - by removing this you're hard coding the logic and side stepping the strategy. Its not necessary to do that is it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agreed, its not necessary, but just makes the readability and maintainability easier. Its just clearer what is happening to read the code and see a residual connection rather than follow up with where a default strategy is coming from.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh one other limitation i think is that i can not pass additional args? for example for ptuning_adapter the forward now accepts an additional arg like used_cached_reps
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
After discussion, seems this is never going to be needed by nlp domain, so its fine to directly call adaper
@@ -70,7 +70,7 @@ def __init__( | |||
self.prompt_embeddings.weight.requires_grad = False | |||
|
|||
# Set fixed indicies for forward pass | |||
self.register_buffer('indices', torch.LongTensor(list(range(self.total_virtual_tokens)))) | |||
self.register_buffer("indices", torch.LongTensor(list(range(self.total_virtual_tokens))), persistent=False) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Probably a breaking change since older peft modules will have this but newer ones wont. Need to check
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good catch! will check!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok this will indeed break older p-tuning checkpoints. However, it is a nice "feature" because older checkpoints will need to be converted anyway to a new param naming format. In that conversion step (which need to be written) I will remove the indices.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The persistent part will be the issue, but if that's ok then good enough
…nto adithyare/lora_fix
@@ -70,7 +70,7 @@ def __init__( | |||
self.prompt_embeddings.weight.requires_grad = False | |||
|
|||
# Set fixed indicies for forward pass | |||
self.register_buffer('indices', torch.LongTensor(list(range(self.total_virtual_tokens)))) | |||
self.register_buffer("indices", torch.LongTensor(list(range(self.total_virtual_tokens))), persistent=False) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The persistent part will be the issue, but if that's ok then good enough
virtual_embeddings = self.forward_single_enabled_adapter_( | ||
_bs, ptuning_adapter, adapter_name=AdapterName.PTUNING_ADAPTER, adapter_strategy=strategy, | ||
) | ||
virtual_embeddings = ptuning_adapter(_bs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
After discussion, seems this is never going to be needed by nlp domain, so its fine to directly call adaper
What does this PR do ?
Collection: [NLP]
Changelog
Usage
# Add a code snippet demonstrating how to use this
Before your PR is "Ready for review"
Pre checks:
PR Type:
If you haven't finished some of the above items you can still open "Draft" PR.
Who can review?
Anyone in the community is free to review the PR once the checks have passed.
Contributor guidelines contains specific people who can review PRs to various areas.
Additional Information