Skip to content

Commit

Permalink
Merge pull request #374 from chidochipotle/sk/pg-vector
Browse files Browse the repository at this point in the history
RAG MVP
  • Loading branch information
snopoke authored May 28, 2024
2 parents f441adb + 22522a5 commit 3eeae5b
Show file tree
Hide file tree
Showing 5 changed files with 81 additions and 3 deletions.
43 changes: 42 additions & 1 deletion apps/experiments/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,18 @@

from celery.app import shared_task
from langchain.schema import AIMessage, HumanMessage
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.document_loaders import PyMuPDFLoader, TextLoader
from taskbadger.celery import Task as TaskbadgerTask

from apps.channels.datamodels import WebMessage
from apps.chat.bots import create_conversation
from apps.chat.channels import WebChannel
from apps.experiments.models import ExperimentSession, PromptBuilderHistory, SourceMaterial
from apps.experiments.models import Experiment, ExperimentSession, PromptBuilderHistory, SourceMaterial
from apps.service_providers.models import LlmProvider
from apps.users.models import CustomUser
from apps.utils.taskbadger import update_taskbadger_data
from apps.vectordb.vectorstore import PGVector


@shared_task(bind=True, base=TaskbadgerTask)
Expand All @@ -23,6 +26,44 @@ def get_response_for_webchat_task(self, experiment_session_id: int, message_text
return message_handler.new_user_message(message)


@shared_task(bind=True, base=TaskbadgerTask)
def store_rag_embedding(self, experiment_id: int) -> None:
experiment = Experiment.objects.get(id=experiment_id)
file_path = experiment.files.all().last().file.path
splits = load_rag_file(file_path)
embeddings_model = experiment.get_llm_service().get_openai_embeddings()
PGVector.from_texts(splits, embeddings_model, None, experiment)


def load_rag_file(file_path: str) -> list[str]:
"""
Loads a text file of any supported type (PDF, TXT, HTML) into Langchain.
Args:
file_path (str): The path to the text file.
Returns:
str_splits: A list of strings from Langchain Document objects
containing the loaded page_content.
"""

# Automatically detect loader based on file extension if not provided
extension = file_path.split(".")[-1].lower()
if extension == "pdf":
loader = PyMuPDFLoader(file_path, extract_images=False)
elif extension in ("txt", "text"):
loader = TextLoader(file_path)
else:
raise ValueError(f"Unsupported file type: {extension}")

# Load the text file using the appropriate loader
documents = loader.load()
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
splits = text_splitter.split_documents(documents)
str_splits = [s.page_content for s in splits[0:10]]
return str_splits


@shared_task
def get_prompt_builder_response_task(team_id: int, user_id, data_dict: dict) -> dict[str, str | int]:
llm_service = LlmProvider.objects.get(id=data_dict["provider"]).get_llm_service()
Expand Down
7 changes: 6 additions & 1 deletion apps/experiments/views/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@
from apps.experiments.helpers import get_real_user_or_none
from apps.experiments.models import Experiment, ExperimentSession, Participant, SessionStatus, SyntheticVoice
from apps.experiments.tables import ExperimentSessionsTable, ExperimentTable
from apps.experiments.tasks import get_response_for_webchat_task
from apps.experiments.tasks import get_response_for_webchat_task, store_rag_embedding
from apps.experiments.views.prompt import PROMPT_DATA_SESSION_KEY
from apps.files.forms import get_file_formset
from apps.files.views import BaseAddFileHtmxView, BaseDeleteFileView
Expand Down Expand Up @@ -220,6 +220,10 @@ def _validate_prompt_variables(form_data):
available_variables = set()
if form_data.get("source_material"):
available_variables.add("source_material")
# available_variables below should be added by making a
# db request to check if there are any RAG files uploaded
available_variables.add("context")
available_variables.add("input")
missing_vars = required_variables - available_variables
known_vars = {"source_material"}
if missing_vars:
Expand Down Expand Up @@ -361,6 +365,7 @@ def form_valid(self, form):
experiment = get_object_or_404(Experiment, team=self.request.team, pk=self.kwargs["pk"])
file = super().form_valid(form)
experiment.files.add(file)
store_rag_embedding(experiment.id)
return file

def get_delete_url(self, file):
Expand Down
9 changes: 9 additions & 0 deletions apps/service_providers/llm_service/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from langchain_community.chat_models import ChatAnthropic
from langchain_core.callbacks import BaseCallbackHandler
from langchain_core.language_models import BaseLanguageModel
from langchain_openai import OpenAIEmbeddings
from langchain_openai.chat_models import AzureChatOpenAI, ChatOpenAI
from openai import OpenAI
from openai._base_client import SyncAPIClient
Expand Down Expand Up @@ -42,6 +43,14 @@ class OpenAILlmService(LlmService):
openai_api_base: str = None
openai_organization: str = None

def get_openai_embeddings(self, model="text-embedding-3-small") -> OpenAIEmbeddings:
return OpenAIEmbeddings(
openai_api_key=self.openai_api_key,
openai_api_base=self.openai_api_base,
openai_organization=self.openai_organization,
model=model,
)

def get_raw_client(self) -> OpenAI:
return OpenAI(api_key=self.openai_api_key, organization=self.openai_organization, base_url=self.openai_api_base)

Expand Down
23 changes: 22 additions & 1 deletion apps/service_providers/llm_service/runnables.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from apps.chat.conversation import compress_chat_history
from apps.chat.models import ChatMessage, ChatMessageType
from apps.experiments.models import Experiment, ExperimentSession
from apps.vectordb.vectorstore import PGVector

logger = logging.getLogger(__name__)

Expand All @@ -51,7 +52,8 @@ def create_experiment_runnable(experiment: Experiment, session: ExperimentSessio
assert experiment.llm_provider, "Experiment must have an LLM provider"
if experiment.tools_enabled:
return AgentExperimentRunnable(experiment=experiment, session=session)

if experiment.files.exists():
return RagExperimentRunnable(experiment=experiment, session=session)
return SimpleExperimentRunnable(experiment=experiment, session=session)


Expand Down Expand Up @@ -219,6 +221,25 @@ def _build_chain(self) -> Runnable[dict[str, Any], str]:
)


class RagExperimentRunnable(ExperimentRunnable):
def _build_chain(self) -> Runnable[dict[str, Any], str]:
def format_docs(docs):
return "\n\n".join(doc.page_content for doc in docs)

model = self.llm_service.get_chat_model(self.experiment.llm, self.experiment.temperature)
embeddings = self.experiment.get_llm_service().get_openai_embeddings()
retriever = PGVector(self.experiment, embeddings).as_retriever()
return (
{"context": retriever | format_docs, "input": RunnablePassthrough()}
| RunnablePassthrough.assign(
history=RunnableLambda(self.memory.load_memory_variables) | itemgetter("history")
)
| self.prompt
| model
| StrOutputParser()
)


class AgentExperimentRunnable(ExperimentRunnable):
def _parse_output(self, output):
return output.get("output", "")
Expand Down
2 changes: 2 additions & 0 deletions requirements/requirements.in
Original file line number Diff line number Diff line change
Expand Up @@ -45,3 +45,5 @@ turn-python>=0.2.0
jinja2
django-taggit
pgvector
PyMuPDF
langchainhub

0 comments on commit 3eeae5b

Please sign in to comment.