Skip to content

Commit

Permalink
Add support to load qLora tuned model
Browse files Browse the repository at this point in the history
Signed-off-by: Angel Luu <angel.luu@us.ibm.com>
  • Loading branch information
aluu317 committed Aug 26, 2024
1 parent daca551 commit af0b22e
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 13 deletions.
29 changes: 23 additions & 6 deletions build/accelerate_launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,9 @@

# Third Party
from accelerate.commands.launch import launch_command
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import AutoModelForCausalLM, AutoTokenizer, GPTQConfig
from peft import PeftModel
from torch import bfloat16
from torch import bfloat16, float16

# Local
from build.utils import (
Expand Down Expand Up @@ -142,11 +142,28 @@ def main():

if os.path.exists(adapter_config_path):
base_model_path = get_base_model_from_adapter_config(adapter_config_path)
base_model = AutoModelForCausalLM.from_pretrained(
base_model_path,
attn_implementation="flash_attention_2" if use_flash_attn else None,
torch_dtype=bfloat16 if use_flash_attn else None,

is_quantized = os.path.exists(
os.path.join(base_model_path, "quantize_config.json")
)
if is_quantized:
print("ANGEL QLORA DEBUG: this model is quantized")

gptq_config = GPTQConfig(bits=4, exllama_config={"version": 2})
base_model = AutoModelForCausalLM.from_pretrained(
base_model_path,
attn_implementation="flash_attention_2" if use_flash_attn else None,
device_map="auto",
torch_dtype=float16 if use_flash_attn else None,
quantization_config=gptq_config,
)
else:
print("ANGEL QLORA DEBUG: this model is NOT quantized")
base_model = AutoModelForCausalLM.from_pretrained(
base_model_path,
attn_implementation="flash_attention_2" if use_flash_attn else None,
torch_dtype=bfloat16 if use_flash_attn else None,
)

# since the peft library (PEFTModelForCausalLM) does not handle cases
# where the model's layers are modified, in our case the embedding layer
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ dependencies = [
"tokenizers>=0.13.3,<1.0",
"tqdm>=4.66.2,<5.0",
"trl>=0.9.3,<1.0",
"optimum>=1.15.0",
"peft>=0.8.0,<0.13",
"datasets>=2.15.0,<3.0",
"fire>=0.5.0,<1.0",
Expand Down
32 changes: 25 additions & 7 deletions scripts/run_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
# Third Party
from peft import PeftModel
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import AutoModelForCausalLM, AutoTokenizer, GPTQConfig
import torch

# Local
Expand Down Expand Up @@ -183,13 +183,31 @@ def load(
try:
if base_model_name_or_path is None:
raise ValueError("base_model_name_or_path has to be passed")
base_model = AutoModelForCausalLM.from_pretrained(
base_model_name_or_path,
attn_implementation="flash_attention_2"
if use_flash_attn
else None,
torch_dtype=torch.bfloat16 if use_flash_attn else None,

# TODO: what to do when base_model_name_or_path is a model name, and not a path
is_quantized = os.path.exists(
os.path.join(base_model_name_or_path, "quantize_config.json")
)
if is_quantized:
gptq_config = GPTQConfig(bits=4, exllama_config={"version": 2})
base_model = AutoModelForCausalLM.from_pretrained(
base_model_name_or_path,
attn_implementation="flash_attention_2"
if use_flash_attn
else None,
device_map="auto",
torch_dtype=torch.float16 if use_flash_attn else None,
quantization_config=gptq_config,
)
else:
base_model = AutoModelForCausalLM.from_pretrained(
base_model_name_or_path,
attn_implementation="flash_attention_2"
if use_flash_attn
else None,
torch_dtype=torch.bfloat16 if use_flash_attn else None,
)

# since the peft library (PEFTModelForCausalLM) does not handle cases
# where the model's layers are modified, in our case the embedding layer
# is modified, so we resize the backbone model's embedding layer with our own
Expand Down

0 comments on commit af0b22e

Please sign in to comment.