From 4f216972535cac57dd7f06b8f2da1fb56615636c Mon Sep 17 00:00:00 2001 From: lvliang-intel Date: Fri, 20 Oct 2023 07:17:42 +0800 Subject: [PATCH] Fix ChatGLM2 model loading issue (#510) * Fix ChatGLM2 model loading issue Signed-off-by: lvliang-intel --- .../neural_chat/models/model_utils.py | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/intel_extension_for_transformers/neural_chat/models/model_utils.py b/intel_extension_for_transformers/neural_chat/models/model_utils.py index c8a050ef114..d54c6a2c2e9 100644 --- a/intel_extension_for_transformers/neural_chat/models/model_utils.py +++ b/intel_extension_for_transformers/neural_chat/models/model_utils.py @@ -28,6 +28,7 @@ from typing import List from transformers import ( GenerationConfig, + AutoModel, AutoModelForCausalLM, AutoModelForSeq2SeqLM, AutoTokenizer, @@ -331,9 +332,11 @@ def load_model( use_fast=False if (re.search("llama", model_name, re.IGNORECASE) or re.search("neural-chat-7b-v2", model_name, re.IGNORECASE)) else True, use_auth_token=hf_access_token, - trust_remote_code=True if (re.search("qwen", model_name, re.IGNORECASE)) else False, + trust_remote_code=True if (re.search("qwen", model_name, re.IGNORECASE) or \ + re.search("chatglm", model_name, re.IGNORECASE)) else False, ) - config = AutoConfig.from_pretrained(model_name, use_auth_token=hf_access_token) + config = AutoConfig.from_pretrained(model_name, use_auth_token=hf_access_token, trust_remote_code=True \ + if re.search("chatglm", model_name, re.IGNORECASE) else False) load_to_meta = model_on_meta(config) if peft_path and device == "hpu" and use_deepspeed and load_to_meta: print("PEFT could not work in deepspeed sharded checkpt loading mode, set load_to_meta to False") @@ -350,6 +353,14 @@ def load_model( use_auth_token=hf_access_token, quantization_config=bitsandbytes_quant_config, ) + elif re.search("chatglm", model_name, re.IGNORECASE) and not ipex_int8: + with smart_context_manager(use_deepspeed=use_deepspeed): + model = AutoModel.from_pretrained( + model_name, + torch_dtype=torch_dtype, + low_cpu_mem_usage=True, + use_auth_token=hf_access_token, + trust_remote_code=True) elif ( re.search("gpt", model_name, re.IGNORECASE) or re.search("mpt", model_name, re.IGNORECASE) @@ -394,11 +405,13 @@ def load_model( if ( hasattr(model.generation_config, "pad_token_id") and model.generation_config.pad_token_id is not None + and not "chatglm" in model_name ): tokenizer.pad_token_id = model.generation_config.pad_token_id if ( hasattr(model.generation_config, "eos_token_id") and model.generation_config.eos_token_id is not None + and not "chatglm" in model_name ): tokenizer.eos_token_id = model.generation_config.eos_token_id if (