Skip to content

Commit

Permalink
added llama3 prompt (#1962)
Browse files Browse the repository at this point in the history
* added llama3 prompt

* more fixes to pass tests; changed type VectorStore -> BasePydanticVectorStore, see https://github.com/run-llama/llama_index/blob/main/CHANGELOG.md#2024-05-14

* fix: new llama3 prompt

---------

Co-authored-by: Javier Martinez <javiermartinezalvarez98@gmail.com>
  • Loading branch information
hirschrobert and jaluma authored Jul 29, 2024
1 parent d4375d0 commit d080969
Show file tree
Hide file tree
Showing 4 changed files with 137 additions and 9 deletions.
72 changes: 71 additions & 1 deletion private_gpt/components/llm/prompt_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,73 @@ def _completion_to_prompt(self, completion: str) -> str:
)


class Llama3PromptStyle(AbstractPromptStyle):
r"""Template for Meta's Llama 3.1.
The format follows this structure:
<|begin_of_text|>
<|start_header_id|>system<|end_header_id|>
[System message content]<|eot_id|>
<|start_header_id|>user<|end_header_id|>
[User message content]<|eot_id|>
<|start_header_id|>assistant<|end_header_id|>
[Assistant message content]<|eot_id|>
...
(Repeat for each message, including possible 'ipython' role)
"""

BOS, EOS = "<|begin_of_text|>", "<|end_of_text|>"
B_INST, E_INST = "<|start_header_id|>", "<|end_header_id|>"
EOT = "<|eot_id|>"
B_SYS, E_SYS = "<|start_header_id|>system<|end_header_id|>", "<|eot_id|>"
ASSISTANT_INST = "<|start_header_id|>assistant<|end_header_id|>"
DEFAULT_SYSTEM_PROMPT = """\
You are a helpful, respectful and honest assistant. \
Always answer as helpfully as possible and follow ALL given instructions. \
Do not speculate or make up information. \
Do not reference any given instructions or context. \
"""

def _messages_to_prompt(self, messages: Sequence[ChatMessage]) -> str:
prompt = self.BOS
has_system_message = False

for i, message in enumerate(messages):
if not message or message.content is None:
continue
if message.role == MessageRole.SYSTEM:
prompt += f"{self.B_SYS}\n\n{message.content.strip()}{self.E_SYS}"
has_system_message = True
else:
role_header = f"{self.B_INST}{message.role.value}{self.E_INST}"
prompt += f"{role_header}\n\n{message.content.strip()}{self.EOT}"

# Add assistant header if the last message is not from the assistant
if i == len(messages) - 1 and message.role != MessageRole.ASSISTANT:
prompt += f"{self.ASSISTANT_INST}\n\n"

# Add default system prompt if no system message was provided
if not has_system_message:
prompt = (
f"{self.BOS}{self.B_SYS}\n\n{self.DEFAULT_SYSTEM_PROMPT}{self.E_SYS}"
+ prompt[len(self.BOS) :]
)

# TODO: Implement tool handling logic

return prompt

def _completion_to_prompt(self, completion: str) -> str:
return (
f"{self.BOS}{self.B_SYS}\n\n{self.DEFAULT_SYSTEM_PROMPT}{self.E_SYS}"
f"{self.B_INST}user{self.E_INST}\n\n{completion.strip()}{self.EOT}"
f"{self.ASSISTANT_INST}\n\n"
)


