Skip to content
This repository has been archived by the owner on Nov 13, 2024. It is now read-only.

Add instruction query generator #226

Merged
merged 8 commits into from
Dec 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion config/anyscale.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ chat_engine:
# Since Anyscale's LLM endpoint currently doesn't support function calling, we will use the LastMessageQueryGenerator
# --------------------------------------------------------------------
query_builder:
type: LastMessageQueryGenerator # Options: [FunctionCallingQueryGenerator, LastMessageQueryGenerator]
type: InstructionQueryGenerator # Options: [InstructionQueryGenerator, LastMessageQueryGenerator]

# -------------------------------------------------------------------------------------------------------------
# ContextEngine configuration
Expand Down
2 changes: 1 addition & 1 deletion config/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ chat_engine:
# The query builder is responsible for generating textual queries given user message history.
# --------------------------------------------------------------------
query_builder:
type: FunctionCallingQueryGenerator # Options: [FunctionCallingQueryGenerator, LastMessageQueryGenerator]
type: FunctionCallingQueryGenerator # Options: [FunctionCallingQueryGenerator, LastMessageQueryGenerator, InstructionQueryGenerator]
params:
prompt: *query_builder_prompt # The query builder's system prompt for calling the LLM
function_description: # A function description passed to the LLM's `function_calling` API
Expand Down
1 change: 1 addition & 0 deletions src/canopy/chat_engine/query_generator/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .base import QueryGenerator
from .function_calling import FunctionCallingQueryGenerator
from .last_message import LastMessageQueryGenerator
from .instruction import InstructionQueryGenerator
133 changes: 133 additions & 0 deletions src/canopy/chat_engine/query_generator/instruction.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
import logging
import re
from typing import List, Optional, cast

from tenacity import retry, stop_after_attempt, retry_if_exception_type

from canopy.chat_engine.models import HistoryPruningMethod
from canopy.chat_engine.prompt_builder import PromptBuilder
from canopy.chat_engine.query_generator import QueryGenerator, LastMessageQueryGenerator
from canopy.llm import BaseLLM, OpenAILLM
from canopy.models.api_models import ChatResponse
from canopy.models.data_models import Messages, Query, UserMessage

logger = logging.getLogger(__name__)

SYSTEM_PROMPT = """You are an expert on formulating a search query for a search engine,
to assist in responding to the user's question.

Given the following conversation, create a standalone question summarizing
the user's last question, in its original language.

Reply to me in JSON in this format:

{"question": {The question you generated here}}.

Example:

User: What is the weather today?

Expected Response:
```json
{"question": "What is the weather today?"}
```

Example 2:

User: How should I wash my white clothes in the laundry?
Assistant: Separate from the colorful ones, and use a bleach.
User: Which temperature?

Expected Response:
```json
{"question": "What is the right temperature for washing white clothes?"}
```


Do not try to answer the question; just try to formulate a question representing the user's question.
Do not return any other format other than the specified JSON format and keep it really short.

""" # noqa: E501

USER_PROMPT = "Return only a JSON containing a single key 'question' and the value."


class ExtractionException(ValueError):
pass


class InstructionQueryGenerator(QueryGenerator):
_DEFAULT_COMPONENTS = {
"llm": OpenAILLM,
}

def __init__(self,
*,
llm: Optional[BaseLLM] = None):
"""
This `QueryGenerator` uses an LLM to formulate a knowledge base query
from the full chat history. It does so by prompting the LLM to reply
with a JSON containing a single key `question`, containing the query
for the knowledge base. If LLM response cannot be parsed
(after multiple retries), it falls back to returning the last message
from the history as a query, much like `LastMessageQueryGenerator`
"""
self._llm = llm or self._DEFAULT_COMPONENTS["llm"]()
self._system_prompt = SYSTEM_PROMPT
self._prompt_builder = PromptBuilder(HistoryPruningMethod.RAISE, 2)
self._last_message_query_generator = LastMessageQueryGenerator()

# Define a regex pattern to find the JSON object with the key "question"
self._question_regex = re.compile(r'{\s*"question":\s*"([^"]+)"\s*}')

def generate(self,
messages: Messages,
max_prompt_tokens: int) -> List[Query]:

# Add a user message at the end; that helps us return a JSON object.
new_history = (
messages +
[
UserMessage(content=USER_PROMPT)
]
)

new_messages = self._prompt_builder.build(system_prompt=self._system_prompt,
history=new_history,
max_tokens=max_prompt_tokens)

question = self._try_generate_question(new_messages)

if question is None:
logger.warning("Falling back to the last message query generator.")
return self._last_message_query_generator.generate(messages, 0)
else:
return [Query(text=question)]

def _get_answer(self, messages: Messages) -> str:
llm_response = self._llm.chat_completion(messages)
response = cast(ChatResponse, llm_response)
return response.choices[0].message.content

@retry(stop=stop_after_attempt(3),
retry=retry_if_exception_type(ExtractionException),
retry_error_callback=lambda _: None)
def _try_generate_question(self, messages: Messages) -> Optional[str]:
content = self._get_answer(messages)
return self._extract_question(content)

def _extract_question(self, text: str) -> str:

# Search for the pattern in the text
match = re.search(self._question_regex, text)

# If a match is found, extract and return the first occurrence
if match:
return match.group(1)

raise ExtractionException("Failed to extract the question.")

async def agenerate(self,
messages: Messages,
max_prompt_tokens: int) -> List[Query]:
raise NotImplementedError
88 changes: 88 additions & 0 deletions tests/unit/query_generators/test_instruction_query_generator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
from typing import List
from unittest.mock import create_autospec

import pytest

from canopy.chat_engine.query_generator import InstructionQueryGenerator
from canopy.llm import BaseLLM
from canopy.models.api_models import ChatResponse, _Choice, TokenCounts
from canopy.models.data_models import Query, UserMessage, AssistantMessage


@pytest.fixture
def mock_llm():
return create_autospec(BaseLLM)


@pytest.fixture
def query_generator(mock_llm):
query_gen = InstructionQueryGenerator(
llm=mock_llm,
)
return query_gen


@pytest.fixture
def sample_messages():
return [UserMessage(content="How can I init a client?"),
AssistantMessage(content="Which kind of client?"),
UserMessage(content="A pinecone client.")]


@pytest.mark.parametrize(("response", "query", "call_count"), [
(
'{"question": "How do I init a pinecone client?"}',
"How do I init a pinecone client?",
1
),

(
'Unparseable JSON response from LLM, falling back to the last message',
"A pinecone client.",
3
)

])
def test_generate(query_generator,
mock_llm,
sample_messages,
response,
query,
call_count):
mock_llm.chat_completion.return_value = ChatResponse(
id="meta-llama/Llama-2-7b-chat-hf-HTQ-4",
object="text_completion",
created=1702569324,
model='meta-llama/Llama-2-7b-chat-hf',
usage=TokenCounts(
prompt_tokens=367,
completion_tokens=19,
total_tokens=386
),
choices=[
_Choice(
index=0,
message=AssistantMessage(
content=response
)
)
]
)

result = query_generator.generate(messages=sample_messages,
max_prompt_tokens=4096)

assert mock_llm.chat_completion.call_count == call_count
assert isinstance(result, List)
assert len(result) == 1
assert result[0] == Query(text=query)


@pytest.mark.asyncio
async def test_agenerate_not_implemented(query_generator,
mock_llm,
sample_messages
):
with pytest.raises(NotImplementedError):
await query_generator.agenerate(messages=sample_messages,
max_prompt_tokens=100)