Skip to content

Commit

Permalink
feat: configurable system message and disable sys msg for mixtral
Browse files Browse the repository at this point in the history
  • Loading branch information
ganler committed Apr 20, 2024
1 parent 4e5a109 commit cece95c
Show file tree
Hide file tree
Showing 6 changed files with 30 additions and 18 deletions.
4 changes: 1 addition & 3 deletions repoqa/provider/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,10 @@
from abc import ABC, abstractmethod
from typing import List

SYSTEM_MSG = "You are a helpful assistant good at code understanding."


class BaseProvider(ABC):
@abstractmethod
def generate_reply(
self, question, n=1, max_tokens=1024, temperature=0
self, question, n=1, max_tokens=1024, temperature=0, system_msg=None
) -> List[str]:
...
6 changes: 3 additions & 3 deletions repoqa/provider/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from openai import Client

from repoqa.provider.base import SYSTEM_MSG, BaseProvider
from repoqa.provider.base import BaseProvider
from repoqa.provider.request.openai import make_auto_request


Expand All @@ -19,7 +19,7 @@ def __init__(self, model, base_url: str = None):
)

def generate_reply(
self, question, n=1, max_tokens=1024, temperature=0
self, question, n=1, max_tokens=1024, temperature=0, system_msg=None
) -> List[str]:
assert temperature != 0 or n == 1, "n must be 1 when temperature is 0"
replies = make_auto_request(
Expand All @@ -29,7 +29,7 @@ def generate_reply(
temperature=temperature,
n=n,
max_tokens=max_tokens,
system_msg=SYSTEM_MSG,
system_msg=system_msg,
)

return [reply.message.content for reply in replies.choices]
7 changes: 7 additions & 0 deletions repoqa/provider/request/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,10 @@
# SPDX-FileCopyrightText: (c) 2024 EvalPlus Team
#
# SPDX-License-Identifier: Apache-2.0


def construct_message_list(message, system_message=None):
msglist = [{"role": "user", "content": message}]
if system_message:
msglist.insert(0, {"role": "system", "content": system_message})
return msglist
7 changes: 3 additions & 4 deletions repoqa/provider/request/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
import openai
from openai.types.chat import ChatCompletion

from repoqa.provider.request import construct_message_list


def make_request(
client: openai.Client,
Expand All @@ -21,10 +23,7 @@ def make_request(
) -> ChatCompletion:
return client.chat.completions.create(
model=model,
messages=[
{"role": "system", "content": system_msg},
{"role": "user", "content": message},
],
messages=construct_message_list(message, system_message=system_msg),
max_tokens=max_tokens,
temperature=temperature,
n=n,
Expand Down
11 changes: 4 additions & 7 deletions repoqa/provider/vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
from transformers import AutoTokenizer
from vllm import LLM, SamplingParams

from repoqa.provider.base import SYSTEM_MSG, BaseProvider
from repoqa.provider.base import BaseProvider
from repoqa.provider.request import construct_message_list


class VllmProvider(BaseProvider):
Expand All @@ -20,15 +21,11 @@ def __init__(self, model, tensor_parallel_size, max_model_len):
)

def generate_reply(
self, question, n=1, max_tokens=1024, temperature=0
self, question, n=1, max_tokens=1024, temperature=0, system_msg=None
) -> List[str]:
assert temperature != 0 or n == 1, "n must be 1 when temperature is 0"
prompt = self.tokenizer.apply_chat_template(
[
{"role": "system", "content": SYSTEM_MSG},
{"role": "user", "content": question},
],
tokenize=False,
construct_message_list(question, system_msg), tokenize=False
)
vllm_outputs = self.llm.generate(
[prompt],
Expand Down
13 changes: 12 additions & 1 deletion repoqa/search_needle_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,7 @@ def evaluate_model(
result_dir: str = "results",
languages: List[str] = None,
caching: bool = False, # if enabled, will cache the tasks which can be used to resume
system_message: str = "You are a helpful assistant good at code understanding.",
):
if backend is None:
if base_url is not None:
Expand Down Expand Up @@ -302,6 +303,14 @@ def evaluate_model(
max_model_len=int(code_context_size * 1.25), # Magic number
)

if model == "mistralai/Mixtral-8x7B-Instruct-v0.1":
system_message = None
print(f"Warning: {model} does not support system message")

if not system_message:
print("🔥 System message is disabled")
system_message = None

with open(result_file, "a") as f_out:
with progress(f"Running {model}") as pbar:
for task in pbar.track(tasks):
Expand All @@ -316,7 +325,9 @@ def evaluate_model(
for key in task["template"].split("\n"):
prompt += task[key]

replies = engine.generate_reply(prompt, n=1, max_tokens=max_new_tokens)
replies = engine.generate_reply(
prompt, n=1, max_tokens=max_new_tokens, system_msg=system_message
)
result = {**task, "output": replies}
f_out.write(json.dumps(result) + "\n")
results.append(result)
Expand Down

0 comments on commit cece95c

Please sign in to comment.