Skip to content

Commit

Permalink
Put launch script code back
Browse files Browse the repository at this point in the history
Signed-off-by: Angel Luu <angel.luu@us.ibm.com>
  • Loading branch information
aluu317 committed Aug 30, 2024
1 parent 8254224 commit 23c71a2
Showing 1 changed file with 7 additions and 24 deletions.
31 changes: 7 additions & 24 deletions build/accelerate_launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,9 @@

# Third Party
from accelerate.commands.launch import launch_command
from transformers import AutoModelForCausalLM, AutoTokenizer, GPTQConfig
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
from torch import bfloat16, float16
from torch import bfloat16

# Local
from build.utils import (
Expand Down Expand Up @@ -142,28 +142,11 @@ def main():

if os.path.exists(adapter_config_path):
base_model_path = get_base_model_from_adapter_config(adapter_config_path)

is_quantized = os.path.exists(
os.path.join(base_model_path, "quantize_config.json")
base_model = AutoModelForCausalLM.from_pretrained(
base_model_path,
attn_implementation="flash_attention_2" if use_flash_attn else None,
torch_dtype=bfloat16 if use_flash_attn else None,
)
if is_quantized:
print("ANGEL QLORA DEBUG: this model is quantized")

gptq_config = GPTQConfig(bits=4, exllama_config={"version": 2})
base_model = AutoModelForCausalLM.from_pretrained(
base_model_path,
attn_implementation="flash_attention_2" if use_flash_attn else None,
device_map="auto",
torch_dtype=float16 if use_flash_attn else None,
quantization_config=gptq_config,
)
else:
print("ANGEL QLORA DEBUG: this model is NOT quantized")
base_model = AutoModelForCausalLM.from_pretrained(
base_model_path,
attn_implementation="flash_attention_2" if use_flash_attn else None,
torch_dtype=bfloat16 if use_flash_attn else None,
)

# since the peft library (PEFTModelForCausalLM) does not handle cases
# where the model's layers are modified, in our case the embedding layer
Expand Down Expand Up @@ -233,4 +216,4 @@ def main():


if __name__ == "__main__":
main()
main()

0 comments on commit 23c71a2

Please sign in to comment.