From 7baa96bc68f4802dcf3938ec33e00f0c9b6691c5 Mon Sep 17 00:00:00 2001 From: lvliang-intel Date: Mon, 20 Nov 2023 10:08:06 +0800 Subject: [PATCH] Support CodeLlama model in NeuralChat (#711) * Support neural-chat-7b-v3 and neural-chat-7b-v3-1 Signed-off-by: lvliang-intel --- .../neural_chat/chatbot.py | 2 +- .../neural_chat/models/base_model.py | 4 +-- .../neural_chat/models/model_utils.py | 4 ++- .../tests/nightly/models/test_model.py | 28 +++++++++++++++++++ 4 files changed, 34 insertions(+), 4 deletions(-) diff --git a/intel_extension_for_transformers/neural_chat/chatbot.py b/intel_extension_for_transformers/neural_chat/chatbot.py index 5c2c52ebe6a..72e25003db2 100644 --- a/intel_extension_for_transformers/neural_chat/chatbot.py +++ b/intel_extension_for_transformers/neural_chat/chatbot.py @@ -72,7 +72,7 @@ def build_chatbot(config: PipelineConfig=None): adapter = BaseModel() else: raise ValueError("NeuralChat Error: Unsupported model name or path, \ - only supports FLAN-T5/LLAMA/MPT/GPT/BLOOM/OPT/QWEN/NEURAL-CHAT/MISTRAL now.") + only supports FLAN-T5/LLAMA/MPT/GPT/BLOOM/OPT/QWEN/NEURAL-CHAT/MISTRAL/CODELLAMA/STARCODER now.") # register plugin instance in model adaptor if config.plugins: diff --git a/intel_extension_for_transformers/neural_chat/models/base_model.py b/intel_extension_for_transformers/neural_chat/models/base_model.py index 9f3ee37b920..1e8609821a8 100644 --- a/intel_extension_for_transformers/neural_chat/models/base_model.py +++ b/intel_extension_for_transformers/neural_chat/models/base_model.py @@ -145,7 +145,7 @@ def predict_stream(self, query, config=None): query_include_prompt = False self.get_conv_template(self.model_name, config.task) if (self.conv_template.roles[0] in query and self.conv_template.roles[1] in query) or \ - "starcoder" in self.model_name: + "starcoder" in self.model_name or "codellama" in self.model_name.lower(): query_include_prompt = True # plugin pre actions @@ -220,7 +220,7 @@ def predict(self, query, config=None): query_include_prompt = False self.get_conv_template(self.model_name, config.task) if (self.conv_template.roles[0] in query and self.conv_template.roles[1] in query) or \ - "starcoder" in self.model_name: + "starcoder" in self.model_name or "codellama" in self.model_name.lower(): query_include_prompt = True # plugin pre actions diff --git a/intel_extension_for_transformers/neural_chat/models/model_utils.py b/intel_extension_for_transformers/neural_chat/models/model_utils.py index cc5040fd942..ef4f405f883 100644 --- a/intel_extension_for_transformers/neural_chat/models/model_utils.py +++ b/intel_extension_for_transformers/neural_chat/models/model_utils.py @@ -365,6 +365,7 @@ def load_model( or re.search("neural-chat-7b-v3", model_name, re.IGNORECASE) or re.search("qwen", model_name, re.IGNORECASE) or re.search("starcoder", model_name, re.IGNORECASE) + or re.search("codellama", model_name, re.IGNORECASE) or re.search("Mistral", model_name, re.IGNORECASE) ) and not ipex_int8) or re.search("opt", model_name, re.IGNORECASE): with smart_context_manager(use_deepspeed=use_deepspeed): @@ -377,6 +378,7 @@ def load_model( ) elif ( (re.search("starcoder", model_name, re.IGNORECASE) + or re.search("codellama", model_name, re.IGNORECASE) ) and ipex_int8 ): with smart_context_manager(use_deepspeed=use_deepspeed): @@ -389,7 +391,7 @@ def load_model( else: raise ValueError( f"Unsupported model {model_name}, only supports " - "FLAN-T5/LLAMA/MPT/GPT/BLOOM/OPT/QWEN/NEURAL-CHAT/MISTRAL now." + "FLAN-T5/LLAMA/MPT/GPT/BLOOM/OPT/QWEN/NEURAL-CHAT/MISTRAL/CODELLAMA/STARCODER now." ) if re.search("llama", model.config.architectures[0], re.IGNORECASE): diff --git a/intel_extension_for_transformers/neural_chat/tests/nightly/models/test_model.py b/intel_extension_for_transformers/neural_chat/tests/nightly/models/test_model.py index 7ac1f28da18..454eac0c1ea 100644 --- a/intel_extension_for_transformers/neural_chat/tests/nightly/models/test_model.py +++ b/intel_extension_for_transformers/neural_chat/tests/nightly/models/test_model.py @@ -144,5 +144,33 @@ def test_get_default_conv_template_v3_1(self): print(result) self.assertIn('The Intel Xeon Scalable Processors', str(result)) +class TestStarCoderModel(unittest.TestCase): + def setUp(self): + return super().setUp() + + def tearDown(self) -> None: + return super().tearDown() + + def test_code_gen(self): + config = PipelineConfig(model_name_or_path="bigcode/starcoder") + chatbot = build_chatbot(config=config) + result = chatbot.predict("def print_hello_world():") + print(result) + self.assertIn("""print('Hello World')""", str(result)) + +class TestCodeLlamaModel(unittest.TestCase): + def setUp(self): + return super().setUp() + + def tearDown(self) -> None: + return super().tearDown() + + def test_code_gen(self): + config = PipelineConfig(model_name_or_path="codellama/CodeLlama-7b-hf") + chatbot = build_chatbot(config=config) + result = chatbot.predict("def print_hello_world():") + print(result) + self.assertIn("""print('Hello World')""", str(result)) + if __name__ == "__main__": unittest.main() \ No newline at end of file