Skip to content
This repository has been archived by the owner on Dec 6, 2023. It is now read-only.

Commit

Permalink
fixed context error
Browse files Browse the repository at this point in the history
  • Loading branch information
filopedraz committed Nov 8, 2023
1 parent 0ee18f1 commit 0543e13
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 20 deletions.
6 changes: 4 additions & 2 deletions cht-llama-cpp/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

load_dotenv()

MODEL_PATH = f"./ml/models/{os.getenv('MODEL_ID', 'mistral-7b-instruct-v0.1.Q5_0')}.gguf"
MODEL_PATH = f"./ml/models/{os.getenv('MODEL_ID', 'yarn-mistral-7b-128k.Q4_K_M')}.gguf"
# Mistral gguf follows ChatML syntax
# https://github.com/openai/openai-python/blob/main/chatml.md
PROMPT_TEMPLATE_STRING = '{"system_prompt_template": "<|im_start|>system\\n{}\\n<|im_end|>\\n", "default_system_text": "You are an helpful AI assistant.", "user_prompt_template": "<|im_start|>user\\n{}\\n<|im_end|>\\n", "assistant_prompt_template": "<|im_start|>assistant\\n{}\\n<|im_end|>\\n", "request_assistant_response_token": "<|im_start|>assistant\\n", "template_format": "chatml"}' # noqa
Expand All @@ -19,8 +19,10 @@
parser = argparse.ArgumentParser()
parser.add_argument("--model-path", help="Path to GGUF", default=MODEL_PATH)
parser.add_argument("--port", help="Port to run model server on", type=int, default=8000)
parser.add_argument("--ctx", help="Context dimension", type=int, default=4096)
args = parser.parse_args()
MODEL_PATH = args.model_path
MODEL_CTX = args.ctx

logging.basicConfig(
format="%(asctime)s %(levelname)-8s %(message)s",
Expand All @@ -33,7 +35,7 @@ def create_start_app_handler(app: FastAPI):
def start_app() -> None:
from models import LLaMACPPBasedModel

LLaMACPPBasedModel.get_model(MODEL_PATH, PROMPT_TEMPLATE_STRING)
LLaMACPPBasedModel.get_model(MODEL_PATH, PROMPT_TEMPLATE_STRING, MODEL_CTX)

return start_app

Expand Down
19 changes: 2 additions & 17 deletions cht-llama-cpp/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,20 +26,6 @@ class LLaMACPPBasedModel(object):
def tokenize(cls, prompt):
return cls.model.tokenize(b" " + prompt.encode("utf-8"))

@classmethod
def reduce_number_of_messages(cls, messages, max_tokens):
buffer_tokens = 32
ctx_max_tokens = 4096
num_messages = len(messages)

tokens = [len(cls.tokenize(doc["content"])) for doc in messages]

token_count = sum(tokens[:num_messages])
while token_count + max_tokens + buffer_tokens > ctx_max_tokens:
num_messages -= 1
token_count -= tokens[num_messages]
return messages[:num_messages]

@classmethod
def generate(
cls,
Expand All @@ -55,7 +41,6 @@ def generate(
):
if stop is None:
stop = []
messages = cls.reduce_number_of_messages(messages[::-1], max_tokens)[::-1]
cls.model.n_threads = n_threads
cht_resp = cls.model.create_chat_completion(
messages,
Expand All @@ -75,13 +60,13 @@ def generate(
return cht_resp

@classmethod
def get_model(cls, model_path, prompt_template_jsonstr):
def get_model(cls, model_path, prompt_template_jsonstr, n_ctx):
chat_format = "llama-2"
if "mistral" in model_path:
cls.PROMPT_TEMPLATE = json.loads(prompt_template_jsonstr)
chat_format = cls.PROMPT_TEMPLATE.get("template_format", "chatml")
if cls.model is None:
cls.model = Llama(model_path, chat_format=chat_format)
cls.model = Llama(model_path, chat_format=chat_format, n_ctx=n_ctx)

return cls.model

Expand Down
2 changes: 1 addition & 1 deletion cht-llama-cpp/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,4 @@ tqdm==4.65.0
httpx==0.23.3
python-dotenv==1.0.0
tenacity==8.2.2
llama-cpp-python==0.2.11
llama-cpp-python==0.2.14

0 comments on commit 0543e13

Please sign in to comment.