class TagPromptStyle(AbstractPromptStyle):
"""Tag prompt style (used by Vigogne) that uses the prompt style `<|ROLE|>`.
Expand Down Expand Up @@ -219,7 +286,8 @@ def _completion_to_prompt(self, completion: str) -> str:


def get_prompt_style(
prompt_style: Literal["default", "llama2", "tag", "mistral", "chatml"] | None
prompt_style: Literal["default", "llama2", "llama3", "tag", "mistral", "chatml"]
| None
) -> AbstractPromptStyle:
"""Get the prompt style to use from the given string.
Expand All @@ -230,6 +298,8 @@ def get_prompt_style(
return DefaultPromptStyle()
elif prompt_style == "llama2":
return Llama2PromptStyle()
elif prompt_style == "llama3":
return Llama3PromptStyle()
elif prompt_style == "tag":
return TagPromptStyle()
elif prompt_style == "mistral":
Expand Down
5 changes: 4 additions & 1 deletion private_gpt/settings/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,12 +111,15 @@ class LLMSettings(BaseModel):
0.1,
description="The temperature of the model. Increasing the temperature will make the model answer more creatively. A value of 0.1 would be more factual.",
)
prompt_style: Literal["default", "llama2", "tag", "mistral", "chatml"] = Field(
prompt_style: Literal[
"default", "llama2", "llama3", "tag", "mistral", "chatml"
] = Field(
"llama2",
description=(
"The prompt style to use for the chat engine. "
"If `default` - use the default prompt style from the llama_index. It should look like `role: message`.\n"
"If `llama2` - use the llama2 prompt style from the llama_index. Based on `<s>`, `[INST]` and `<<SYS>>`.\n"
"If `llama3` - use the llama3 prompt style from the llama_index."
"If `tag` - use the `tag` prompt style. It should look like `<|role|>: message`. \n"
"If `mistral` - use the `mistral prompt style. It shoudl look like <s>[INST] {System Prompt} [/INST]</s>[INST] { UserInstructions } [/INST]"
"`llama2` is the historic behaviour. `default` might work better with your custom models."
Expand Down
14 changes: 7 additions & 7 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ target-version = ['py311']
target-version = 'py311'

# See all rules at https://beta.ruff.rs/docs/rules/
select = [
lint.select = [
"E", # pycodestyle
"W", # pycodestyle
"F", # Pyflakes
Expand All @@ -141,7 +141,7 @@ select = [
"RUF", # Ruff-specific rules
]

ignore = [
lint.ignore = [
"E501", # "Line too long"
# -> line length already regulated by black
"PT011", # "pytest.raises() should specify expected exception"
Expand All @@ -159,24 +159,24 @@ ignore = [
# -> "Missing docstring in public function too restrictive"
]

[tool.ruff.pydocstyle]
[tool.ruff.lint.pydocstyle]
# Automatically disable rules that are incompatible with Google docstring convention
convention = "google"

[tool.ruff.pycodestyle]
[tool.ruff.lint.pycodestyle]
max-doc-length = 88

[tool.ruff.flake8-tidy-imports]
[tool.ruff.lint.flake8-tidy-imports]
ban-relative-imports = "all"

[tool.ruff.flake8-type-checking]
[tool.ruff.lint.flake8-type-checking]
strict = true
runtime-evaluated-base-classes = ["pydantic.BaseModel"]
# Pydantic needs to be able to evaluate types at runtime
# see https://pypi.org/project/flake8-type-checking/ for flake8-type-checking documentation
# see https://beta.ruff.rs/docs/settings/#flake8-type-checking-runtime-evaluated-base-classes for ruff documentation

[tool.ruff.per-file-ignores]
[tool.ruff.lint.per-file-ignores]
# Allow missing docstrings for tests
"tests/**/*.py" = ["D1"]

Expand Down
55 changes: 55 additions & 0 deletions tests/test_prompt_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
ChatMLPromptStyle,
DefaultPromptStyle,
Llama2PromptStyle,
Llama3PromptStyle,
MistralPromptStyle,
TagPromptStyle,
get_prompt_style,
Expand Down Expand Up @@ -139,3 +140,57 @@ def test_llama2_prompt_style_with_system_prompt():
)

assert prompt_style.messages_to_prompt(messages) == expected_prompt


def test_llama3_prompt_style_format():
prompt_style = Llama3PromptStyle()
messages = [
ChatMessage(content="You are a helpful assistant", role=MessageRole.SYSTEM),
ChatMessage(content="Hello, how are you doing?", role=MessageRole.USER),
]

expected_prompt = (
"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n"
"You are a helpful assistant<|eot_id|>"
"<|start_header_id|>user<|end_header_id|>\n\n"
"Hello, how are you doing?<|eot_id|>"
"<|start_header_id|>assistant<|end_header_id|>\n\n"
)

assert prompt_style.messages_to_prompt(messages) == expected_prompt


def test_llama3_prompt_style_with_default_system():
prompt_style = Llama3PromptStyle()
messages = [
ChatMessage(content="Hello!", role=MessageRole.USER),
]
expected = (
"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n"
f"{prompt_style.DEFAULT_SYSTEM_PROMPT}<|eot_id|>"
"<|start_header_id|>user<|end_header_id|>\n\nHello!<|eot_id|>"
"<|start_header_id|>assistant<|end_header_id|>\n\n"
)
assert prompt_style._messages_to_prompt(messages) == expected


def test_llama3_prompt_style_with_assistant_response():
prompt_style = Llama3PromptStyle()
messages = [
ChatMessage(content="You are a helpful assistant", role=MessageRole.SYSTEM),
ChatMessage(content="What is the capital of France?", role=MessageRole.USER),
ChatMessage(
content="The capital of France is Paris.", role=MessageRole.ASSISTANT
),
]

expected_prompt = (
"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n"
"You are a helpful assistant<|eot_id|>"
"<|start_header_id|>user<|end_header_id|>\n\n"
"What is the capital of France?<|eot_id|>"
"<|start_header_id|>assistant<|end_header_id|>\n\n"
"The capital of France is Paris.<|eot_id|>"
)

assert prompt_style.messages_to_prompt(messages) == expected_prompt

0 comments on commit d080969

Please sign in to comment.