-
Notifications
You must be signed in to change notification settings - Fork 3.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add LoRA and Prefix-Tuning as Modeling Options for Improved Memory Efficiency + performance (potentially) #2840
Conversation
❌ pre-commit failed. |
2 similar comments
❌ pre-commit failed. |
❌ pre-commit failed. |
@jordiclive would this PR make it possible to load a PEFT model for inference in the chat? |
@smeyerhot This code is currently just for model training and evaluation. But should be trivial to load it for inference, it uses the same HF generate method. |
@andreaskoepf I am going to run a 30B LoRA model just on the sft datasets and will post the sampling report. |
❌ pre-commit failed. |
❌ pre-commit failed. |
freeze_layer: bool = False | ||
residual_dropout: float = 0 | ||
use_flash_attention: bool = False | ||
adapter_save_path: str = "adapter_13B_new" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this only work on 13B model? Should it be hardcoded?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it's the output path for a directory with the adapter weights in them. I've changed to just "adapter"
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@smeyerhot I've updated the configs.
change log dir names and model name for lora 13B in config
change output path for adapter
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Super great to have this!
This PR adds LoRA and prefix-tuning as modelling options (training and sampling code).
Both have shown strong performance and can outperform fine-tuning. They also can protect against the catastrophic forgetting problem which is important for chatbots. They keep the whole language model frozen so they can be distributed freely independent of the base language model.
They also allow much more memory-efficient training as there is no need for the optimizer states of the base model.
Benefits of LoRA
— See Andrej Karpathy (OpenAI) comment
— See purported google leak
Implementation Details:
generate.
pytorch_model.bin
for special tokens. Although these tokens are randomly initialized, they must be stored and saved as an external module since PEFT parameters learn to utilize them. Making them trainable is an option, but it is unlikely to make a significant difference.