Skip to content

Commit

Permalink
add chat memory
Browse files Browse the repository at this point in the history
  • Loading branch information
baptiste-pasquier committed Mar 22, 2024
1 parent fa67d5f commit b0b445d
Show file tree
Hide file tree
Showing 14 changed files with 298 additions and 9 deletions.
39 changes: 36 additions & 3 deletions backend/rag_1/chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,13 @@
from langchain_core.messages import BaseMessage, HumanMessage
from langchain_core.output_parsers.string import StrOutputParser
from langchain_core.runnables import RunnableLambda, RunnablePassthrough
from langchain_core.runnables.base import RunnableSequence
from langchain_core.runnables.base import RunnableSequence, RunnableSerializable
from omegaconf.dictconfig import DictConfig
from pydantic import BaseModel

from backend.rag_components.chain_links.rag_with_history import (
construct_rag_with_history,
)
from backend.utils.image import resize_base64_image
from backend.utils.llm import get_vision_llm
from backend.utils.retriever import get_retriever
Expand Down Expand Up @@ -83,7 +87,19 @@ def img_prompt_func(data_dict: dict) -> list[BaseMessage]:
return [HumanMessage(content=messages)]


def get_chain(config: DictConfig) -> RunnableSequence:
class Question(BaseModel):
"""Question to be answered."""

question: str


class Response(BaseModel):
"""Response to the question."""

response: str


def get_base_chain(config: DictConfig) -> RunnableSequence:
"""Constructs a RAG pipeline that retrieves image and text data from documents.
The pipeline consists of the following steps:
Expand Down Expand Up @@ -112,5 +128,22 @@ def get_chain(config: DictConfig) -> RunnableSequence:
| model
| StrOutputParser()
)
typed_chain = chain.with_types(input_type=str, output_type=Response)

return typed_chain

return chain

def get_chain(config: DictConfig) -> RunnableSerializable:
"""Get the appropriate RAG pipeline based on the configuration.
Args:
config (DictConfig): Configuration object.
Returns:
RunnableSerializable: RAG pipeline.
"""
base_chain = get_base_chain(config)
if config.rag.enable_chat_memory:
chain_with_mem = construct_rag_with_history(base_chain, config)
return chain_with_mem
return base_chain
10 changes: 10 additions & 0 deletions backend/rag_1/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,14 @@ def validate_size(cls, value: list[float]) -> list[float]:
return value


@dataclass(config=ConfigDict(extra="forbid"))
class RagConfig:
"""Configuration for RAG."""

database_url: str
enable_chat_memory: bool


@dataclass(config=ConfigDict(extra="forbid"))
class Config:
"""Configuration for the RAG Option 1."""
Expand All @@ -71,6 +79,8 @@ class Config:

ingest: IngestConfig

rag: RagConfig


def validate_config(config: DictConfig) -> Config:
"""Validate the configuration.
Expand Down
4 changes: 4 additions & 0 deletions backend/rag_1/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -52,3 +52,7 @@ ingest:
table_min_size: [0.0, 0.0]

export_extracted: True

rag:
database_url: ${oc.env:DATABASE_URL}
enable_chat_memory: ${oc.decode:${oc.env:ENABLE_AUTHENTICATION}}
42 changes: 39 additions & 3 deletions backend/rag_2/chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,35 @@
from langchain_core.output_parsers.string import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnablePassthrough
from langchain_core.runnables.base import RunnableSequence
from langchain_core.runnables.base import (
RunnableSequence,
RunnableSerializable,
)
from omegaconf.dictconfig import DictConfig
from pydantic import BaseModel

from backend.rag_components.chain_links.rag_with_history import (
construct_rag_with_history,
)
from backend.utils.llm import get_text_llm
from backend.utils.retriever import get_retriever

from . import prompts


def get_chain(config: DictConfig) -> RunnableSequence:
class Question(BaseModel):
"""Question to be answered."""

question: str


class Response(BaseModel):
"""Response to the question."""

response: str


def get_base_chain(config: DictConfig) -> RunnableSequence:
"""Constructs a RAG pipeline that retrieves text data from documents.
The pipeline consists of the following steps:
Expand Down Expand Up @@ -43,5 +62,22 @@ def get_chain(config: DictConfig) -> RunnableSequence:
| model
| StrOutputParser()
)
typed_chain = chain.with_types(input_type=str, output_type=Response)

return typed_chain

return chain

def get_chain(config: DictConfig) -> RunnableSerializable:
"""Get the appropriate RAG pipeline based on the configuration.
Args:
config (DictConfig): Configuration object.
Returns:
RunnableSerializable: RAG pipeline.
"""
base_chain = get_base_chain(config)
if config.rag.enable_chat_memory:
chain_with_mem = construct_rag_with_history(base_chain, config)
return chain_with_mem
return base_chain
10 changes: 10 additions & 0 deletions backend/rag_2/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,14 @@ def validate_size(cls, value: list[float]) -> list[float]:
return value


@dataclass(config=ConfigDict(extra="forbid"))
class RagConfig:
"""Configuration for RAG."""

database_url: str
enable_chat_memory: bool


@dataclass(config=ConfigDict(extra="forbid"))
class Config:
"""Configuration for the RAG Option 2."""
Expand All @@ -147,6 +155,8 @@ class Config:

ingest: IngestConfig

rag: RagConfig


def validate_config(config: DictConfig) -> Config:
"""Validate the configuration.
Expand Down
4 changes: 4 additions & 0 deletions backend/rag_2/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -81,3 +81,7 @@ ingest:
image: "summary"

export_extracted: True

