From cece95c73002bc13090fcfd434f7c15274d3475c Mon Sep 17 00:00:00 2001 From: ganler Date: Fri, 19 Apr 2024 20:08:43 -0500 Subject: [PATCH] feat: configurable system message and disable sys msg for mixtral --- repoqa/provider/base.py | 4 +--- repoqa/provider/openai.py | 6 +++--- repoqa/provider/request/__init__.py | 7 +++++++ repoqa/provider/request/openai.py | 7 +++---- repoqa/provider/vllm.py | 11 ++++------- repoqa/search_needle_function.py | 13 ++++++++++++- 6 files changed, 30 insertions(+), 18 deletions(-) diff --git a/repoqa/provider/base.py b/repoqa/provider/base.py index d41e351..2d373f0 100644 --- a/repoqa/provider/base.py +++ b/repoqa/provider/base.py @@ -5,12 +5,10 @@ from abc import ABC, abstractmethod from typing import List -SYSTEM_MSG = "You are a helpful assistant good at code understanding." - class BaseProvider(ABC): @abstractmethod def generate_reply( - self, question, n=1, max_tokens=1024, temperature=0 + self, question, n=1, max_tokens=1024, temperature=0, system_msg=None ) -> List[str]: ... diff --git a/repoqa/provider/openai.py b/repoqa/provider/openai.py index 9985b59..681fe18 100644 --- a/repoqa/provider/openai.py +++ b/repoqa/provider/openai.py @@ -7,7 +7,7 @@ from openai import Client -from repoqa.provider.base import SYSTEM_MSG, BaseProvider +from repoqa.provider.base import BaseProvider from repoqa.provider.request.openai import make_auto_request @@ -19,7 +19,7 @@ def __init__(self, model, base_url: str = None): ) def generate_reply( - self, question, n=1, max_tokens=1024, temperature=0 + self, question, n=1, max_tokens=1024, temperature=0, system_msg=None ) -> List[str]: assert temperature != 0 or n == 1, "n must be 1 when temperature is 0" replies = make_auto_request( @@ -29,7 +29,7 @@ def generate_reply( temperature=temperature, n=n, max_tokens=max_tokens, - system_msg=SYSTEM_MSG, + system_msg=system_msg, ) return [reply.message.content for reply in replies.choices] diff --git a/repoqa/provider/request/__init__.py b/repoqa/provider/request/__init__.py index 113d9a9..62276bc 100644 --- a/repoqa/provider/request/__init__.py +++ b/repoqa/provider/request/__init__.py @@ -1,3 +1,10 @@ # SPDX-FileCopyrightText: (c) 2024 EvalPlus Team # # SPDX-License-Identifier: Apache-2.0 + + +def construct_message_list(message, system_message=None): + msglist = [{"role": "user", "content": message}] + if system_message: + msglist.insert(0, {"role": "system", "content": system_message}) + return msglist diff --git a/repoqa/provider/request/openai.py b/repoqa/provider/request/openai.py index 1718706..4fb97da 100644 --- a/repoqa/provider/request/openai.py +++ b/repoqa/provider/request/openai.py @@ -8,6 +8,8 @@ import openai from openai.types.chat import ChatCompletion +from repoqa.provider.request import construct_message_list + def make_request( client: openai.Client, @@ -21,10 +23,7 @@ def make_request( ) -> ChatCompletion: return client.chat.completions.create( model=model, - messages=[ - {"role": "system", "content": system_msg}, - {"role": "user", "content": message}, - ], + messages=construct_message_list(message, system_message=system_msg), max_tokens=max_tokens, temperature=temperature, n=n, diff --git a/repoqa/provider/vllm.py b/repoqa/provider/vllm.py index c27683c..5c5538a 100644 --- a/repoqa/provider/vllm.py +++ b/repoqa/provider/vllm.py @@ -7,7 +7,8 @@ from transformers import AutoTokenizer from vllm import LLM, SamplingParams -from repoqa.provider.base import SYSTEM_MSG, BaseProvider +from repoqa.provider.base import BaseProvider +from repoqa.provider.request import construct_message_list class VllmProvider(BaseProvider): @@ -20,15 +21,11 @@ def __init__(self, model, tensor_parallel_size, max_model_len): ) def generate_reply( - self, question, n=1, max_tokens=1024, temperature=0 + self, question, n=1, max_tokens=1024, temperature=0, system_msg=None ) -> List[str]: assert temperature != 0 or n == 1, "n must be 1 when temperature is 0" prompt = self.tokenizer.apply_chat_template( - [ - {"role": "system", "content": SYSTEM_MSG}, - {"role": "user", "content": question}, - ], - tokenize=False, + construct_message_list(question, system_msg), tokenize=False ) vllm_outputs = self.llm.generate( [prompt], diff --git a/repoqa/search_needle_function.py b/repoqa/search_needle_function.py index dac1b56..12d2f5b 100644 --- a/repoqa/search_needle_function.py +++ b/repoqa/search_needle_function.py @@ -163,6 +163,7 @@ def evaluate_model( result_dir: str = "results", languages: List[str] = None, caching: bool = False, # if enabled, will cache the tasks which can be used to resume + system_message: str = "You are a helpful assistant good at code understanding.", ): if backend is None: if base_url is not None: @@ -302,6 +303,14 @@ def evaluate_model( max_model_len=int(code_context_size * 1.25), # Magic number ) + if model == "mistralai/Mixtral-8x7B-Instruct-v0.1": + system_message = None + print(f"Warning: {model} does not support system message") + + if not system_message: + print("🔥 System message is disabled") + system_message = None + with open(result_file, "a") as f_out: with progress(f"Running {model}") as pbar: for task in pbar.track(tasks): @@ -316,7 +325,9 @@ def evaluate_model( for key in task["template"].split("\n"): prompt += task[key] - replies = engine.generate_reply(prompt, n=1, max_tokens=max_new_tokens) + replies = engine.generate_reply( + prompt, n=1, max_tokens=max_new_tokens, system_msg=system_message + ) result = {**task, "output": replies} f_out.write(json.dumps(result) + "\n") results.append(result)