Skip to content
This repository has been archived by the owner on Dec 6, 2023. It is now read-only.

Commit

Permalink
Merge pull request #124 from biswaroop1547/feat/chatml-prompt-template
Browse files Browse the repository at this point in the history
  • Loading branch information
casperdcl authored Oct 26, 2023
2 parents e675886 + 44c2a04 commit c4b1dab
Show file tree
Hide file tree
Showing 5 changed files with 94 additions and 8 deletions.
2 changes: 1 addition & 1 deletion cht-llama-cpp/build-aarch64-apple-darwin.sh
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#!/bin/bash
set -e

export VERSION=1.1.1
export VERSION=1.1.2

test -f venv/bin/activate || python -m venv venv
source venv/bin/activate
Expand Down
2 changes: 1 addition & 1 deletion cht-llama-cpp/build.sh
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#!/bin/bash
set -e
export VERSION=1.1.0
export VERSION=1.1.2
source "$(dirname "${BASH_SOURCE[0]}")/../utils.sh"

build_cpu ghcr.io/premai-io/chat-mistral-7b-instruct-q5 mistral-7b-instruct-v0.1.Q5_0 --build-arg="MODEL_ID=mistral-7b-instruct-v0.1.Q5_0" --build-arg="MODEL_DOWNLOAD_URL=https://huggingface.co/TheBloke/Mistral-7B-Instruct-v0.1-GGUF/resolve/main/mistral-7b-instruct-v0.1.Q5_0.gguf" ${@:1}
Expand Down
6 changes: 5 additions & 1 deletion cht-llama-cpp/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@
load_dotenv()

MODEL_PATH = f"./ml/models/{os.getenv('MODEL_ID', 'mistral-7b-instruct-v0.1.Q5_0')}.gguf"
# Mistral gguf follows ChatML syntax
# https://github.com/openai/openai-python/blob/main/chatml.md
PROMPT_TEMPLATE_STRING = '{"system_prompt_template": "<|im_start|>system\\n{}\\n<|im_end|>\\n", "default_system_text": "You are an helpful AI assistant.", "user_prompt_template": "<|im_start|>user\\n{}\\n<|im_end|>\\n", "assistant_prompt_template": "<|im_start|>assistant\\n{}\\n<|im_end|>\\n", "request_assistant_response_token": "<|im_start|>assistant\\n", "template_format": "chatml"}' # noqa

if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model-path", help="Path to GGUF", default=MODEL_PATH)
Expand All @@ -29,7 +33,7 @@ def create_start_app_handler(app: FastAPI):
def start_app() -> None:
from models import LLaMACPPBasedModel

LLaMACPPBasedModel.get_model(MODEL_PATH)
LLaMACPPBasedModel.get_model(MODEL_PATH, PROMPT_TEMPLATE_STRING)

return start_app

Expand Down
58 changes: 54 additions & 4 deletions cht-llama-cpp/models.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,26 @@
import json
import multiprocessing
from typing import Any, Dict, List

from llama_cpp import Llama
from llama_cpp import Llama, llama_chat_format, llama_types

DEFAULT_N_THREADS = max(multiprocessing.cpu_count() // 2, 1)


@llama_chat_format.register_chat_format("chatml")
def initiate_chatml_prompt_template(
messages: List[llama_types.ChatCompletionRequestMessage],
**kwargs: Any,
) -> llama_chat_format.ChatFormatterResponse:
# TODO: drop when https://github.com/abetlen/llama-cpp-python/issues/717 supports ChatML

_prompt = LLaMACPPBasedModel.stitch_prompt(messages, LLaMACPPBasedModel.PROMPT_TEMPLATE)
return llama_chat_format.ChatFormatterResponse(prompt=_prompt)


class LLaMACPPBasedModel(object):
model = None
PROMPT_TEMPLATE = {}

@classmethod
def tokenize(cls, prompt):
Expand Down Expand Up @@ -43,22 +57,58 @@ def generate(
stop = []
messages = cls.reduce_number_of_messages(messages[::-1], max_tokens)[::-1]
cls.model.n_threads = n_threads
return cls.model.create_chat_completion(
cht_resp = cls.model.create_chat_completion(
messages,
temperature=temperature,
top_p=top_p,
stream=stream,
stop=stop,
max_tokens=max_tokens,
)
if not stream and cls.PROMPT_TEMPLATE.get("template_format") == "chatml":
cht_resp["choices"][0]["message"]["content"] = (
cht_resp["choices"][0]["message"]["content"].split("\n<|im_end|>")[0].strip()
)

# TODO: handle postprocessing for streaming responses

return cht_resp

@classmethod
def get_model(cls, model_path):
def get_model(cls, model_path, prompt_template_jsonstr):
chat_format = "llama-2"
if "mistral" in model_path:
cls.PROMPT_TEMPLATE = json.loads(prompt_template_jsonstr)
chat_format = cls.PROMPT_TEMPLATE.get("template_format", "chatml")
if cls.model is None:
cls.model = Llama(model_path)
cls.model = Llama(model_path, chat_format=chat_format)

return cls.model

@classmethod
def embeddings(cls, text):
return cls.model.create_embedding(text)

@staticmethod
def stitch_prompt(messages: list, prompt_template: Dict[str, str]) -> str:
system_prompt_template = prompt_template["system_prompt_template"]
default_system_text = prompt_template["default_system_text"]
user_prompt_template = prompt_template["user_prompt_template"]
assistant_prompt_template = prompt_template["assistant_prompt_template"]
request_assistant_response_token = prompt_template.get("request_assistant_response_token", "")

system_prompt, chat_prompt = "", ""
for message in messages:
role = message["role"]
content = message["content"]
if role == "system":
system_prompt = system_prompt_template.format(content)
elif role == "user":
chat_prompt += user_prompt_template.format(content)
elif role == "assistant":
chat_prompt += assistant_prompt_template.format(content)

if not system_prompt:
system_prompt = system_prompt_template.format(default_system_text)

return system_prompt + chat_prompt + request_assistant_response_token
34 changes: 33 additions & 1 deletion cht-llama-cpp/tests/test_views.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import json

from fastapi.testclient import TestClient
from main import get_application
from main import PROMPT_TEMPLATE_STRING, get_application
from models import LLaMACPPBasedModel


def test_chat_llama_cpp() -> None:
Expand All @@ -24,3 +27,32 @@ def test_chat_llama_cpp() -> None:
},
)
assert response.status_code == 200


def test_chatml_stitch_prompt():
messages = [
{"role": "user", "content": "Why should we run ML models on premise?"},
{
"role": "assistant",
"content": "There are several reasons why an organization might choose to run machine learning (ML) models on-premise:\n\n1. Security and privacy concerns: Running ML models on-premise allows organizations to", # noqa
},
]
prompt_template = json.loads(PROMPT_TEMPLATE_STRING)
assert prompt_template["template_format"] == "chatml"
result = LLaMACPPBasedModel.stitch_prompt(messages, prompt_template=prompt_template)
assert (
result
== """<|im_start|>system
You are an helpful AI assistant.
<|im_end|>
<|im_start|>user
Why should we run ML models on premise?
<|im_end|>
<|im_start|>assistant
There are several reasons why an organization might choose to run machine learning (ML) models on-premise:
1. Security and privacy concerns: Running ML models on-premise allows organizations to
<|im_end|>
<|im_start|>assistant
"""
)

0 comments on commit c4b1dab

Please sign in to comment.