Skip to content
This repository has been archived by the owner on Oct 25, 2024. It is now read-only.

Commit

Permalink
Fix ChatGLM2 model loading issue (#510)
Browse files Browse the repository at this point in the history
* Fix ChatGLM2 model loading issue

Signed-off-by: lvliang-intel <liang1.lv@intel.com>
  • Loading branch information
lvliang-intel authored and VincyZhang committed Oct 23, 2023
1 parent e80fb9e commit 4f21697
Showing 1 changed file with 15 additions and 2 deletions.
17 changes: 15 additions & 2 deletions intel_extension_for_transformers/neural_chat/models/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from typing import List
from transformers import (
GenerationConfig,
AutoModel,
AutoModelForCausalLM,
AutoModelForSeq2SeqLM,
AutoTokenizer,
Expand Down Expand Up @@ -331,9 +332,11 @@ def load_model(
use_fast=False if (re.search("llama", model_name, re.IGNORECASE)
or re.search("neural-chat-7b-v2", model_name, re.IGNORECASE)) else True,
use_auth_token=hf_access_token,
trust_remote_code=True if (re.search("qwen", model_name, re.IGNORECASE)) else False,
trust_remote_code=True if (re.search("qwen", model_name, re.IGNORECASE) or \
re.search("chatglm", model_name, re.IGNORECASE)) else False,
)
config = AutoConfig.from_pretrained(model_name, use_auth_token=hf_access_token)
config = AutoConfig.from_pretrained(model_name, use_auth_token=hf_access_token, trust_remote_code=True \
if re.search("chatglm", model_name, re.IGNORECASE) else False)
load_to_meta = model_on_meta(config)
if peft_path and device == "hpu" and use_deepspeed and load_to_meta:
print("PEFT could not work in deepspeed sharded checkpt loading mode, set load_to_meta to False")
Expand All @@ -350,6 +353,14 @@ def load_model(
use_auth_token=hf_access_token,
quantization_config=bitsandbytes_quant_config,
)
elif re.search("chatglm", model_name, re.IGNORECASE) and not ipex_int8:
with smart_context_manager(use_deepspeed=use_deepspeed):
model = AutoModel.from_pretrained(
model_name,
torch_dtype=torch_dtype,
low_cpu_mem_usage=True,
use_auth_token=hf_access_token,
trust_remote_code=True)
elif (
re.search("gpt", model_name, re.IGNORECASE)
or re.search("mpt", model_name, re.IGNORECASE)
Expand Down Expand Up @@ -394,11 +405,13 @@ def load_model(
if (
hasattr(model.generation_config, "pad_token_id")
and model.generation_config.pad_token_id is not None
and not "chatglm" in model_name
):
tokenizer.pad_token_id = model.generation_config.pad_token_id
if (
hasattr(model.generation_config, "eos_token_id")
and model.generation_config.eos_token_id is not None
and not "chatglm" in model_name
):
tokenizer.eos_token_id = model.generation_config.eos_token_id
if (
Expand Down

0 comments on commit 4f21697

Please sign in to comment.