rag:
database_url: ${oc.env:DATABASE_URL}
enable_chat_memory: ${oc.decode:${oc.env:ENABLE_AUTHENTICATION}}
42 changes: 39 additions & 3 deletions backend/rag_3/chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,16 @@
from langchain_core.messages import BaseMessage, HumanMessage
from langchain_core.output_parsers.string import StrOutputParser
from langchain_core.runnables import RunnableLambda, RunnablePassthrough
from langchain_core.runnables.base import RunnableSequence
from langchain_core.runnables.base import (
RunnableSequence,
RunnableSerializable,
)
from omegaconf.dictconfig import DictConfig
from pydantic import BaseModel

from backend.rag_components.chain_links.rag_with_history import (
construct_rag_with_history,
)
from backend.utils.image import resize_base64_image
from backend.utils.llm import get_vision_llm
from backend.utils.retriever import get_retriever
Expand Down Expand Up @@ -83,7 +90,19 @@ def img_prompt_func(data_dict: dict) -> list[BaseMessage]:
return [HumanMessage(content=messages)]


def get_chain(config: DictConfig) -> RunnableSequence:
class Question(BaseModel):
"""Question to be answered."""

question: str


class Response(BaseModel):
"""Response to the question."""

response: str


def get_base_chain(config: DictConfig) -> RunnableSequence:
"""Constructs a RAG pipeline that retrieves image and text data from documents.
The pipeline consists of the following steps:
Expand Down Expand Up @@ -112,5 +131,22 @@ def get_chain(config: DictConfig) -> RunnableSequence:
| model
| StrOutputParser()
)
typed_chain = chain.with_types(input_type=str, output_type=Response)

return typed_chain

return chain

def get_chain(config: DictConfig) -> RunnableSerializable:
"""Get the appropriate RAG pipeline based on the configuration.
Args:
config (DictConfig): Configuration object.
Returns:
RunnableSerializable: RAG pipeline.
"""
base_chain = get_base_chain(config)
if config.rag.enable_chat_memory:
chain_with_mem = construct_rag_with_history(base_chain, config)
return chain_with_mem
return base_chain
10 changes: 10 additions & 0 deletions backend/rag_3/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,14 @@ def validate_size(cls, value: list[float]) -> list[float]:
return value


@dataclass(config=ConfigDict(extra="forbid"))
class RagConfig:
"""Configuration for RAG."""

database_url: str
enable_chat_memory: bool


@dataclass(config=ConfigDict(extra="forbid"))
class Config:
"""Configuration for the RAG Option 3."""
Expand All @@ -136,6 +144,8 @@ class Config:

ingest: IngestConfig

rag: RagConfig


def validate_config(config: DictConfig) -> Config:
"""Validate the configuration.
Expand Down
4 changes: 4 additions & 0 deletions backend/rag_3/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -81,3 +81,7 @@ ingest:
image: "content"

export_extracted: True

rag:
database_url: ${oc.env:DATABASE_URL}
enable_chat_memory: ${oc.decode:${oc.env:ENABLE_AUTHENTICATION}}
Empty file.
47 changes: 47 additions & 0 deletions backend/rag_components/chain_links/condense_question.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
"""This chain condenses the chat history and the question into one standalone question."""

from langchain_core.prompts import PromptTemplate
from langchain_core.output_parsers import StrOutputParser
from pydantic import BaseModel


from langchain_core.runnables import RunnableSequence


class QuestionWithChatHistory(BaseModel):
question: str
chat_history: str


class StandaloneQuestion(BaseModel):
standalone_question: str


prompt = """\
Given the conversation history and the following question, can you rephrase the user's \
question in its original language so that it is self-sufficient. You are presented \
with a conversation that may contain some spelling mistakes and grammatical errors, \
but your goal is to understand the underlying question. Make sure to avoid the use of \
unclear pronouns.
If the question is already self-sufficient, return the original question. If it seem \
the user is authorizing the chatbot to answer without specific context, make sure to \
reflect that in the rephrased question.
Chat history: {chat_history}
Question: {question}
""" # noqa: E501


def condense_question(llm) -> RunnableSequence:
condense_question_prompt = PromptTemplate.from_template(
prompt
) # chat_history, question

standalone_question = condense_question_prompt | llm | StrOutputParser()

typed_chain = standalone_question.with_types(
input_type=QuestionWithChatHistory, output_type=StandaloneQuestion
)
return typed_chain
55 changes: 55 additions & 0 deletions backend/rag_components/chain_links/rag_with_history.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
from langchain_core.retrievers import BaseRetriever
from langchain_core.runnables.history import RunnableWithMessageHistory
from backend.config import RagConfig

from backend.rag_components.chain_links.answer_question_from_docs_and_history import (
answer_question_from_docs_and_history_chain,
)
from backend.rag_components.chain_links.condense_question import condense_question
from backend.rag_components.chat_message_history import get_chat_message_history

from backend.utils.llm import get_text_llm
from pydantic import BaseModel
from langchain_core.runnables.base import RunnableSequence


class QuestionWithHistory(BaseModel):
"""Question with chat history."""

question: str
chat_history: str


class Response(BaseModel):
"""Response to the question."""

response: str


def construct_rag_with_history(
base_chain: RunnableSequence,
config: RagConfig,
) -> RunnableWithMessageHistory:
"""Constructs a RAG pipeline with memory.
Args:
config (DictConfig): Configuration object.
Returns:
RunnableWithMessageHistory: RAG pipeline with memory.
"""
text_llm = get_text_llm(config)

reformulate_question = condense_question(text_llm)

chain = reformulate_question | base_chain
typed_chain = chain.with_types(input_type=QuestionWithHistory, output_type=Response)

chain_with_mem = RunnableWithMessageHistory(
typed_chain,
lambda session_id: get_chat_message_history(config, session_id),
input_messages_key="question",
history_messages_key="chat_history",
)

return chain_with_mem
Loading

0 comments on commit b0b445d

Please sign in to comment.